ImageProcessor_MainServer/worker/inpaint_module.py

244 lines
10 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import annotations
import os, cv2, numpy as np
from typing import List, Tuple, Optional
from PIL import Image
# ── (옵션) LaMa
try:
from simple_lama_inpainting.models.model import SimpleLama
_HAVE_LAMA = True
except Exception:
_HAVE_LAMA = False
class InpaintBackends:
OPENCV = "opencv"
LAMA = "lama"
LAMA_TORCH_AMP = "lama_torch_amp" # placeholder
# ── 공통 유틸
def polygons_to_mask(shape: Tuple[int,int], polygons: List[List[List[int]]]) -> np.ndarray:
h, w = shape
mask = np.zeros((h, w), dtype=np.uint8)
for poly in polygons:
pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
cv2.fillPoly(mask, [pts], 255)
return mask
def resize_long_side(img: np.ndarray, max_side: int) -> Tuple[np.ndarray, float]:
h, w = img.shape[:2]
if max(h, w) <= max_side:
return img, 1.0
if h >= w:
s = max_side / float(h)
nh, nw = max_side, int(round(w * s))
else:
s = max_side / float(w)
nw, nh = max_side, int(round(h * s))
out = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
return out, s
def _soften_mask(mask: np.ndarray, *, dilate_px: int, blur_px: int) -> np.ndarray:
m = mask.copy()
if dilate_px > 0:
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate_px+1, 2*dilate_px+1))
m = cv2.dilate(m, k, iterations=1)
m = cv2.GaussianBlur(m, (blur_px | 1, blur_px | 1), 0)
return m
# ── 컴포넌트/ROI 유틸
def _connected_components(mask: np.ndarray, *, min_area: int = 30) -> List[Tuple[int,int,int,int]]:
num, _, stats, _ = cv2.connectedComponentsWithStats((mask > 0).astype(np.uint8), connectivity=8)
boxes = []
for cid in range(1, num):
x, y, w, h, area = stats[cid]
if w > 1 and h > 1 and area >= min_area:
boxes.append((int(x), int(y), int(w), int(h)))
return boxes
def _expand_box(b: Tuple[int,int,int,int], pad_ratio: float, W: int, H: int) -> Tuple[int,int,int,int]:
x, y, w, h = b
pad = int(round(max(w, h) * pad_ratio))
x0 = max(0, x - pad); y0 = max(0, y - pad)
x1 = min(W, x + w + pad); y1 = min(H, y + h + pad)
return x0, y0, x1 - x0, y1 - y0
def _min_gap(a: Tuple[int,int,int,int], b: Tuple[int,int,int,int]) -> int:
ax, ay, aw, ah = a; bx, by, bw, bh = b
ar, ab = ax + aw, ay + ah
br, bb = bx + bw, by + bh
dx = max(0, max(ax - br, bx - ar))
dy = max(0, max(ay - bb, by - ab))
return max(dx, dy) # L∞ gap
def _merge_close_boxes(boxes: List[Tuple[int,int,int,int]], *, thresh_px: int) -> List[Tuple[int,int,int,int]]:
if not boxes: return []
n = len(boxes)
parent = list(range(n))
def find(i):
while parent[i] != i:
parent[i] = parent[parent[i]]
i = parent[i]
return i
def union(i, j):
ri, rj = find(i), find(j)
if ri != rj: parent[rj] = ri
for i in range(n):
for j in range(i+1, n):
if _min_gap(boxes[i], boxes[j]) <= thresh_px:
union(i, j)
groups = {}
for i, b in enumerate(boxes):
r = find(i)
groups.setdefault(r, []).append(b)
merged = []
for grp in groups.values():
xs = [x for x,_,_,_ in grp]; ys = [y for _,y,_,_ in grp]
rs = [x+w for x,_,w,_ in grp]; bs = [y+h for _,y,_,h in grp]
x0, y0, x1, y1 = min(xs), min(ys), max(rs), max(bs)
merged.append((x0, y0, x1 - x0, y1 - y0))
return merged
class Inpainter:
def __init__(self, logger=None,
default_backend: str = InpaintBackends.LAMA,
lama_device: str = "cuda"):
self.logger = logger
self.default_backend = default_backend
self.lama_device = lama_device
self._lama: Optional[SimpleLama] = None
def _log(self, msg):
if self.logger and hasattr(self.logger, "log"): self.logger.log(msg)
else: print(msg)
def _get_lama(self):
if not _HAVE_LAMA:
raise RuntimeError("SimpleLama not installed")
if self._lama is None:
self._log("Init SimpleLama...")
self._lama = SimpleLama(device=self.lama_device)
return self._lama
# 평탄 배경에 강한 OpenCV (필요 시 사용)
def _opencv_text_inpaint(self, img_bgr: np.ndarray, hard_mask: np.ndarray,
r1: int = 3, r2: int = 7) -> np.ndarray:
out1 = cv2.inpaint(img_bgr, hard_mask, r1, cv2.INPAINT_TELEA)
remain = (hard_mask > 0) & (np.abs(out1.astype(np.int16) - img_bgr.astype(np.int16)).max(axis=2) > 3)
out2 = cv2.inpaint(out1, (remain.astype(np.uint8) * 255), r2, cv2.INPAINT_TELEA) if remain.any() else out1
return out2
def _run_backend(self, roi_img: np.ndarray, roi_mask: np.ndarray, backend: str) -> np.ndarray:
if backend == InpaintBackends.OPENCV:
return self._opencv_text_inpaint(roi_img, roi_mask)
elif backend in (InpaintBackends.LAMA, InpaintBackends.LAMA_TORCH_AMP):
lama = self._get_lama()
dst_pil = lama(Image.fromarray(cv2.cvtColor(roi_img, cv2.COLOR_BGR2RGB)),
Image.fromarray(roi_mask, "L"))
return cv2.cvtColor(np.array(dst_pil), cv2.COLOR_RGB2BGR)
else:
raise NotImplementedError(f"Backend {backend} not wired.")
def inpaint(self, img_bgr: np.ndarray, polygons: List[List[List[int]]],
*,
backend: Optional[str] = None,
# 공통
roi_strategy: str = "components", # "components" | "full"
max_side: int = 1600,
auto_opencv_if_few: bool = False,
few_threshold: int = 0,
# components 전용
comp_min_area: int = 30,
pad_ratio: float = 0.12,
merge_thresh_factor: float = 0.7,
merge_abs_min_px: int = 8,
soft_dilate_px: int = 10,
soft_blur_px: int = 17,
# 디버그 저장
debug_save_rois: bool = False,
debug_dir: Optional[str] = None,
request_id: Optional[str] = None) -> np.ndarray:
backend = (backend or self.default_backend).lower()
H, W = img_bgr.shape[:2]
base_mask = polygons_to_mask((H, W), polygons)
# ── 풀프레임 모드
if roi_strategy == "full":
img_small, s = resize_long_side(img_bgr, max_side)
mask_small = cv2.resize(base_mask, (img_small.shape[1], img_small.shape[0]),
interpolation=cv2.INTER_NEAREST) if s != 1.0 else base_mask
dst_small = self._run_backend(img_small, mask_small, InpaintBackends.LAMA)
# 소프트 블렌딩(테두리 얇게)
soft_small = _soften_mask(mask_small, dilate_px=soft_dilate_px, blur_px=soft_blur_px)
alpha = (soft_small.astype(np.float32) / 255.0)[..., None]
blended_small = (alpha * dst_small.astype(np.float32) + (1 - alpha) * img_small.astype(np.float32)).astype(np.uint8)
out = cv2.resize(blended_small, (W, H), interpolation=cv2.INTER_CUBIC) if s != 1.0 else blended_small
return out
# ── 컴포넌트 기반 ROI 모드
boxes = _connected_components(base_mask, min_area=comp_min_area)
if not boxes:
return img_bgr.copy()
heights = [h for _,_,_,h in boxes]
med_h = float(np.median(heights)) if heights else 0.0
merge_px = max(merge_abs_min_px, int(round(med_h * merge_thresh_factor)))
merged = _merge_close_boxes(boxes, thresh_px=merge_px)
rois = [_expand_box(b, pad_ratio, W, H) for b in merged]
rois.sort(key=lambda r: (r[1]//32, r[0]))
# 디버그 저장 준비
save_idx = 0
if debug_save_rois and debug_dir:
os.makedirs(debug_dir, exist_ok=True)
out = img_bgr.copy()
for (x, y, w, h) in rois:
if w <= 1 or h <= 1:
continue
roi_img = out[y:y+h, x:x+w]
roi_mask = base_mask[y:y+h, x:x+w]
roi_soft = _soften_mask(roi_mask, dilate_px=soft_dilate_px, blur_px=soft_blur_px)
roi_img_small, s = resize_long_side(roi_img, max_side)
if s != 1.0:
roi_mask_small = cv2.resize(roi_mask, (roi_img_small.shape[1], roi_img_small.shape[0]),
interpolation=cv2.INTER_NEAREST)
roi_soft_small = cv2.resize(roi_soft, (roi_img_small.shape[1], roi_img_small.shape[0]),
interpolation=cv2.INTER_LINEAR)
else:
roi_mask_small = roi_mask
roi_soft_small = roi_soft
use_backend = (InpaintBackends.OPENCV if (auto_opencv_if_few and len(merged) <= few_threshold)
else InpaintBackends.LAMA if backend not in (InpaintBackends.OPENCV,) else backend)
dst_small = self._run_backend(roi_img_small, roi_mask_small, use_backend)
if dst_small.shape[:2] != roi_img_small.shape[:2]:
dst_small = cv2.resize(dst_small, (roi_img_small.shape[1], roi_img_small.shape[0]), interpolation=cv2.INTER_CUBIC)
# 소프트 블렌딩
alpha = (roi_soft_small.astype(np.float32) / 255.0)[..., None]
blended_small = (alpha * dst_small.astype(np.float32) +
(1 - alpha) * roi_img_small.astype(np.float32)).astype(np.uint8)
# 원 크기로 복원
dst_roi = cv2.resize(blended_small, (w, h), interpolation=cv2.INTER_CUBIC) if s != 1.0 else blended_small
out[y:y+h, x:x+w] = dst_roi
# ── 중간 저장 (원본/마스크/결과)
if debug_save_rois and debug_dir:
base = f"{request_id or 'req'}_roi{save_idx:02d}"
cv2.imwrite(os.path.join(debug_dir, base + "_img.png"), roi_img)
cv2.imwrite(os.path.join(debug_dir, base + "_mask.png"), roi_mask)
cv2.imwrite(os.path.join(debug_dir, base + "_soft.png"), roi_soft)
cv2.imwrite(os.path.join(debug_dir, base + "_dst.png"), dst_roi)
save_idx += 1
return out