244 lines
10 KiB
Python
244 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import annotations
|
|
import os, cv2, numpy as np
|
|
from typing import List, Tuple, Optional
|
|
from PIL import Image
|
|
|
|
# ── (옵션) LaMa
|
|
try:
|
|
from simple_lama_inpainting.models.model import SimpleLama
|
|
_HAVE_LAMA = True
|
|
except Exception:
|
|
_HAVE_LAMA = False
|
|
|
|
class InpaintBackends:
|
|
OPENCV = "opencv"
|
|
LAMA = "lama"
|
|
LAMA_TORCH_AMP = "lama_torch_amp" # placeholder
|
|
|
|
# ── 공통 유틸
|
|
def polygons_to_mask(shape: Tuple[int,int], polygons: List[List[List[int]]]) -> np.ndarray:
|
|
h, w = shape
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
for poly in polygons:
|
|
pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
|
|
cv2.fillPoly(mask, [pts], 255)
|
|
return mask
|
|
|
|
def resize_long_side(img: np.ndarray, max_side: int) -> Tuple[np.ndarray, float]:
|
|
h, w = img.shape[:2]
|
|
if max(h, w) <= max_side:
|
|
return img, 1.0
|
|
if h >= w:
|
|
s = max_side / float(h)
|
|
nh, nw = max_side, int(round(w * s))
|
|
else:
|
|
s = max_side / float(w)
|
|
nw, nh = max_side, int(round(h * s))
|
|
out = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
|
|
return out, s
|
|
|
|
def _soften_mask(mask: np.ndarray, *, dilate_px: int, blur_px: int) -> np.ndarray:
|
|
m = mask.copy()
|
|
if dilate_px > 0:
|
|
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate_px+1, 2*dilate_px+1))
|
|
m = cv2.dilate(m, k, iterations=1)
|
|
m = cv2.GaussianBlur(m, (blur_px | 1, blur_px | 1), 0)
|
|
return m
|
|
|
|
# ── 컴포넌트/ROI 유틸
|
|
def _connected_components(mask: np.ndarray, *, min_area: int = 30) -> List[Tuple[int,int,int,int]]:
|
|
num, _, stats, _ = cv2.connectedComponentsWithStats((mask > 0).astype(np.uint8), connectivity=8)
|
|
boxes = []
|
|
for cid in range(1, num):
|
|
x, y, w, h, area = stats[cid]
|
|
if w > 1 and h > 1 and area >= min_area:
|
|
boxes.append((int(x), int(y), int(w), int(h)))
|
|
return boxes
|
|
|
|
def _expand_box(b: Tuple[int,int,int,int], pad_ratio: float, W: int, H: int) -> Tuple[int,int,int,int]:
|
|
x, y, w, h = b
|
|
pad = int(round(max(w, h) * pad_ratio))
|
|
x0 = max(0, x - pad); y0 = max(0, y - pad)
|
|
x1 = min(W, x + w + pad); y1 = min(H, y + h + pad)
|
|
return x0, y0, x1 - x0, y1 - y0
|
|
|
|
def _min_gap(a: Tuple[int,int,int,int], b: Tuple[int,int,int,int]) -> int:
|
|
ax, ay, aw, ah = a; bx, by, bw, bh = b
|
|
ar, ab = ax + aw, ay + ah
|
|
br, bb = bx + bw, by + bh
|
|
dx = max(0, max(ax - br, bx - ar))
|
|
dy = max(0, max(ay - bb, by - ab))
|
|
return max(dx, dy) # L∞ gap
|
|
|
|
def _merge_close_boxes(boxes: List[Tuple[int,int,int,int]], *, thresh_px: int) -> List[Tuple[int,int,int,int]]:
|
|
if not boxes: return []
|
|
n = len(boxes)
|
|
parent = list(range(n))
|
|
def find(i):
|
|
while parent[i] != i:
|
|
parent[i] = parent[parent[i]]
|
|
i = parent[i]
|
|
return i
|
|
def union(i, j):
|
|
ri, rj = find(i), find(j)
|
|
if ri != rj: parent[rj] = ri
|
|
for i in range(n):
|
|
for j in range(i+1, n):
|
|
if _min_gap(boxes[i], boxes[j]) <= thresh_px:
|
|
union(i, j)
|
|
groups = {}
|
|
for i, b in enumerate(boxes):
|
|
r = find(i)
|
|
groups.setdefault(r, []).append(b)
|
|
merged = []
|
|
for grp in groups.values():
|
|
xs = [x for x,_,_,_ in grp]; ys = [y for _,y,_,_ in grp]
|
|
rs = [x+w for x,_,w,_ in grp]; bs = [y+h for _,y,_,h in grp]
|
|
x0, y0, x1, y1 = min(xs), min(ys), max(rs), max(bs)
|
|
merged.append((x0, y0, x1 - x0, y1 - y0))
|
|
return merged
|
|
|
|
class Inpainter:
|
|
def __init__(self, logger=None,
|
|
default_backend: str = InpaintBackends.LAMA,
|
|
lama_device: str = "cuda"):
|
|
self.logger = logger
|
|
self.default_backend = default_backend
|
|
self.lama_device = lama_device
|
|
self._lama: Optional[SimpleLama] = None
|
|
|
|
def _log(self, msg):
|
|
if self.logger and hasattr(self.logger, "log"): self.logger.log(msg)
|
|
else: print(msg)
|
|
|
|
def _get_lama(self):
|
|
if not _HAVE_LAMA:
|
|
raise RuntimeError("SimpleLama not installed")
|
|
if self._lama is None:
|
|
self._log("Init SimpleLama...")
|
|
self._lama = SimpleLama(device=self.lama_device)
|
|
return self._lama
|
|
|
|
# 평탄 배경에 강한 OpenCV (필요 시 사용)
|
|
def _opencv_text_inpaint(self, img_bgr: np.ndarray, hard_mask: np.ndarray,
|
|
r1: int = 3, r2: int = 7) -> np.ndarray:
|
|
out1 = cv2.inpaint(img_bgr, hard_mask, r1, cv2.INPAINT_TELEA)
|
|
remain = (hard_mask > 0) & (np.abs(out1.astype(np.int16) - img_bgr.astype(np.int16)).max(axis=2) > 3)
|
|
out2 = cv2.inpaint(out1, (remain.astype(np.uint8) * 255), r2, cv2.INPAINT_TELEA) if remain.any() else out1
|
|
return out2
|
|
|
|
def _run_backend(self, roi_img: np.ndarray, roi_mask: np.ndarray, backend: str) -> np.ndarray:
|
|
if backend == InpaintBackends.OPENCV:
|
|
return self._opencv_text_inpaint(roi_img, roi_mask)
|
|
elif backend in (InpaintBackends.LAMA, InpaintBackends.LAMA_TORCH_AMP):
|
|
lama = self._get_lama()
|
|
dst_pil = lama(Image.fromarray(cv2.cvtColor(roi_img, cv2.COLOR_BGR2RGB)),
|
|
Image.fromarray(roi_mask, "L"))
|
|
return cv2.cvtColor(np.array(dst_pil), cv2.COLOR_RGB2BGR)
|
|
else:
|
|
raise NotImplementedError(f"Backend {backend} not wired.")
|
|
|
|
def inpaint(self, img_bgr: np.ndarray, polygons: List[List[List[int]]],
|
|
*,
|
|
backend: Optional[str] = None,
|
|
# 공통
|
|
roi_strategy: str = "components", # "components" | "full"
|
|
max_side: int = 1600,
|
|
auto_opencv_if_few: bool = False,
|
|
few_threshold: int = 0,
|
|
# components 전용
|
|
comp_min_area: int = 30,
|
|
pad_ratio: float = 0.12,
|
|
merge_thresh_factor: float = 0.7,
|
|
merge_abs_min_px: int = 8,
|
|
soft_dilate_px: int = 10,
|
|
soft_blur_px: int = 17,
|
|
# 디버그 저장
|
|
debug_save_rois: bool = False,
|
|
debug_dir: Optional[str] = None,
|
|
request_id: Optional[str] = None) -> np.ndarray:
|
|
|
|
backend = (backend or self.default_backend).lower()
|
|
H, W = img_bgr.shape[:2]
|
|
base_mask = polygons_to_mask((H, W), polygons)
|
|
|
|
# ── 풀프레임 모드
|
|
if roi_strategy == "full":
|
|
img_small, s = resize_long_side(img_bgr, max_side)
|
|
mask_small = cv2.resize(base_mask, (img_small.shape[1], img_small.shape[0]),
|
|
interpolation=cv2.INTER_NEAREST) if s != 1.0 else base_mask
|
|
dst_small = self._run_backend(img_small, mask_small, InpaintBackends.LAMA)
|
|
|
|
# 소프트 블렌딩(테두리 얇게)
|
|
soft_small = _soften_mask(mask_small, dilate_px=soft_dilate_px, blur_px=soft_blur_px)
|
|
alpha = (soft_small.astype(np.float32) / 255.0)[..., None]
|
|
blended_small = (alpha * dst_small.astype(np.float32) + (1 - alpha) * img_small.astype(np.float32)).astype(np.uint8)
|
|
|
|
out = cv2.resize(blended_small, (W, H), interpolation=cv2.INTER_CUBIC) if s != 1.0 else blended_small
|
|
return out
|
|
|
|
# ── 컴포넌트 기반 ROI 모드
|
|
boxes = _connected_components(base_mask, min_area=comp_min_area)
|
|
if not boxes:
|
|
return img_bgr.copy()
|
|
|
|
heights = [h for _,_,_,h in boxes]
|
|
med_h = float(np.median(heights)) if heights else 0.0
|
|
merge_px = max(merge_abs_min_px, int(round(med_h * merge_thresh_factor)))
|
|
merged = _merge_close_boxes(boxes, thresh_px=merge_px)
|
|
rois = [_expand_box(b, pad_ratio, W, H) for b in merged]
|
|
rois.sort(key=lambda r: (r[1]//32, r[0]))
|
|
|
|
# 디버그 저장 준비
|
|
save_idx = 0
|
|
if debug_save_rois and debug_dir:
|
|
os.makedirs(debug_dir, exist_ok=True)
|
|
|
|
out = img_bgr.copy()
|
|
|
|
for (x, y, w, h) in rois:
|
|
if w <= 1 or h <= 1:
|
|
continue
|
|
|
|
roi_img = out[y:y+h, x:x+w]
|
|
roi_mask = base_mask[y:y+h, x:x+w]
|
|
roi_soft = _soften_mask(roi_mask, dilate_px=soft_dilate_px, blur_px=soft_blur_px)
|
|
|
|
roi_img_small, s = resize_long_side(roi_img, max_side)
|
|
if s != 1.0:
|
|
roi_mask_small = cv2.resize(roi_mask, (roi_img_small.shape[1], roi_img_small.shape[0]),
|
|
interpolation=cv2.INTER_NEAREST)
|
|
roi_soft_small = cv2.resize(roi_soft, (roi_img_small.shape[1], roi_img_small.shape[0]),
|
|
interpolation=cv2.INTER_LINEAR)
|
|
else:
|
|
roi_mask_small = roi_mask
|
|
roi_soft_small = roi_soft
|
|
|
|
use_backend = (InpaintBackends.OPENCV if (auto_opencv_if_few and len(merged) <= few_threshold)
|
|
else InpaintBackends.LAMA if backend not in (InpaintBackends.OPENCV,) else backend)
|
|
|
|
dst_small = self._run_backend(roi_img_small, roi_mask_small, use_backend)
|
|
if dst_small.shape[:2] != roi_img_small.shape[:2]:
|
|
dst_small = cv2.resize(dst_small, (roi_img_small.shape[1], roi_img_small.shape[0]), interpolation=cv2.INTER_CUBIC)
|
|
|
|
# 소프트 블렌딩
|
|
alpha = (roi_soft_small.astype(np.float32) / 255.0)[..., None]
|
|
blended_small = (alpha * dst_small.astype(np.float32) +
|
|
(1 - alpha) * roi_img_small.astype(np.float32)).astype(np.uint8)
|
|
|
|
# 원 크기로 복원
|
|
dst_roi = cv2.resize(blended_small, (w, h), interpolation=cv2.INTER_CUBIC) if s != 1.0 else blended_small
|
|
out[y:y+h, x:x+w] = dst_roi
|
|
|
|
# ── 중간 저장 (원본/마스크/결과)
|
|
if debug_save_rois and debug_dir:
|
|
base = f"{request_id or 'req'}_roi{save_idx:02d}"
|
|
cv2.imwrite(os.path.join(debug_dir, base + "_img.png"), roi_img)
|
|
cv2.imwrite(os.path.join(debug_dir, base + "_mask.png"), roi_mask)
|
|
cv2.imwrite(os.path.join(debug_dir, base + "_soft.png"), roi_soft)
|
|
cv2.imwrite(os.path.join(debug_dir, base + "_dst.png"), dst_roi)
|
|
save_idx += 1
|
|
|
|
return out
|