301 lines
11 KiB
Python
301 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
lama_trt_inpaint.py (수정版)
|
|
|
|
- FastDeploy + TensorRT(FP16)로 lama_fp32.onnx 실행
|
|
- 512x512 고정 입력 모델 대응:
|
|
* full : 전체 이미지를 512로 리사이즈 후 복원
|
|
* roi : 마스크 영역만 크롭→512 맞춰 인페인트→원본에 합성 (권장)
|
|
|
|
필요:
|
|
pip install fastdeploy-gpu-python pillow opencv-python numpy
|
|
|
|
사용 예:
|
|
python lama_trt_inpaint.py \
|
|
--onnx lama_fp32.onnx \
|
|
--img image_1.png \
|
|
--mask mask_1.png \
|
|
--out result.png \
|
|
--mode roi \
|
|
--backend tensorrt --precision fp16
|
|
"""
|
|
|
|
import sys
|
|
import cv2
|
|
import time
|
|
import logging
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Literal, Optional
|
|
from PIL import Image
|
|
|
|
# ---------------------------
|
|
# Logger
|
|
# ---------------------------
|
|
def get_logger(name="lama_trt"):
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(logging.DEBUG)
|
|
if not logger.handlers:
|
|
sh = logging.StreamHandler(sys.stdout)
|
|
sh.setLevel(logging.DEBUG)
|
|
fmt = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")
|
|
sh.setFormatter(fmt)
|
|
logger.addHandler(sh)
|
|
return logger
|
|
|
|
logger = get_logger()
|
|
|
|
# ---------------------------
|
|
# Utils
|
|
# ---------------------------
|
|
def pil_to_chw_float(img: Image.Image | np.ndarray) -> np.ndarray:
|
|
if isinstance(img, Image.Image):
|
|
arr = np.array(img)
|
|
else:
|
|
arr = img.copy()
|
|
if arr.ndim == 3:
|
|
arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
|
|
elif arr.ndim == 2:
|
|
arr = arr[np.newaxis, ...] # HW -> 1HW
|
|
else:
|
|
raise ValueError("Unexpected ndim for image.")
|
|
return arr.astype(np.float32) / 255.0
|
|
|
|
def chw_to_pil(arr: np.ndarray) -> Image.Image:
|
|
arr = np.clip(arr, 0, 1)
|
|
arr = (arr * 255).astype(np.uint8)
|
|
if arr.shape[0] == 1:
|
|
return Image.fromarray(arr[0], mode="L")
|
|
arr = np.transpose(arr, (1, 2, 0))
|
|
return Image.fromarray(arr)
|
|
|
|
def resize_chw(arr: np.ndarray, h: int, w: int, is_mask=False) -> np.ndarray:
|
|
if arr.shape[0] == 1:
|
|
data = arr[0]
|
|
inter = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
|
|
data = cv2.resize(data, (w, h), interpolation=inter)
|
|
return data[np.newaxis, ...]
|
|
else:
|
|
data = np.transpose(arr, (1, 2, 0))
|
|
inter = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
|
|
data = cv2.resize(data, (w, h), interpolation=inter)
|
|
return np.transpose(data, (2, 0, 1))
|
|
|
|
def merge_rects(rects: List[Tuple[int,int,int,int]], iou_thresh=0.2) -> List[Tuple[int,int,int,int]]:
|
|
if not rects:
|
|
return []
|
|
rects = sorted(rects, key=lambda r: (r[0], r[1]))
|
|
merged = []
|
|
for r in rects:
|
|
merged_flag = False
|
|
for i, mr in enumerate(merged):
|
|
if _iou_rect(r, mr) > iou_thresh or _intersects(r, mr):
|
|
merged[i] = _union_rect(mr, r)
|
|
merged_flag = True
|
|
break
|
|
if not merged_flag:
|
|
merged.append(r)
|
|
return merged
|
|
|
|
def _union_rect(a,b):
|
|
return (min(a[0],b[0]), min(a[1],b[1]), max(a[2],b[2]), max(a[3],b[3]))
|
|
def _intersects(a,b):
|
|
return not (a[2] <= b[0] or a[0] >= b[2] or a[3] <= b[1] or a[1] >= b[3])
|
|
def _iou_rect(a,b):
|
|
inter_w = max(0, min(a[2],b[2]) - max(a[0],b[0]))
|
|
inter_h = max(0, min(a[3],b[3]) - max(a[1],b[1]))
|
|
inter = inter_w * inter_h
|
|
area_a = (a[2]-a[0])*(a[3]-a[1])
|
|
area_b = (b[2]-b[0])*(b[3]-b[1])
|
|
union = area_a + area_b - inter
|
|
return inter/union if union>0 else 0
|
|
|
|
# ---------------------------
|
|
# FastDeploy + TensorRT Runner
|
|
# ---------------------------
|
|
class FastDeployLamaTRT:
|
|
def __init__(self,
|
|
onnx_path: str,
|
|
backend: Literal["tensorrt","ort","openvino","paddle"] = "tensorrt",
|
|
precision: Literal["fp16","fp32","int8"] = "fp16",
|
|
device_id: int = 0,
|
|
input_hw: Tuple[int,int] = (512,512),
|
|
trt_cache: Optional[str] = "engines/lama_fp16.trt"):
|
|
import fastdeploy as fd
|
|
|
|
self.fd = fd
|
|
self.input_hw = input_hw
|
|
|
|
opt = fd.RuntimeOption()
|
|
# ONNX 경로 등록
|
|
opt.set_model_path(onnx_path, "", fd.ModelFormat.ONNX)
|
|
|
|
# Device
|
|
if backend == "tensorrt" or precision in ("fp16","int8"):
|
|
opt.use_gpu(device_id)
|
|
else:
|
|
opt.use_cpu()
|
|
|
|
# Backend / Precision
|
|
if backend == "tensorrt":
|
|
opt.use_trt_backend()
|
|
# 새 API
|
|
opt.trt_option.enable_fp16 = (precision == "fp16")
|
|
if precision == "int8":
|
|
opt.trt_option.enable_int8 = True
|
|
opt.trt_option.calibration_data = "calibration.table"
|
|
|
|
# 고정 shape 지정
|
|
opt.trt_option.set_shape("image",
|
|
[1,3,512,512],
|
|
[1,3,512,512],
|
|
[1,3,512,512])
|
|
opt.trt_option.set_shape("mask",
|
|
[1,1,512,512],
|
|
[1,1,512,512],
|
|
[1,1,512,512])
|
|
# 캐시 파일
|
|
if trt_cache:
|
|
Path(trt_cache).parent.mkdir(parents=True, exist_ok=True)
|
|
opt.trt_option.serialize_file = trt_cache
|
|
|
|
elif backend == "ort":
|
|
opt.use_ort_backend()
|
|
elif backend == "openvino":
|
|
opt.use_openvino_backend()
|
|
else:
|
|
opt.use_paddle_backend()
|
|
|
|
self.runtime = fd.Runtime(opt)
|
|
logger.debug("[Init] FastDeploy Runtime ready.")
|
|
|
|
def _forward_512(self, img_chw: np.ndarray, mask_chw: np.ndarray) -> np.ndarray:
|
|
inputs = {
|
|
"image": img_chw[np.newaxis, ...],
|
|
"mask": mask_chw[np.newaxis, ...],
|
|
}
|
|
start = time.time()
|
|
outputs = self.runtime.infer(inputs)
|
|
ms = (time.time() - start) * 1000
|
|
logger.debug(f"[Infer] {ms:.2f} ms")
|
|
|
|
if isinstance(outputs, dict):
|
|
out = outputs.get("output") or list(outputs.values())[0]
|
|
else:
|
|
out = outputs[0]
|
|
return out[0]
|
|
|
|
# 전체 512 리사이즈
|
|
def inpaint_full_resize(self, image: Image.Image, mask: Image.Image, invert_mask=False) -> Image.Image:
|
|
org_w, org_h = image.size
|
|
img_chw = pil_to_chw_float(image)
|
|
mask_chw = pil_to_chw_float(mask)
|
|
mask_chw = (mask_chw > 0).astype(np.float32)
|
|
if invert_mask:
|
|
mask_chw = 1.0 - mask_chw
|
|
|
|
img_r = resize_chw(img_chw, 512, 512, is_mask=False)
|
|
mask_r = resize_chw(mask_chw, 512, 512, is_mask=True)
|
|
|
|
out_chw = self._forward_512(img_r, mask_r)
|
|
return chw_to_pil(out_chw).resize((org_w, org_h), Image.BICUBIC)
|
|
|
|
# ROI 크롭 전략
|
|
def inpaint_with_rois(self,
|
|
image: Image.Image,
|
|
mask: Image.Image,
|
|
rois: List[Tuple[int,int,int,int]],
|
|
pad: int = 32,
|
|
invert_mask=False) -> Image.Image:
|
|
base_np = np.array(image)
|
|
mask_np = np.array(mask.convert("L"))
|
|
|
|
for (x1,y1,x2,y2) in rois:
|
|
x1p = max(0, x1 - pad); y1p = max(0, y1 - pad)
|
|
x2p = min(image.width, x2 + pad); y2p = min(image.height, y2 + pad)
|
|
|
|
crop_img = image.crop((x1p, y1p, x2p, y2p))
|
|
crop_mask = Image.fromarray(mask_np[y1p:y2p, x1p:x2p])
|
|
|
|
crop_img_chw = pil_to_chw_float(crop_img)
|
|
crop_mask_chw = pil_to_chw_float(crop_mask)
|
|
crop_mask_chw = (crop_mask_chw > 0).astype(np.float32)
|
|
if invert_mask:
|
|
crop_mask_chw = 1.0 - crop_mask_chw
|
|
|
|
h, w = crop_img_chw.shape[1:]
|
|
scale = min(512 / h, 512 / w)
|
|
nh, nw = int(h * scale), int(w * scale)
|
|
|
|
img_scaled = resize_chw(crop_img_chw, nh, nw, is_mask=False)
|
|
mask_scaled = resize_chw(crop_mask_chw, nh, nw, is_mask=True)
|
|
|
|
pad_h = 512 - nh; pad_w = 512 - nw
|
|
img_pad = np.pad(img_scaled, ((0,0),(0,pad_h),(0,pad_w)), mode="symmetric")
|
|
# ★ 마스크는 constant=0 로 패딩 ★
|
|
mask_pad = np.pad(mask_scaled, ((0,0),(0,pad_h),(0,pad_w)), mode="constant", constant_values=0)
|
|
|
|
out_chw = self._forward_512(img_pad, mask_pad)
|
|
out_crop = out_chw[:, :nh, :nw]
|
|
out_crop = resize_chw(out_crop, h, w, is_mask=False)
|
|
out_np = np.array(chw_to_pil(out_crop))
|
|
|
|
region = (mask_np[y1p:y2p, x1p:x2p] > 0)[..., None]
|
|
base_np[y1p:y2p, x1p:x2p] = np.where(region, out_np, base_np[y1p:y2p, x1p:x2p])
|
|
|
|
return Image.fromarray(base_np)
|
|
|
|
# ---------------------------
|
|
# CLI
|
|
# ---------------------------
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="LaMa TRT FP16 Inpainting")
|
|
parser.add_argument("--onnx", required=True)
|
|
parser.add_argument("--img", default="image_1.png")
|
|
parser.add_argument("--mask", default="mask_1.png")
|
|
parser.add_argument("--out", default="result.png")
|
|
parser.add_argument("--backend", default="tensorrt", choices=["tensorrt","ort","openvino","paddle"])
|
|
parser.add_argument("--precision", default="fp16", choices=["fp32","fp16","int8"])
|
|
parser.add_argument("--device", type=int, default=0)
|
|
parser.add_argument("--mode", default="roi", choices=["roi","full"])
|
|
parser.add_argument("--roi_pad", type=int, default=32)
|
|
parser.add_argument("--trt_cache", default="engines/lama_fp16.trt")
|
|
parser.add_argument("--invert_mask", action="store_true", help="마스크 극성이 반대라면 켜세요")
|
|
args = parser.parse_args()
|
|
|
|
img = Image.open(args.img).convert("RGB")
|
|
mask = Image.open(args.mask).convert("L")
|
|
|
|
rects = []
|
|
if args.mode == "roi":
|
|
m = (np.array(mask) > 0).astype(np.uint8)
|
|
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
for c in contours:
|
|
x,y,w,h = cv2.boundingRect(c)
|
|
rects.append((x, y, x+w, y+h))
|
|
rects = merge_rects(rects, 0.2)
|
|
logger.debug(f"ROI count: {len(rects)}")
|
|
|
|
inpainter = FastDeployLamaTRT(
|
|
onnx_path=args.onnx,
|
|
backend=args.backend,
|
|
precision=args.precision,
|
|
device_id=args.device,
|
|
input_hw=(512,512),
|
|
trt_cache=args.trt_cache
|
|
)
|
|
|
|
if args.mode == "full":
|
|
result = inpainter.inpaint_full_resize(img, mask, invert_mask=args.invert_mask)
|
|
else:
|
|
result = inpainter.inpaint_with_rois(img, mask, rects, pad=args.roi_pad, invert_mask=args.invert_mask)
|
|
|
|
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
|
result.save(args.out)
|
|
logger.info(f"[Done] saved -> {args.out}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|