918 lines
38 KiB
Python
918 lines
38 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
통합 인페인팅 모듈
|
|
- OpenCV 텍스트 최적화 인페인트
|
|
- SimpleLama (PyTorch)
|
|
- LaMa ONNX (Hugging Face: opencv/inpainting_lama_2025jan.onnx)
|
|
- MiGAN / EdgeConnect 어댑터 자리 마련
|
|
|
|
사용 예:
|
|
from worker.inpaint_module import Inpainter, InpaintBackends
|
|
inp = Inpainter(default_backend=InpaintBackends.LAMA_TORCH,
|
|
lama_device="cuda",
|
|
lama_onnx_path="/app/worker/models/inpainting_lama_2025jan.onnx")
|
|
out = inp.inpaint(img_bgr, [poly1, poly2, ...], backend=None, max_side=1024,
|
|
auto_opencv_if_few=True, few_threshold=4)
|
|
"""
|
|
from __future__ import annotations
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
from typing import Dict, Any, List, Tuple, Optional
|
|
from PIL import Image
|
|
import threading
|
|
|
|
|
|
# ───────────────────────────────────────────────
|
|
# 백엔드 식별자
|
|
# ───────────────────────────────────────────────
|
|
class InpaintBackends:
|
|
OPENCV = "opencv"
|
|
LAMA_TORCH = "lama_torch"
|
|
LAMA_ONNX_FD = "lama_onnx_fd" # FastDeploy 기반
|
|
LAMA_ONNX_ORT = "lama_onnx_ort" # 순수 onnxruntime 기반
|
|
MIGAN = "migan" # placeholder
|
|
EDGECONNECT = "edgeconnect" # placeholder
|
|
LAMA_TORCH_AMP = "lama_torch_amp" # 패치한 파일 위치 그대로
|
|
|
|
|
|
# ───────────────────────────────────────────────
|
|
# 유틸
|
|
# ───────────────────────────────────────────────
|
|
def _log(logger, msg, level=20):
|
|
"""logger가 있으면 logger.log로, 없으면 print"""
|
|
if logger and hasattr(logger, "log"):
|
|
logger.log(msg, level=level)
|
|
else:
|
|
print(msg)
|
|
|
|
def polygons_to_mask(shape: Tuple[int,int], polygons: List[List[List[int]]]) -> np.ndarray:
|
|
"""폴리곤 리스트 -> 단일 바이너리 마스크(0/255)"""
|
|
h, w = shape
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
for poly in polygons:
|
|
pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
|
|
cv2.fillPoly(mask, [pts], 255)
|
|
return mask
|
|
|
|
def union_bbox_of_mask(mask: np.ndarray, pad_ratio: float = 0.1) -> Tuple[int,int,int,int]:
|
|
"""마스크의 합집합 영역 bbox + 패딩"""
|
|
ys, xs = np.where(mask > 0)
|
|
if len(xs) == 0:
|
|
return 0,0,mask.shape[1],mask.shape[0]
|
|
x, y = int(xs.min()), int(ys.min())
|
|
w, h = int(xs.max()-xs.min()+1), int(ys.max()-ys.min()+1)
|
|
pad = int(max(w,h) * pad_ratio)
|
|
x0 = max(0, x - pad); y0 = max(0, y - pad)
|
|
x1 = min(mask.shape[1], x + w + pad)
|
|
y1 = min(mask.shape[0], y + h + pad)
|
|
return x0, y0, x1 - x0, y1 - y0
|
|
|
|
def resize_long_side(img: np.ndarray, max_side: int) -> Tuple[np.ndarray, float]:
|
|
"""가장 긴 변을 max_side로 맞춰 축소(확대 안함) + scale 반환"""
|
|
h, w = img.shape[:2]
|
|
if max(h, w) <= max_side:
|
|
return img, 1.0
|
|
if h >= w:
|
|
scale = max_side / float(h)
|
|
nh, nw = max_side, int(w * scale)
|
|
else:
|
|
scale = max_side / float(w)
|
|
nw, nh = max_side, int(h * scale)
|
|
out = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
|
|
return out, scale
|
|
|
|
def _next_pow2(n: int) -> int:
|
|
return 1 if n <= 1 else 1 << (n - 1).bit_length()
|
|
|
|
def _reflect_pad_to(img: np.ndarray, target_h: int, target_w: int) -> Tuple[np.ndarray, Tuple[int,int,int,int]]:
|
|
import cv2, numpy as np
|
|
h, w = img.shape[:2]
|
|
top = max(0, (target_h - h) // 2)
|
|
bottom = max(0, target_h - h - top)
|
|
left = max(0, (target_w - w) // 2)
|
|
right = max(0, target_w - w - left)
|
|
if top or bottom or left or right:
|
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_REFLECT_101)
|
|
return img, (top, bottom, left, right)
|
|
|
|
def _crop_by_pad(img: np.ndarray, pad: Tuple[int,int,int,int]) -> np.ndarray:
|
|
top, bottom, left, right = pad
|
|
if not (top or bottom or left or right):
|
|
return img
|
|
h, w = img.shape[:2]
|
|
return img[top:h-bottom, left:w-right]
|
|
|
|
# ───────────────────────────────────────────────
|
|
# OpenCV 텍스트 특화 인페인트
|
|
# ───────────────────────────────────────────────
|
|
def _opencv_text_inpaint(img_bgr: np.ndarray, mask: np.ndarray,
|
|
small_radius: int = 3, large_radius: int = 7,
|
|
dilate_px: int = 2, smooth_kernel: int = 3) -> np.ndarray:
|
|
"""
|
|
텍스트 제거 최적화:
|
|
1) 마스크 소폭 팽창 → 글자 테두리까지 포함
|
|
2) TELEA/r=3 1차 인페인트
|
|
3) 잔여 노이즈만 r=7로 2차 인페인트
|
|
4) 경계 feathering(가벼운 블렌딩)
|
|
"""
|
|
h, w = mask.shape[:2]
|
|
dil_k = max(0, int(dilate_px))
|
|
if dil_k > 0:
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dil_k+1, 2*dil_k+1))
|
|
mask1 = cv2.dilate(mask, kernel, iterations=1)
|
|
else:
|
|
mask1 = mask.copy()
|
|
|
|
out1 = cv2.inpaint(img_bgr, mask1, small_radius, cv2.INPAINT_TELEA)
|
|
|
|
# 남은 영역만 큰 반경으로 한 번 더
|
|
remain = (mask1 > 0) & (np.abs(out1.astype(np.int16) - img_bgr.astype(np.int16)).max(axis=2) > 3)
|
|
if remain.any():
|
|
mask2 = (remain.astype(np.uint8) * 255)
|
|
out2 = cv2.inpaint(out1, mask2, large_radius, cv2.INPAINT_TELEA)
|
|
else:
|
|
out2 = out1
|
|
|
|
# Feathering
|
|
k = (smooth_kernel | 1)
|
|
blur = cv2.GaussianBlur(mask1, (k, k), 0)
|
|
alpha = (blur.astype(np.float32) / 255.0)[..., None]
|
|
blended = (alpha * out2.astype(np.float32) + (1 - alpha) * img_bgr.astype(np.float32)).astype(np.uint8)
|
|
return blended
|
|
|
|
|
|
# ───────────────────────────────────────────────
|
|
# SimpleLama (PyTorch) 어댑터
|
|
# ───────────────────────────────────────────────
|
|
# _HAVE_LAMA_TORCH = False
|
|
# try:
|
|
# from simple_lama_inpainting.models.model import SimpleLama
|
|
# _HAVE_LAMA_TORCH = True
|
|
# except Exception:
|
|
# _HAVE_LAMA_TORCH = False
|
|
|
|
_HAVE_LAMA_TORCH = False
|
|
try:
|
|
# 패치한 파일 위치 그대로
|
|
from simple_lama_inpainting.models.model import SimpleLama
|
|
_HAVE_LAMA_TORCH = True
|
|
except Exception:
|
|
_HAVE_LAMA_TORCH = False
|
|
|
|
|
|
# ───────────────────────────────────────────────
|
|
# LaMa ONNX 어댑터 (opencv/inpainting_lama)
|
|
# ───────────────────────────────────────────────
|
|
class LamaOnnxORT:
|
|
def __init__(self, model_path: str, logger=None, providers=None, backend_hint: Optional[str]=None):
|
|
import os, onnxruntime as ort
|
|
|
|
self._log = (lambda m: logger.log(m) if logger and hasattr(logger, "log") else print)
|
|
|
|
# 세션 옵션 (원하면 스레드/그래프옵트 추가)
|
|
so = ort.SessionOptions()
|
|
so.log_severity_level = 2
|
|
# so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
# so.enable_mem_pattern = False
|
|
|
|
avail = ort.get_available_providers() # ['TensorrtExecutionProvider','CUDAExecutionProvider','CPUExecutionProvider'] 기대
|
|
self._log(f"[ORT] available providers={avail}")
|
|
|
|
def _truthy(s: str) -> bool:
|
|
return str(s).lower() in ("1", "true", "yes", "on")
|
|
|
|
# ── provider 리스트 구성 (우선순위: TRT → CUDA → CPU)
|
|
if providers is None:
|
|
# hint 강제
|
|
if backend_hint and backend_hint.lower() == "cpu":
|
|
providers = ["CPUExecutionProvider"]
|
|
else:
|
|
providers = []
|
|
|
|
# Tensorrt EP (있으면 최우선)
|
|
if "TensorrtExecutionProvider" in avail:
|
|
# env 기반 옵션 주입
|
|
trt_opts = {
|
|
"trt_engine_cache_enable": _truthy(os.getenv("ORT_TENSORRT_ENGINE_CACHE_ENABLE", "1")),
|
|
"trt_engine_cache_path": os.getenv("ORT_TENSORRT_CACHE_PATH", "/app/trt_cache"),
|
|
"trt_fp16_enable": _truthy(os.getenv("ORT_TENSORRT_FP16_ENABLE", "1")),
|
|
}
|
|
# 워크스페이스 (기본 1GB)
|
|
try:
|
|
trt_opts["trt_max_workspace_size"] = int(os.getenv("ORT_TENSORRT_MAX_WORKSPACE_SIZE", str(1 << 30)))
|
|
except Exception:
|
|
pass
|
|
|
|
# 캐시 디렉토리 보장
|
|
try:
|
|
os.makedirs(trt_opts["trt_engine_cache_path"], exist_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
providers.append(("TensorrtExecutionProvider", trt_opts))
|
|
|
|
# CUDA EP
|
|
if "CUDAExecutionProvider" in avail and (not backend_hint or backend_hint.lower() in ("cuda", "gpu")):
|
|
cuda_opts = {
|
|
# 선택 옵션들 — 버전에 따라 무시될 수 있음
|
|
"cudnn_conv_use_max_workspace": "1",
|
|
# "do_copy_in_default_stream": "1",
|
|
}
|
|
providers.append(("CUDAExecutionProvider", cuda_opts))
|
|
|
|
# CPU EP (항상 폴백)
|
|
providers.append("CPUExecutionProvider")
|
|
|
|
self._log(f"[ORT] providers={providers}")
|
|
self.sess = ort.InferenceSession(model_path, sess_options=so, providers=providers)
|
|
|
|
# IO 이름 로깅
|
|
self.input_name = self.sess.get_inputs()[0].name
|
|
self.output_name = self.sess.get_outputs()[0].name
|
|
self._log(f"[ORT] io: in={self.input_name}, out={self.output_name}")
|
|
|
|
def infer(self, img_bgr: np.ndarray, mask_gray: np.ndarray) -> np.ndarray:
|
|
import numpy as np, cv2
|
|
H, W = img_bgr.shape[:2]
|
|
target = 512
|
|
need_resize = (H != target or W != target)
|
|
if need_resize:
|
|
img = cv2.resize(img_bgr, (target, target), cv2.INTER_AREA)
|
|
msk = cv2.resize(mask_gray, (target, target), cv2.INTER_NEAREST)
|
|
else:
|
|
img, msk = img_bgr, mask_gray
|
|
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
|
m = (msk.astype(np.float32) / 255.0)[..., None]
|
|
inp = np.concatenate([img_rgb, m], axis=2) # H,W,4
|
|
blob = np.transpose(inp, (2,0,1))[None, ...] # 1,4,H,W
|
|
out = self.sess.run([self.output_name], {self.input_name: blob})[0] # 1,3,H,W
|
|
out_rgb = np.transpose(out[0], (1,2,0))
|
|
out_rgb = np.clip(out_rgb, 0.0, 1.0)
|
|
out_bgr = cv2.cvtColor((out_rgb * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR)
|
|
if need_resize:
|
|
out_bgr = cv2.resize(out_bgr, (W, H), cv2.INTER_CUBIC)
|
|
return out_bgr
|
|
|
|
def _log(self, msg):
|
|
if self.logger and hasattr(self.logger, "log"):
|
|
self.logger.log(msg)
|
|
else:
|
|
print(msg)
|
|
|
|
# 추가: FastDeploy 기반 ONNX LaMa
|
|
class LamaOnnxFD:
|
|
"""
|
|
FastDeploy Runtime으로 inpainting_lama_2025jan.onnx 실행
|
|
- model_path: ONNX 경로
|
|
- device: "gpu" or "cpu"
|
|
- device_id: GPU index
|
|
- backend: "ort" | "trt" (기본 ort)
|
|
"""
|
|
def __init__(self, model_path: str,
|
|
device: str = "gpu",
|
|
device_id: int = 0,
|
|
backend: str = "ort",
|
|
logger=None):
|
|
print("LamaOnnxFD init")
|
|
import fastdeploy as fd
|
|
self.fd = fd
|
|
self.logger = logger
|
|
self.model_path = model_path
|
|
|
|
opt = fd.RuntimeOption()
|
|
if device.lower() == "gpu":
|
|
opt.use_gpu(device_id)
|
|
# 백엔드 선택
|
|
try:
|
|
if backend.lower() == "trt":
|
|
opt.use_trt_backend()
|
|
# 필요시 워크스페이스/FP16 설정
|
|
opt.trt_option.enable_fp16 = True
|
|
opt.trt_option.max_workspace_size = 1 << 28 # 256MB
|
|
try:
|
|
print("TRT 프로필 설정 시작")
|
|
# 1) 입력 이름 알아내기
|
|
in_infos = None
|
|
try:
|
|
tmp_rt = self.fd.Runtime(opt)
|
|
in_infos = tmp_rt.get_input_info()
|
|
del tmp_rt
|
|
except Exception:
|
|
pass
|
|
|
|
# 2) 이름 모르면 첫 번째 입력을 "input" 가정 (나중에 로그로 확인)
|
|
input_names = [x.name for x in in_infos] if in_infos else ["input"]
|
|
|
|
# 3) min/opt/max 프로필 등록 (예: 256~1024 사이 허용)
|
|
min_hw, opt_hw, max_hw = 256, 512, 1024
|
|
for name in input_names:
|
|
# N,C,H,W = 1,4,*
|
|
opt.set_trt_input_shape(name,
|
|
min_shape=[1, 4, min_hw, min_hw],
|
|
opt_shape=[1, 4, opt_hw, opt_hw],
|
|
max_shape=[1, 4, max_hw, max_hw]
|
|
)
|
|
except Exception:
|
|
print("TRT 프로필 설정 실패")
|
|
|
|
try:
|
|
print("TRT 런타임 생성 시작")
|
|
self.runtime = self.fd.Runtime(opt)
|
|
in_infos = self.runtime.get_input_info()
|
|
out_infos = self.runtime.get_output_info()
|
|
self._log(f"[TRT] inputs={[ (i.name, i.shape) for i in in_infos ]}")
|
|
self._log(f"[TRT] outputs={[ (o.name, o.shape) for o in out_infos ]}")
|
|
except Exception as e:
|
|
self._log(f"[TRT] engine build failed: {e}")
|
|
# 안전하게 ORT fallback
|
|
opt_fallback = self.fd.RuntimeOption()
|
|
opt_fallback.use_ort_backend(); opt_fallback.use_gpu(device_id)
|
|
opt_fallback.set_model_path(model_path, model_format=self.fd.ModelFormat.ONNX)
|
|
self.runtime = self.fd.Runtime(opt_fallback)
|
|
self._log("[TRT] Fallback to ORT GPU")
|
|
print("TRT 런타임 생성 실패")
|
|
|
|
elif backend.lower() == "cuda":
|
|
opt.use_ort_backend()
|
|
opt.use_gpu(device_id)
|
|
elif backend.lower() == "cpu":
|
|
opt.use_ort_backend()
|
|
opt.use_cpu()
|
|
else: # "ort"
|
|
opt.use_ort_backend()
|
|
opt.use_gpu(device_id) # GPU ORT
|
|
except Exception as e:
|
|
self._log(f"[LaMa-ONNX-FD] backend init failed ({backend}), fallback to ORT: {e}")
|
|
opt = self.fd.RuntimeOption()
|
|
opt.use_ort_backend(); opt.use_gpu(device_id)
|
|
|
|
else:
|
|
opt.use_cpu()
|
|
opt.use_ort_backend()
|
|
opt.set_cpu_thread_num(2)
|
|
|
|
# ONNX 모델 지정
|
|
opt.set_model_path(model_path, model_format=self.fd.ModelFormat.ONNX)
|
|
|
|
# Runtime 생성
|
|
self.runtime = self.fd.Runtime(opt)
|
|
|
|
# 입력/출력 메타 확인해두면 디버깅 쉬움
|
|
try:
|
|
in_infos = self.runtime.get_input_info()
|
|
out_infos = self.runtime.get_output_info()
|
|
names_in = [x.name for x in in_infos]
|
|
names_out = [x.name for x in out_infos]
|
|
self._log(f"[LaMa-ONNX-FD] inputs={names_in} outputs={names_out}")
|
|
except Exception:
|
|
pass
|
|
|
|
def _log(self, msg):
|
|
if self.logger and hasattr(self.logger, "log"):
|
|
self.logger.log(msg)
|
|
else:
|
|
print(msg)
|
|
|
|
def infer(self, img_bgr, mask_gray):
|
|
import numpy as np, cv2
|
|
H, W = img_bgr.shape[:2]
|
|
|
|
# 모델이 자유 크기 지원이면 그대로, 아니면 512 정사각으로
|
|
target = 512
|
|
need_resize = not (H == target and W == target)
|
|
if need_resize:
|
|
img_resized = cv2.resize(img_bgr, (target, target), interpolation=cv2.INTER_AREA)
|
|
mask_resized = cv2.resize(mask_gray, (target, target), interpolation=cv2.INTER_NEAREST)
|
|
else:
|
|
img_resized, mask_resized = img_bgr, mask_gray
|
|
|
|
# 전처리: BGR->RGB, [0,1], mask → [0,1], 채널 concat (img 3ch + mask 1ch = 4ch)
|
|
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
|
m = (mask_resized.astype(np.float32) / 255.0)[..., None]
|
|
inp = np.concatenate([img_rgb, m], axis=2) # H,W,4
|
|
blob = np.transpose(inp, (2,0,1))[None, ...].astype(np.float32) # 1,4,H,W
|
|
|
|
# 추론
|
|
outputs = self.runtime.infer([blob])
|
|
out = outputs[0] # numpy array, shape (1,3,H,W) 예상
|
|
out_rgb = np.transpose(out[0], (1,2,0))
|
|
out_rgb = np.clip(out_rgb, 0.0, 1.0)
|
|
out_bgr = cv2.cvtColor((out_rgb * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR)
|
|
|
|
if need_resize:
|
|
out_bgr = cv2.resize(out_bgr, (W, H), interpolation=cv2.INTER_CUBIC)
|
|
return out_bgr
|
|
|
|
# ───────────────────────────────────────────────
|
|
# 메인 Inpainter
|
|
# ───────────────────────────────────────────────
|
|
from PIL import Image
|
|
import threading
|
|
|
|
_SIMPLE_LAMA_SINGLETON = None
|
|
class Inpainter:
|
|
_lock = threading.Lock() # 내부 초기화 경쟁 방지
|
|
|
|
|
|
def __init__(self,
|
|
logger=None,
|
|
default_backend: str = InpaintBackends.LAMA_TORCH,
|
|
lama_device: str = "cuda",
|
|
lama_onnx_ort_path: Optional[str] = None,
|
|
lama_onnx_ort_providers: Optional[list] = None,
|
|
lama_onnx_fd_path: Optional[str] = None,
|
|
lama_onnx_fd_device: str = "gpu",
|
|
lama_onnx_fd_device_id: int = 0,
|
|
lama_onnx_fd_backend: str = "ort"):
|
|
print("Inpainter init")
|
|
self.logger = logger
|
|
self.default_backend = (default_backend or InpaintBackends.LAMA_TORCH).lower()
|
|
self.lama_device = lama_device
|
|
|
|
# self.lama_onnx_ort_path = lama_onnx_ort_path or os.getenv("INPAINT_LAMA_ONNX", "/app/worker/models/inpainting_lama_2025jan.onnx")
|
|
self.lama_onnx_ort_path = lama_onnx_ort_path or os.getenv("INPAINT_LAMA_ONNX", "/app/worker/models/lama_fp32.onnx")
|
|
self.lama_onnx_ort_providers = lama_onnx_ort_providers
|
|
self._lama_onnx_ort = None
|
|
|
|
self._lama_torch = None
|
|
# self.lama_onnx_fd_path = lama_onnx_fd_path or os.getenv("INPAINT_LAMA_ONNX", "/app/worker/models/inpainting_lama_2025jan.onnx")
|
|
self.lama_onnx_fd_path = lama_onnx_fd_path or os.getenv("INPAINT_LAMA_ONNX", "/app/worker/models/lama_fp32.onnx")
|
|
self.lama_onnx_fd_device = lama_onnx_fd_device
|
|
self.lama_onnx_fd_device_id = lama_onnx_fd_device_id
|
|
self.lama_onnx_fd_backend = lama_onnx_fd_backend
|
|
|
|
self._lama_onnx_fd = None
|
|
|
|
self._lama_torch_amp = None # ⬅️ 추가
|
|
|
|
self._log(f"Inpainter init: default={self.default_backend}")
|
|
|
|
# 통일 로그
|
|
def _log(self, msg):
|
|
if self.logger and hasattr(self.logger, "log"): self.logger.log(msg)
|
|
else: print(msg)
|
|
|
|
|
|
# ── 백엔드별 lazy 생성 (스레드 세이프)
|
|
def _get_lama_onnx_ort(self, backend_hint: Optional[str] = None) -> LamaOnnxORT:
|
|
if self._lama_onnx_ort is None:
|
|
with self._lock:
|
|
if self._lama_onnx_ort is None:
|
|
self._log("[Init] LamaOnnxORT")
|
|
self._lama_onnx_ort = LamaOnnxORT(
|
|
model_path=self.lama_onnx_ort_path,
|
|
logger=self.logger,
|
|
providers=self.lama_onnx_ort_providers,
|
|
backend_hint=backend_hint
|
|
)
|
|
return self._lama_onnx_ort
|
|
|
|
def _get_lama_onnx_fd(self):
|
|
if self._lama_onnx_fd is None:
|
|
with self._lock:
|
|
if self._lama_onnx_fd is None:
|
|
self._log("[Init] LamaOnnxFD")
|
|
self._lama_onnx_fd = LamaOnnxFD(
|
|
model_path=self.lama_onnx_fd_path,
|
|
backend=self.lama_onnx_fd_backend,
|
|
device="gpu", device_id=0, logger=self.logger
|
|
)
|
|
return self._lama_onnx_fd
|
|
|
|
def _get_lama_torch(self):
|
|
if self._lama_torch is None:
|
|
with self._lock:
|
|
if self._lama_torch is None:
|
|
self._log("[Init] SimpleLaMa (torch)")
|
|
# from simple_lama_inpainting.models.model import SimpleLama
|
|
# self._lama_torch = SimpleLama(device=self.lama_device)
|
|
self._lama_torch = self.get_simple_lama(device=self.lama_device)
|
|
return self._lama_torch
|
|
|
|
@staticmethod
|
|
def get_simple_lama(device="cuda"):
|
|
global _SIMPLE_LAMA_SINGLETON
|
|
if _SIMPLE_LAMA_SINGLETON is None:
|
|
# 캐시 폴더 고정 (있으면 유지)
|
|
torch_home = "/app/torch_cache"
|
|
os.makedirs(torch_home, exist_ok=True)
|
|
os.environ.setdefault("TORCH_HOME", torch_home)
|
|
|
|
# (선택) 네가 fp16 체크포인트를 이 경로로 마운트해두면,
|
|
# 컨테이너 환경변수 또는 여기에서 직접 지정 가능
|
|
# 예) os.environ.setdefault("LAMA_MODEL", "/app/torch_cache/Big-LaMa.fp16.pt")
|
|
|
|
# 순서 힌트가 필요하면(보통 필요 없음): image_first | mask_first
|
|
# os.environ.setdefault("SIMPLE_LAMA_JIT_ORDER", "mask_first")
|
|
|
|
# 디버그(형상/순서 로그): "1"로 켜기
|
|
# os.environ.setdefault("SIMPLE_LAMA_DEBUG_SHAPES", "0")
|
|
|
|
# 패치된 SimpleLama는 내부에서 FP16/순서 자동 처리
|
|
m = SimpleLama(device=torch.device(device if device != "gpu" else "cuda"))
|
|
_SIMPLE_LAMA_SINGLETON = m
|
|
return _SIMPLE_LAMA_SINGLETON
|
|
|
|
def _get_lama_torch_amp(self):
|
|
if self._lama_torch_amp is None:
|
|
with self._lock:
|
|
if self._lama_torch_amp is None:
|
|
self._log("[Init] SimpleLaMa (torch AMP)")
|
|
# ckpt는 환경변수 SIMPLE_LAMA_CKPT 또는 simple-lama 기본 URL 자동 다운로드
|
|
self._lama_torch_amp = LamaTorchAMP(device=self.lama_device)
|
|
return self._lama_torch_amp
|
|
|
|
# ── Public API
|
|
def inpaint(self,
|
|
image_bgr: np.ndarray,
|
|
polygons: List[List[List[int]]],
|
|
*,
|
|
backend: Optional[str] = None,
|
|
max_side: int = 1024,
|
|
auto_opencv_if_few: bool = True,
|
|
few_threshold: int = 4,
|
|
backend_hint: Optional[str] = None) -> np.ndarray:
|
|
"""
|
|
Args:
|
|
img_bgr: 원본 BGR 이미지 (H,W,3)
|
|
polygons: [[ [x,y], [x,y], ... ], ...]
|
|
backend: 명시 시 강제 사용, None이면 default_backend
|
|
max_side: ROI 다운스케일 상한 (VRAM/속도 절충)
|
|
auto_opencv_if_few: 텍스트 박스가 적으면 OpenCV로 자동 전환
|
|
few_threshold: '적다'의 기준 (기본 4)
|
|
"""
|
|
|
|
backend = backend or self.default_backend
|
|
|
|
# 1) 폴리곤 → 마스크
|
|
mask = np.zeros(image_bgr.shape[:2], np.uint8)
|
|
for poly in polygons:
|
|
pts = np.array(poly, dtype=np.int32)
|
|
cv2.fillPoly(mask, [pts], color=255)
|
|
|
|
# 2) 마스크 수가 적으면 OpenCV로 빠르게
|
|
cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
if auto_opencv_if_few and len(cnts) <= few_threshold and backend != InpaintBackends.OPENCV:
|
|
backend = InpaintBackends.OPENCV
|
|
|
|
# 3) ROI + 다운스케일
|
|
ys, xs = np.where(mask > 0)
|
|
if len(xs) == 0 or len(ys) == 0:
|
|
return image_bgr
|
|
x1, x2, y1, y2 = xs.min(), xs.max(), ys.min(), ys.max()
|
|
roi_img = image_bgr[y1:y2+1, x1:x2+1]
|
|
roi_mask = mask[y1:y2+1, x1:x2+1]
|
|
h, w = roi_img.shape[:2]
|
|
if max(h, w) > max_side:
|
|
scale = max_side / float(max(h, w))
|
|
roi_img_small = cv2.resize(roi_img, (int(w*scale), int(h*scale)), cv2.INTER_AREA)
|
|
roi_mask_small = cv2.resize(roi_mask, (int(w*scale), int(h*scale)), cv2.INTER_NEAREST)
|
|
else:
|
|
roi_img_small, roi_mask_small = roi_img, roi_mask
|
|
|
|
# 4) 백엔드 호출
|
|
if backend == InpaintBackends.OPENCV:
|
|
out_small = cv2.inpaint(roi_img_small, roi_mask_small, 3, cv2.INPAINT_TELEA)
|
|
|
|
elif backend == InpaintBackends.LAMA_TORCH:
|
|
mdl = self._get_lama_torch()
|
|
img_pil = Image.fromarray(cv2.cvtColor(roi_img_small, cv2.COLOR_BGR2RGB))
|
|
msk_pil = Image.fromarray(roi_mask_small, "L") # 1채널 보장
|
|
|
|
# 패치된 SimpleLama가 FP16/순서/채널을 내부에서 처리
|
|
out_pil = mdl(img_pil, msk_pil)
|
|
|
|
out_small = cv2.cvtColor(np.array(out_pil), cv2.COLOR_RGB2BGR)
|
|
|
|
elif backend == InpaintBackends.LAMA_TORCH_AMP:
|
|
# SimpleLama(fp32 가중치) + autocast(fp16) + cuFFT pow2 안전 패딩
|
|
mdl = self._get_lama_torch() # simple_lama_inpainting.models.model.SimpleLama (fp32)
|
|
img_roi = roi_img_small
|
|
msk_roi = roi_mask_small
|
|
|
|
# pow2 패딩(AMP에서 FFC/cuFFT 반쯤 쓰는 모델 보호)
|
|
H, W = img_roi.shape[:2]
|
|
th, tw = _next_pow2(H), _next_pow2(W)
|
|
if (th % 8) != 0: th = ((th + 7) // 8) * 8
|
|
if (tw % 8) != 0: tw = ((tw + 7) // 8) * 8
|
|
|
|
pad_info = (0,0,0,0)
|
|
if (th, tw) != (H, W):
|
|
img_roi, pad_info = _reflect_pad_to(img_roi, th, tw)
|
|
msk_roi, _ = _reflect_pad_to(msk_roi, th, tw)
|
|
|
|
# PIL 변환
|
|
img_pil = Image.fromarray(cv2.cvtColor(img_roi, cv2.COLOR_BGR2RGB))
|
|
msk_pil = Image.fromarray(msk_roi, "L")
|
|
|
|
# AMP (모델은 fp32 유지, 연산은 자동 혼합정밀)
|
|
import torch
|
|
with torch.cuda.amp.autocast(enabled=(self.lama_device in ("cuda","gpu")), dtype=torch.float16):
|
|
out_pil = mdl(img_pil, msk_pil)
|
|
|
|
out_small = cv2.cvtColor(np.array(out_pil), cv2.COLOR_RGB2BGR)
|
|
|
|
# 패딩 되돌리기
|
|
if pad_info != (0,0,0,0):
|
|
out_small = _crop_by_pad(out_small, pad_info)
|
|
|
|
|
|
elif backend == InpaintBackends.LAMA_ONNX_FD:
|
|
mdl = self._get_lama_onnx_fd()
|
|
out_small = mdl.infer(roi_img_small, roi_mask_small)
|
|
|
|
elif backend == InpaintBackends.LAMA_ONNX_ORT:
|
|
mdl = self._get_lama_onnx_ort(backend_hint=backend_hint) # "cuda"/"cpu" 힌트
|
|
out_small = mdl.infer(roi_img_small, roi_mask_small)
|
|
else:
|
|
# 안전폴백
|
|
out_small = cv2.inpaint(roi_img_small, roi_mask_small, 3, cv2.INPAINT_TELEA)
|
|
|
|
# 5) 업스케일 + 합성
|
|
if out_small.shape[:2] != roi_img.shape[:2]:
|
|
out_roi = cv2.resize(out_small, (roi_img.shape[1], roi_img.shape[0]), cv2.INTER_CUBIC)
|
|
else:
|
|
out_roi = out_small
|
|
|
|
result = image_bgr.copy()
|
|
m = (roi_mask > 0)[:, :, None]
|
|
result[y1:y2+1, x1:x2+1] = np.where(m, out_roi, roi_img)
|
|
return result
|
|
|
|
import torch, numpy as np
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
|
|
class _SimpleLamaFPCompat:
|
|
"""
|
|
FP16 TorchScript / state_dict 체크포인트를 SimpleLama처럼 호출 가능하게 래핑.
|
|
__call__(image_pil|ndarray, mask_pil|ndarray) -> PIL.Image
|
|
- 가중치가 fp16이면 입력도 fp16으로 자동 캐스팅(AMP 포함)
|
|
- JIT 빌드(enesmsahin big-lama JIT)는 (mask, image) 순서를 기대
|
|
원본 SimpleLama는 (image, mask) 순서 → is_jit 플래그로 분기
|
|
"""
|
|
def __init__(self, model, device="cuda", is_jit=True, is_fp16=True):
|
|
self.model = model.eval()
|
|
self.device = torch.device("cuda" if device in ("cuda","gpu") else device)
|
|
self.is_jit = is_jit
|
|
self.is_fp16 = is_fp16
|
|
self.model.to(self.device)
|
|
if self.is_fp16:
|
|
self.model.half()
|
|
|
|
@classmethod
|
|
def load(cls, ckpt_path: str, device="cuda"):
|
|
# 1) TorchScript 시도
|
|
try:
|
|
m = torch.jit.load(ckpt_path, map_location="cpu")
|
|
# fp16 여부 대략 추정 (파라미터가 없으면 fp16 JIT로 가정)
|
|
is_fp16 = True
|
|
try:
|
|
p = next(m.parameters())
|
|
is_fp16 = (p.dtype == torch.float16)
|
|
except StopIteration:
|
|
pass
|
|
return cls(m, device=device, is_jit=True, is_fp16=is_fp16)
|
|
except Exception:
|
|
pass
|
|
|
|
# 2) state_dict 시도 (원 SimpleLama 구조 필요)
|
|
from simple_lama_inpainting import SimpleLama
|
|
base = SimpleLama(device="cpu")
|
|
sd = torch.load(ckpt_path, map_location="cpu")
|
|
core = getattr(base, "model", base)
|
|
core.load_state_dict(sd, strict=False)
|
|
is_fp16 = any(p.dtype == torch.float16 for p in core.parameters())
|
|
return cls(base, device=device, is_jit=False, is_fp16=is_fp16)
|
|
|
|
# ---------- 유틸 ----------
|
|
@staticmethod
|
|
def _to_pil(x, mode=None):
|
|
if isinstance(x, Image.Image):
|
|
return x.convert(mode) if mode else x
|
|
if isinstance(x, np.ndarray):
|
|
if mode == "L":
|
|
if x.ndim == 2:
|
|
return Image.fromarray(x.astype(np.uint8), "L")
|
|
return Image.fromarray(x[..., 0].astype(np.uint8), "L")
|
|
if x.ndim == 3 and x.shape[2] == 3: # BGR -> RGB
|
|
x = x[..., ::-1]
|
|
return Image.fromarray(x.astype(np.uint8), "RGB")
|
|
raise TypeError(f"Unsupported input type: {type(x)}")
|
|
|
|
@staticmethod
|
|
def _to_numpy_rgb(img: Image.Image) -> np.ndarray:
|
|
if img.mode != "RGB":
|
|
img = img.convert("RGB")
|
|
arr = np.asarray(img, dtype=np.uint8)
|
|
if not arr.flags['C_CONTIGUOUS']:
|
|
arr = np.ascontiguousarray(arr)
|
|
return arr
|
|
|
|
@staticmethod
|
|
def _to_numpy_mask1(mask: Image.Image) -> np.ndarray:
|
|
if mask.mode != "L":
|
|
mask = mask.convert("L")
|
|
m = np.asarray(mask, dtype=np.uint8)
|
|
if not m.flags['C_CONTIGUOUS']:
|
|
m = np.ascontiguousarray(m)
|
|
return (m > 127).astype(np.float32) # 0/1
|
|
|
|
@staticmethod
|
|
def _pad8_reflect(t: torch.Tensor, target_dtype: torch.dtype):
|
|
h, w = t.shape[-2:]
|
|
nh = (h + 7) // 8 * 8
|
|
nw = (w + 7) // 8 * 8
|
|
if nh == h and nw == w:
|
|
return t, (0,0,0,0)
|
|
ph, pw = nh - h, nw - w
|
|
t32 = t.to(torch.float32)
|
|
t32 = F.pad(t32, (0, pw, 0, ph), mode="reflect") # reflect는 fp16 미지원 버전 존재
|
|
return t32.to(target_dtype), (0, pw, 0, ph)
|
|
|
|
# ---------- 호출 ----------
|
|
@torch.inference_mode()
|
|
def __call__(self, image: Image.Image, mask: Image.Image) -> Image.Image:
|
|
# 모델 dtype/디바이스
|
|
try:
|
|
p0 = next(self.model.parameters())
|
|
target_dtype = p0.dtype
|
|
device = p0.device
|
|
except StopIteration:
|
|
target_dtype = torch.float16 if self.device.type == "cuda" and self.is_fp16 else torch.float32
|
|
device = self.device
|
|
|
|
# numpy → tensor
|
|
img_np = self._to_numpy_rgb(self._to_pil(image, "RGB")) # H,W,3 uint8
|
|
msk_np = self._to_numpy_mask1(self._to_pil(mask, "L")) # H,W float32 {0,1}
|
|
|
|
img_t = torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).to(device=device, dtype=torch.float32) / 255.0 # 1,3,H,W
|
|
msk_t = torch.from_numpy(msk_np).unsqueeze(0).unsqueeze(0).to(device=device, dtype=torch.float32) # 1,1,H,W
|
|
|
|
# pad (fp32) → target dtype
|
|
img_t, pad_hw = self._pad8_reflect(img_t, torch.float32)
|
|
msk_t, _ = self._pad8_reflect(msk_t, torch.float32)
|
|
img_t = img_t.to(dtype=target_dtype)
|
|
msk_t = msk_t.to(dtype=target_dtype)
|
|
|
|
# 호출 순서 분기
|
|
if self.is_jit:
|
|
# JIT big-lama는 (mask, image) 순서
|
|
out = self.model(msk_t, img_t)
|
|
else:
|
|
# 원 SimpleLama는 (image, mask) 순서
|
|
out = self.model(img_t, msk_t)
|
|
|
|
# unpad 및 to PIL
|
|
_, _, H, W = img_t.shape
|
|
_, pw, _, ph = pad_hw
|
|
if ph or pw:
|
|
out = out[..., :H-ph, :W-pw]
|
|
out = out.clamp(0, 1).to(torch.float32)
|
|
out_np = (out[0].permute(1,2,0).cpu().numpy() * 255.0 + 0.5).astype(np.uint8)
|
|
return Image.fromarray(out_np, "RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, torch, torch.nn.functional as F, numpy as np
|
|
from PIL import Image
|
|
|
|
def _to_pil_rgb(x):
|
|
if isinstance(x, Image.Image):
|
|
return x.convert("RGB")
|
|
if isinstance(x, np.ndarray):
|
|
if x.ndim == 3 and x.shape[2] == 3:
|
|
# BGR -> RGB
|
|
x = x[..., ::-1]
|
|
return Image.fromarray(x.astype(np.uint8)).convert("RGB")
|
|
raise TypeError(f"unsupported image type: {type(x)}")
|
|
|
|
def _to_pil_maskL(x):
|
|
if isinstance(x, Image.Image):
|
|
return x.convert("L")
|
|
if isinstance(x, np.ndarray):
|
|
if x.ndim == 3:
|
|
x = x[..., 0]
|
|
return Image.fromarray(x.astype(np.uint8)).convert("L")
|
|
raise TypeError(f"unsupported mask type: {type(x)}")
|
|
|
|
def _pad_mod8_reflect_nchw(t: torch.Tensor):
|
|
# t: NCHW (float32)
|
|
_, _, h, w = t.shape
|
|
nh = (h + 7) // 8 * 8
|
|
nw = (w + 7) // 8 * 8
|
|
if nh == h and nw == w:
|
|
return t, (0,0,0,0)
|
|
ph, pw = nh - h, nw - w
|
|
top = ph // 2; bottom = ph - top
|
|
left = pw // 2; right = pw - left
|
|
t32 = F.pad(t, (left, right, top, bottom), mode="reflect")
|
|
return t32, (top, bottom, left, right)
|
|
|
|
def _crop_from_pad_nchw(t: torch.Tensor, pad):
|
|
top, bottom, left, right = pad
|
|
if top==bottom==left==right==0:
|
|
return t
|
|
return t[..., top:t.shape[-2]-bottom, left:t.shape[-1]-right]
|
|
|
|
def _detect_arg_order(script_module) -> str:
|
|
"""
|
|
TorchScript LaMa(JIT)의 forward 인자 순서 추정.
|
|
- enesmsahin big-lama.pt: (mask, image) 가 일반적.
|
|
- 안전하게 스키마/코드에서 먼저 감지, 실패 시 'mask_im' 기본.
|
|
"""
|
|
try:
|
|
sch = str(getattr(script_module, "forward").schema).lower()
|
|
if "tensor mask" in sch and "tensor image" in sch:
|
|
return "mask_im" if sch.index("tensor mask") < sch.index("tensor image") else "im_mask"
|
|
except Exception:
|
|
pass
|
|
code = getattr(script_module, "code", "")
|
|
if isinstance(code, str):
|
|
if "forward(mask, image" in code.replace(" ", ""):
|
|
return "mask_im"
|
|
if "forward(image, mask" in code.replace(" ", ""):
|
|
return "im_mask"
|
|
# 기본값
|
|
return os.getenv("SIMPLE_LAMA_ARG_ORDER", "mask_im").lower()
|
|
|
|
class LamaTorchAMP:
|
|
"""
|
|
- 가중치: FP32 유지
|
|
- 추론: torch.cuda.amp.autocast(dtype=torch.float16)
|
|
- 입력: RGB/0..1, mask 1ch/0..1, NCHW, mod=8 reflect pad
|
|
"""
|
|
def __init__(self, device="cuda", ckpt_path: str|None=None):
|
|
self.device = torch.device("cuda" if device in ("cuda","gpu") and torch.cuda.is_available() else "cpu")
|
|
|
|
# 체크포인트 경로: 우선순위 ENV → 인자 → simple-lama 기본 URL 다운로드
|
|
if ckpt_path is None:
|
|
ckpt_path = os.getenv("SIMPLE_LAMA_CKPT")
|
|
if ckpt_path is None or not os.path.isfile(ckpt_path):
|
|
# simple-lama의 다운로드 유틸 재사용
|
|
from simple_lama_inpainting.utils.util import download_model
|
|
from simple_lama_inpainting.models.model import LAMA_MODEL_URL
|
|
ckpt_path = download_model(LAMA_MODEL_URL)
|
|
|
|
m = torch.jit.load(ckpt_path, map_location="cpu").eval()
|
|
|
|
|
|
try:
|
|
m = m.to(dtype=torch.float32)
|
|
except Exception:
|
|
# TorchScript에서 .to 실패하는 경우 수동 승격
|
|
for p in m.parameters(recurse=True):
|
|
if p.dtype != torch.float32:
|
|
p.data = p.data.float()
|
|
for b in m.buffers(recurse=True):
|
|
if b.dtype != torch.float32:
|
|
b.data = b.data.float()
|
|
|
|
m = m.to(self.device) # FP32 유지
|
|
self.model = m
|
|
self.device = torch.device(device if device != "gpu" else "cuda")
|
|
self.arg_order = _detect_arg_order(m)
|
|
if self.device.type == "cuda":
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
@torch.inference_mode()
|
|
def __call__(self, image: Image.Image|np.ndarray, mask: Image.Image|np.ndarray) -> Image.Image:
|
|
im = _to_pil_rgb(image)
|
|
mk = _to_pil_maskL(mask)
|
|
|
|
im_np = np.asarray(im, dtype=np.uint8)
|
|
mk_np = np.asarray(mk, dtype=np.uint8)
|
|
|
|
im_t = torch.from_numpy(im_np).permute(2,0,1).unsqueeze(0).to(self.device, dtype=torch.float32) / 255.0 # 1,3,H,W
|
|
mk_f = (mk_np > 127).astype(np.float32)
|
|
mk_t = torch.from_numpy(mk_f).unsqueeze(0).unsqueeze(0).to(self.device, dtype=torch.float32) # 1,1,H,W
|
|
|
|
# mod=8 pad (float32에서)
|
|
im_t, pad = _pad_mod8_reflect_nchw(im_t)
|
|
mk_t, _ = _pad_mod8_reflect_nchw(mk_t)
|
|
|
|
# AMP 추론
|
|
if self.device.type == "cuda":
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
out = self.model(mk_t, im_t) if self.arg_order == "mask_im" else self.model(im_t, mk_t)
|
|
else:
|
|
out = self.model(mk_t, im_t) if self.arg_order == "mask_im" else self.model(im_t, mk_t)
|
|
|
|
out = out[0] if isinstance(out, (list, tuple)) else out # NCHW
|
|
out = _crop_from_pad_nchw(out, pad).clamp(0,1).to(torch.float32)
|
|
out_np = (out[0].permute(1,2,0).cpu().numpy() * 255.0 + 0.5).astype(np.uint8)
|
|
return Image.fromarray(out_np, "RGB")
|