340 lines
11 KiB
Python
340 lines
11 KiB
Python
# lama_trt_module.py
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
FastDeploy + TensorRT(FP16) 기반 LaMa 인페인팅 모듈
|
||
- simple-lama-inpainting 전처리/후처리 로직을 반영
|
||
- 입력: PIL.Image 또는 np.ndarray (RGB), 마스크(PIL/np) - 255가 인페인트 영역
|
||
- 출력: PIL.Image
|
||
|
||
필요 패키지:
|
||
pip install fastdeploy-python pillow opencv-python numpy
|
||
(onnx, onnxsim 등은 모델 변환 시 필요할 수 있지만, 여기서는 이미 onnx를 가정)
|
||
|
||
환경 변수:
|
||
CUDA GPU가 있어야 TensorRT 백엔드를 사용 가능.
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import cv2
|
||
import time
|
||
import json
|
||
import logging
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import Tuple, List, Literal, Optional
|
||
|
||
from PIL import Image
|
||
|
||
# ==========================================
|
||
# 로거 설정
|
||
# ==========================================
|
||
def get_logger(name: str = "lama_trt"):
|
||
logger = logging.getLogger(name)
|
||
logger.setLevel(logging.DEBUG)
|
||
if not logger.handlers:
|
||
sh = logging.StreamHandler(sys.stdout)
|
||
sh.setLevel(logging.DEBUG)
|
||
fmt = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
|
||
sh.setFormatter(fmt)
|
||
logger.addHandler(sh)
|
||
return logger
|
||
|
||
logger = get_logger()
|
||
|
||
|
||
# ==========================================
|
||
# 유틸 함수들 (simple-lama-inpainting util.py 참고)
|
||
# ==========================================
|
||
def np_from_image(img: Image.Image | np.ndarray) -> np.ndarray:
|
||
"""PIL 또는 ndarray 입력을 (C,H,W) float32 [0~1]로 변환."""
|
||
if isinstance(img, Image.Image):
|
||
arr = np.array(img)
|
||
elif isinstance(img, np.ndarray):
|
||
arr = img.copy()
|
||
else:
|
||
raise TypeError("image/mask는 PIL.Image 또는 np.ndarray 여야 합니다.")
|
||
|
||
if arr.ndim == 3: # HWC
|
||
arr = np.transpose(arr, (2, 0, 1)) # CHW
|
||
elif arr.ndim == 2: # HW
|
||
arr = arr[np.newaxis, ...] # 1HW
|
||
else:
|
||
raise ValueError("입력 배열 차원이 맞지 않습니다.")
|
||
|
||
arr = arr.astype(np.float32) / 255.0
|
||
return arr
|
||
|
||
|
||
def ceil_modulo(x: int, mod: int) -> int:
|
||
return x if x % mod == 0 else (x // mod + 1) * mod
|
||
|
||
|
||
def pad_to_modulo(arr: np.ndarray, mod: int) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||
"""
|
||
arr: (C,H,W)
|
||
mod: 패딩 배수
|
||
return: (pad_arr, (orig_h, orig_w))
|
||
"""
|
||
c, h, w = arr.shape
|
||
out_h = ceil_modulo(h, mod)
|
||
out_w = ceil_modulo(w, mod)
|
||
pad_h = out_h - h
|
||
pad_w = out_w - w
|
||
if pad_h == 0 and pad_w == 0:
|
||
return arr, (h, w)
|
||
pad_arr = np.pad(arr, ((0, 0), (0, pad_h), (0, pad_w)), mode="symmetric")
|
||
return pad_arr, (h, w)
|
||
|
||
|
||
def depad(arr: np.ndarray, orig_hw: Tuple[int, int]) -> np.ndarray:
|
||
"""(C,H,W) -> 잘라서 원래 크기로."""
|
||
c, h, w = arr.shape
|
||
oh, ow = orig_hw
|
||
return arr[:, :oh, :ow]
|
||
|
||
|
||
def resize_hw(arr: np.ndarray, h: int, w: int, is_mask: bool = False) -> np.ndarray:
|
||
"""(C,H,W) -> (C,h,w) 리사이즈"""
|
||
if arr.shape[0] == 1: # mask 등
|
||
arr_hw = arr[0]
|
||
inter = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
|
||
arr_hw = cv2.resize(arr_hw, (w, h), interpolation=inter)
|
||
return arr_hw[np.newaxis, ...]
|
||
else:
|
||
arr_hw = np.transpose(arr, (1, 2, 0)) # HWC
|
||
inter = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
|
||
arr_hw = cv2.resize(arr_hw, (w, h), interpolation=inter)
|
||
return np.transpose(arr_hw, (2, 1, 0)) # (W,H,C)->(C,H,W) (주의)
|
||
|
||
|
||
def pil_from_chw(arr: np.ndarray) -> Image.Image:
|
||
"""(C,H,W) float[0~1] -> PIL RGB"""
|
||
arr = np.clip(arr, 0, 1)
|
||
arr = (arr * 255).astype(np.uint8)
|
||
if arr.shape[0] == 1:
|
||
arr = arr[0]
|
||
return Image.fromarray(arr, mode="L")
|
||
arr = np.transpose(arr, (1, 2, 0)) # HWC
|
||
return Image.fromarray(arr)
|
||
|
||
|
||
# ==========================================
|
||
# FastDeploy LaMa Inpainter
|
||
# ==========================================
|
||
class FastDeployLamaTRT:
|
||
"""
|
||
TensorRT 백엔드로 lama_fp32.onnx를 추론하는 클래스
|
||
- 고정 해상도 512×512용으로 최적화
|
||
- other mode: pad_to_modulo(8)도 제공 (원본과 비슷하게)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
onnx_path: str,
|
||
backend: Literal["tensorrt", "ort", "openvino", "paddle"] = "tensorrt",
|
||
precision: Literal["fp16", "fp32", "int8"] = "fp16",
|
||
device_id: int = 0,
|
||
input_hw: Tuple[int, int] = (512, 512),
|
||
pad_modulo: Optional[int] = None,
|
||
engine_cache_dir: Optional[str] = None,
|
||
):
|
||
"""
|
||
onnx_path: lama_fp32.onnx 경로
|
||
backend: 'tensorrt' 권장
|
||
precision: fp16 권장 (GPU 없으면 FP32로 자동 폴백)
|
||
input_hw: 고정 입력 해상도 (512,512)
|
||
pad_modulo: None이면 해상도 고정. 값이 있으면 원본을 pad 후 resize 없이 추론(단, 모델이 고정형이라면 의미 없음)
|
||
|
||
engine_cache_dir: FastDeploy의 엔진 캐시 디렉토리(선택). 기본은 내부 캐시.
|
||
"""
|
||
import fastdeploy as fd
|
||
|
||
self.onnx_path = onnx_path
|
||
self.backend = backend
|
||
self.precision = precision
|
||
self.device_id = device_id
|
||
self.input_hw = input_hw
|
||
self.pad_modulo = pad_modulo
|
||
self.fd = fd
|
||
|
||
# RuntimeOption 설정
|
||
opt = fd.RuntimeOption()
|
||
opt.set_gpu_id(device_id)
|
||
|
||
if backend == "tensorrt":
|
||
opt.use_trt_backend()
|
||
if precision == "fp16":
|
||
opt.enable_trt_fp16()
|
||
elif precision == "int8":
|
||
# INT8Calibration 기능 필요. 여기선 예시만.
|
||
opt.enable_trt_int8("calibration.table")
|
||
elif backend == "ort":
|
||
opt.use_ort_backend()
|
||
elif backend == "openvino":
|
||
opt.use_openvino_backend()
|
||
else:
|
||
opt.use_paddle_backend()
|
||
|
||
if engine_cache_dir:
|
||
# FastDeploy 최신 버전은 engine 파일 경로 지정 기능 제공(버전에 따라 다름)
|
||
# opt.set_trt_cache_dir(engine_cache_dir)
|
||
pass
|
||
|
||
# 범용 Model 로딩
|
||
self.model = fd.vision.utils.Model(opt, onnx_path, "")
|
||
logger.debug(f"[Init] ONNX 로드 완료: {onnx_path} | backend={backend}, precision={precision}")
|
||
|
||
# -----------------------------
|
||
# 전처리/후처리
|
||
# -----------------------------
|
||
def _preprocess(
|
||
self,
|
||
image: Image.Image | np.ndarray,
|
||
mask: Image.Image | np.ndarray
|
||
):
|
||
"""
|
||
return: dict(inputs), org_info(dict)
|
||
"""
|
||
# 1) PIL/ndarray -> (C,H,W) float32
|
||
img_chw = np_from_image(image)
|
||
mask_chw = np_from_image(mask)
|
||
|
||
# 마스크 이진화
|
||
mask_chw = (mask_chw > 0.0).astype(np.float32)
|
||
|
||
# 2) 패딩 or 리사이즈
|
||
org_h, org_w = img_chw.shape[1], img_chw.shape[2]
|
||
|
||
if self.pad_modulo is not None:
|
||
# 원본 크기에 맞게 pad 후, (모델이 고정형이면 pad->resize 필요)
|
||
img_chw, orig_hw_after_pad = pad_to_modulo(img_chw, self.pad_modulo)
|
||
mask_chw, _ = pad_to_modulo(mask_chw, self.pad_modulo)
|
||
else:
|
||
orig_hw_after_pad = (org_h, org_w)
|
||
|
||
# 고정 입력 크기(512×512)로 리사이즈
|
||
in_h, in_w = self.input_hw
|
||
img_resized = resize_hw(img_chw, in_h, in_w, is_mask=False)
|
||
mask_resized = resize_hw(mask_chw, in_h, in_w, is_mask=True)
|
||
|
||
# 3) NCHW 배치 축
|
||
img_input = img_resized[np.newaxis, ...] # (1,C,H,W)
|
||
mask_input = mask_resized[np.newaxis, ...]
|
||
|
||
inputs = {"image": img_input, "mask": mask_input}
|
||
org_info = {
|
||
"orig_hw": (org_h, org_w),
|
||
"padded_hw": orig_hw_after_pad,
|
||
}
|
||
return inputs, org_info
|
||
|
||
def _postprocess(self, output: np.ndarray, org_info: dict) -> Image.Image:
|
||
"""
|
||
output: (1,3,H,W) float32
|
||
org_info: preprocess에서 저장한 정보
|
||
"""
|
||
out = output[0] # (3,H,W)
|
||
# 역정규화
|
||
pil_resized = pil_from_chw(out)
|
||
|
||
# 원본 크기로 복원(512->orig)
|
||
orig_h, orig_w = org_info["orig_hw"]
|
||
pil_out = pil_resized.resize((orig_w, orig_h), Image.BICUBIC)
|
||
|
||
return pil_out
|
||
|
||
# -----------------------------
|
||
# 추론
|
||
# -----------------------------
|
||
def inpaint(
|
||
self,
|
||
image: Image.Image | np.ndarray,
|
||
mask: Image.Image | np.ndarray
|
||
) -> Image.Image:
|
||
inputs, org_info = self._preprocess(image, mask)
|
||
|
||
# FastDeploy 추론
|
||
start = time.time()
|
||
outputs = self.model.predict(inputs)
|
||
elapsed = (time.time() - start) * 1000
|
||
logger.debug(f"[Infer] {elapsed:.2f} ms")
|
||
|
||
# outputs: dict or tuple
|
||
if isinstance(outputs, dict):
|
||
out_arr = outputs.get("output") or list(outputs.values())[0]
|
||
else:
|
||
out_arr = outputs[0]
|
||
|
||
result = self._postprocess(out_arr, org_info)
|
||
return result
|
||
|
||
|
||
# ==========================================
|
||
# CLI 실행부 (옵션)
|
||
# ==========================================
|
||
def build_mask_from_ocr_json(img: Image.Image, json_path: str, dilate_px: int = 4, feather: int = 3) -> Image.Image:
|
||
"""OCR bbox json -> 마스크 생성. [{'bbox':[x1,y1,x2,y2]}, ...] 형식 가정."""
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
boxes = [tuple(item["bbox"]) for item in data]
|
||
|
||
w, h = img.size
|
||
mask = np.zeros((h, w), dtype=np.uint8)
|
||
for (x1, y1, x2, y2) in boxes:
|
||
x1 = max(0, x1 - dilate_px); y1 = max(0, y1 - dilate_px)
|
||
x2 = min(w, x2 + dilate_px); y2 = min(h, y2 + dilate_px)
|
||
mask[y1:y2, x1:x2] = 255
|
||
|
||
if feather > 0:
|
||
k = feather * 2 + 1
|
||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||
return Image.fromarray(mask)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="LaMa Inpainting with FastDeploy TensorRT")
|
||
sub = parser.add_subparsers(dest="cmd")
|
||
|
||
# run
|
||
p_run = sub.add_parser("run")
|
||
p_run.add_argument("--img", required=True)
|
||
p_run.add_argument("--mask", help="mask 이미지 경로. 없으면 --ocr_json 사용")
|
||
p_run.add_argument("--ocr_json", help="OCR bbox json 경로")
|
||
p_run.add_argument("--onnx", required=True, help="lama_fp32.onnx 경로")
|
||
p_run.add_argument("--out", default="result.png")
|
||
p_run.add_argument("--backend", default="tensorrt", choices=["tensorrt", "ort", "openvino", "paddle"])
|
||
p_run.add_argument("--precision", default="fp16", choices=["fp32", "fp16", "int8"])
|
||
p_run.add_argument("--device", type=int, default=0)
|
||
p_run.add_argument("--engine_cache", default=None)
|
||
args = parser.parse_args()
|
||
|
||
if args.cmd == "run":
|
||
img = Image.open(args.img).convert("RGB")
|
||
|
||
if args.mask:
|
||
mask = Image.open(args.mask).convert("L")
|
||
elif args.ocr_json:
|
||
mask = build_mask_from_ocr_json(img, args.ocr_json)
|
||
else:
|
||
logger.error("mask 또는 ocr_json 중 하나는 제공해야 합니다.")
|
||
sys.exit(1)
|
||
|
||
inpainter = FastDeployLamaTRT(
|
||
onnx_path=args.onnx,
|
||
backend=args.backend,
|
||
precision=args.precision,
|
||
device_id=args.device,
|
||
input_hw=(512, 512),
|
||
pad_modulo=None,
|
||
engine_cache_dir=args.engine_cache
|
||
)
|
||
out_img = inpainter.inpaint(img, mask)
|
||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||
out_img.save(args.out)
|
||
logger.info(f"[Done] 저장: {args.out}")
|
||
else:
|
||
parser.print_help()
|