ImageProcessor_MainServer/lama/lama_trt_inpaint.py

301 lines
11 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
lama_trt_inpaint.py (수정版)
- FastDeploy + TensorRT(FP16)로 lama_fp32.onnx 실행
- 512x512 고정 입력 모델 대응:
* full : 전체 이미지를 512로 리사이즈 후 복원
* roi : 마스크 영역만 크롭→512 맞춰 인페인트→원본에 합성 (권장)
필요:
pip install fastdeploy-gpu-python pillow opencv-python numpy
사용 예:
python lama_trt_inpaint.py \
--onnx lama_fp32.onnx \
--img image_1.png \
--mask mask_1.png \
--out result.png \
--mode roi \
--backend tensorrt --precision fp16
"""
import sys
import cv2
import time
import logging
import numpy as np
from pathlib import Path
from typing import List, Tuple, Literal, Optional
from PIL import Image
# ---------------------------
# Logger
# ---------------------------
def get_logger(name="lama_trt"):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
if not logger.handlers:
sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.DEBUG)
fmt = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
sh.setFormatter(fmt)
logger.addHandler(sh)
return logger
logger = get_logger()
# ---------------------------
# Utils
# ---------------------------
def pil_to_chw_float(img: Image.Image | np.ndarray) -> np.ndarray:
if isinstance(img, Image.Image):
arr = np.array(img)
else:
arr = img.copy()
if arr.ndim == 3:
arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
elif arr.ndim == 2:
arr = arr[np.newaxis, ...] # HW -> 1HW
else:
raise ValueError("Unexpected ndim for image.")
return arr.astype(np.float32) / 255.0
def chw_to_pil(arr: np.ndarray) -> Image.Image:
arr = np.clip(arr, 0, 1)
arr = (arr * 255).astype(np.uint8)
if arr.shape[0] == 1:
return Image.fromarray(arr[0], mode="L")
arr = np.transpose(arr, (1, 2, 0))
return Image.fromarray(arr)
def resize_chw(arr: np.ndarray, h: int, w: int, is_mask=False) -> np.ndarray:
if arr.shape[0] == 1:
data = arr[0]
inter = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
data = cv2.resize(data, (w, h), interpolation=inter)
return data[np.newaxis, ...]
else:
data = np.transpose(arr, (1, 2, 0))
inter = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
data = cv2.resize(data, (w, h), interpolation=inter)
return np.transpose(data, (2, 0, 1))
def merge_rects(rects: List[Tuple[int,int,int,int]], iou_thresh=0.2) -> List[Tuple[int,int,int,int]]:
if not rects:
return []
rects = sorted(rects, key=lambda r: (r[0], r[1]))
merged = []
for r in rects:
merged_flag = False
for i, mr in enumerate(merged):
if _iou_rect(r, mr) > iou_thresh or _intersects(r, mr):
merged[i] = _union_rect(mr, r)
merged_flag = True
break
if not merged_flag:
merged.append(r)
return merged
def _union_rect(a,b):
return (min(a[0],b[0]), min(a[1],b[1]), max(a[2],b[2]), max(a[3],b[3]))
def _intersects(a,b):
return not (a[2] <= b[0] or a[0] >= b[2] or a[3] <= b[1] or a[1] >= b[3])
def _iou_rect(a,b):
inter_w = max(0, min(a[2],b[2]) - max(a[0],b[0]))
inter_h = max(0, min(a[3],b[3]) - max(a[1],b[1]))
inter = inter_w * inter_h
area_a = (a[2]-a[0])*(a[3]-a[1])
area_b = (b[2]-b[0])*(b[3]-b[1])
union = area_a + area_b - inter
return inter/union if union>0 else 0
# ---------------------------
# FastDeploy + TensorRT Runner
# ---------------------------
class FastDeployLamaTRT:
def __init__(self,
onnx_path: str,
backend: Literal["tensorrt","ort","openvino","paddle"] = "tensorrt",
precision: Literal["fp16","fp32","int8"] = "fp16",
device_id: int = 0,
input_hw: Tuple[int,int] = (512,512),
trt_cache: Optional[str] = "engines/lama_fp16.trt"):
import fastdeploy as fd
self.fd = fd
self.input_hw = input_hw
opt = fd.RuntimeOption()
# ONNX 경로 등록
opt.set_model_path(onnx_path, "", fd.ModelFormat.ONNX)
# Device
if backend == "tensorrt" or precision in ("fp16","int8"):
opt.use_gpu(device_id)
else:
opt.use_cpu()
# Backend / Precision
if backend == "tensorrt":
opt.use_trt_backend()
# 새 API
opt.trt_option.enable_fp16 = (precision == "fp16")
if precision == "int8":
opt.trt_option.enable_int8 = True
opt.trt_option.calibration_data = "calibration.table"
# 고정 shape 지정
opt.trt_option.set_shape("image",
[1,3,512,512],
[1,3,512,512],
[1,3,512,512])
opt.trt_option.set_shape("mask",
[1,1,512,512],
[1,1,512,512],
[1,1,512,512])
# 캐시 파일
if trt_cache:
Path(trt_cache).parent.mkdir(parents=True, exist_ok=True)
opt.trt_option.serialize_file = trt_cache
elif backend == "ort":
opt.use_ort_backend()
elif backend == "openvino":
opt.use_openvino_backend()
else:
opt.use_paddle_backend()
self.runtime = fd.Runtime(opt)
logger.debug("[Init] FastDeploy Runtime ready.")
def _forward_512(self, img_chw: np.ndarray, mask_chw: np.ndarray) -> np.ndarray:
inputs = {
"image": img_chw[np.newaxis, ...],
"mask": mask_chw[np.newaxis, ...],
}
start = time.time()
outputs = self.runtime.infer(inputs)
ms = (time.time() - start) * 1000
logger.debug(f"[Infer] {ms:.2f} ms")
if isinstance(outputs, dict):
out = outputs.get("output") or list(outputs.values())[0]
else:
out = outputs[0]
return out[0]
# 전체 512 리사이즈
def inpaint_full_resize(self, image: Image.Image, mask: Image.Image, invert_mask=False) -> Image.Image:
org_w, org_h = image.size
img_chw = pil_to_chw_float(image)
mask_chw = pil_to_chw_float(mask)
mask_chw = (mask_chw > 0).astype(np.float32)
if invert_mask:
mask_chw = 1.0 - mask_chw
img_r = resize_chw(img_chw, 512, 512, is_mask=False)
mask_r = resize_chw(mask_chw, 512, 512, is_mask=True)
out_chw = self._forward_512(img_r, mask_r)
return chw_to_pil(out_chw).resize((org_w, org_h), Image.BICUBIC)
# ROI 크롭 전략
def inpaint_with_rois(self,
image: Image.Image,
mask: Image.Image,
rois: List[Tuple[int,int,int,int]],
pad: int = 32,
invert_mask=False) -> Image.Image:
base_np = np.array(image)
mask_np = np.array(mask.convert("L"))
for (x1,y1,x2,y2) in rois:
x1p = max(0, x1 - pad); y1p = max(0, y1 - pad)
x2p = min(image.width, x2 + pad); y2p = min(image.height, y2 + pad)
crop_img = image.crop((x1p, y1p, x2p, y2p))
crop_mask = Image.fromarray(mask_np[y1p:y2p, x1p:x2p])
crop_img_chw = pil_to_chw_float(crop_img)
crop_mask_chw = pil_to_chw_float(crop_mask)
crop_mask_chw = (crop_mask_chw > 0).astype(np.float32)
if invert_mask:
crop_mask_chw = 1.0 - crop_mask_chw
h, w = crop_img_chw.shape[1:]
scale = min(512 / h, 512 / w)
nh, nw = int(h * scale), int(w * scale)
img_scaled = resize_chw(crop_img_chw, nh, nw, is_mask=False)
mask_scaled = resize_chw(crop_mask_chw, nh, nw, is_mask=True)
pad_h = 512 - nh; pad_w = 512 - nw
img_pad = np.pad(img_scaled, ((0,0),(0,pad_h),(0,pad_w)), mode="symmetric")
# ★ 마스크는 constant=0 로 패딩 ★
mask_pad = np.pad(mask_scaled, ((0,0),(0,pad_h),(0,pad_w)), mode="constant", constant_values=0)
out_chw = self._forward_512(img_pad, mask_pad)
out_crop = out_chw[:, :nh, :nw]
out_crop = resize_chw(out_crop, h, w, is_mask=False)
out_np = np.array(chw_to_pil(out_crop))
region = (mask_np[y1p:y2p, x1p:x2p] > 0)[..., None]
base_np[y1p:y2p, x1p:x2p] = np.where(region, out_np, base_np[y1p:y2p, x1p:x2p])
return Image.fromarray(base_np)
# ---------------------------
# CLI
# ---------------------------
def main():
import argparse
parser = argparse.ArgumentParser(description="LaMa TRT FP16 Inpainting")
parser.add_argument("--onnx", required=True)
parser.add_argument("--img", default="image_1.png")
parser.add_argument("--mask", default="mask_1.png")
parser.add_argument("--out", default="result.png")
parser.add_argument("--backend", default="tensorrt", choices=["tensorrt","ort","openvino","paddle"])
parser.add_argument("--precision", default="fp16", choices=["fp32","fp16","int8"])
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--mode", default="roi", choices=["roi","full"])
parser.add_argument("--roi_pad", type=int, default=32)
parser.add_argument("--trt_cache", default="engines/lama_fp16.trt")
parser.add_argument("--invert_mask", action="store_true", help="마스크 극성이 반대라면 켜세요")
args = parser.parse_args()
img = Image.open(args.img).convert("RGB")
mask = Image.open(args.mask).convert("L")
rects = []
if args.mode == "roi":
m = (np.array(mask) > 0).astype(np.uint8)
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in contours:
x,y,w,h = cv2.boundingRect(c)
rects.append((x, y, x+w, y+h))
rects = merge_rects(rects, 0.2)
logger.debug(f"ROI count: {len(rects)}")
inpainter = FastDeployLamaTRT(
onnx_path=args.onnx,
backend=args.backend,
precision=args.precision,
device_id=args.device,
input_hw=(512,512),
trt_cache=args.trt_cache
)
if args.mode == "full":
result = inpainter.inpaint_full_resize(img, mask, invert_mask=args.invert_mask)
else:
result = inpainter.inpaint_with_rois(img, mask, rects, pad=args.roi_pad, invert_mask=args.invert_mask)
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
result.save(args.out)
logger.info(f"[Done] saved -> {args.out}")
if __name__ == "__main__":
main()