381 lines
17 KiB
Python
381 lines
17 KiB
Python
"""
|
|
Simple LAMA 인페인팅 모델 구현
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
from typing import Union, Tuple, List
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from simple_lama_inpainting import SimpleLama
|
|
# 사용하지 않는 import 정리
|
|
# from ..utils.image_utils import (
|
|
# decode_base64_to_image,
|
|
# encode_image_to_base64,
|
|
# get_image_size,
|
|
# resize_image_if_needed,
|
|
# )
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SimpleLamaInpainter:
|
|
def __init__(self, model_path: str, device: str = "cpu", fp16: bool = False):
|
|
self.model_path = model_path
|
|
self._device = torch.device(device)
|
|
# LAMA 경로에서는 FP16을 사용하지 않습니다.
|
|
self._fp16 = False
|
|
self._model = None
|
|
self.loaded = False
|
|
# 동적 리사이즈 파라미터: 긴 변 상한 및 네트워크 호환 다중수
|
|
self._max_long_side = 1024
|
|
self._size_multiple = 8
|
|
# 자동 전체 인페인팅 전환 조건 (환경변수로 설정 가능)
|
|
import os
|
|
self._mask_area_ratio_threshold = float(os.getenv('LAMA_MASK_AREA_RATIO', '0.5')) # 마스크 면적이 전체의 50% 이상
|
|
self._roi_area_ratio_threshold = float(os.getenv('LAMA_ROI_AREA_RATIO', '0.7')) # ROI가 전체 이미지의 70% 이상
|
|
self._min_mask_components = int(os.getenv('LAMA_MIN_COMPONENTS', '5')) # 마스크 컴포넌트가 5개 이상 (분산도)
|
|
self._roi_margin = int(os.getenv('LAMA_ROI_MARGIN', '32')) # ROI 마진 (기본 32px)
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info("Loading Simple LAMA model...")
|
|
|
|
# 실제 simple-lama-inpainting 라이브러리 사용
|
|
try:
|
|
self._model = SimpleLama(device=self._device)
|
|
logger.info("실제 SimpleLama 모델 로딩 완료")
|
|
except ImportError as e:
|
|
logger.warning(f"SimpleLama 라이브러리 import 실패: {e}")
|
|
logger.info("fallback 모드로 전환합니다...")
|
|
# fallback으로 시뮬레이션 모드 사용
|
|
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
|
|
except Exception as e:
|
|
logger.error(f"SimpleLama 모델 초기화 실패: {e}")
|
|
logger.info("fallback 모드로 전환합니다...")
|
|
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
|
|
|
|
self.loaded = True
|
|
logger.info("Simple LAMA model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Simple LAMA model: {e}")
|
|
raise
|
|
|
|
def _get_mask_bbox(self, mask: Image.Image) -> Union[Tuple[int, int, int, int], None]:
|
|
"""마스크의 유효 영역 바운딩 박스를 반환합니다. 없으면 None 반환."""
|
|
m = mask.convert("L")
|
|
m_bin = m.point(lambda p: 255 if p >= 128 else 0)
|
|
return m_bin.getbbox()
|
|
|
|
def _expand_bbox(self, bbox: Tuple[int, int, int, int], image_size: Tuple[int, int], margin: int = 16) -> Tuple[int, int, int, int]:
|
|
"""문맥을 확보하기 위해 bbox를 margin만큼 확장합니다."""
|
|
left, top, right, bottom = bbox
|
|
width, height = image_size
|
|
left = max(0, left - margin)
|
|
top = max(0, top - margin)
|
|
right = min(width, right + margin)
|
|
bottom = min(height, bottom + margin)
|
|
return left, top, right, bottom
|
|
|
|
def _pad_to_multiple(self, img_np: np.ndarray, mask_np: np.ndarray, multiple: int = 8) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int]]:
|
|
"""하단/우측 0 패딩으로 (H,W)를 multiple 배수로 맞춥니다. pad_h, pad_w 반환."""
|
|
h, w = img_np.shape[:2]
|
|
pad_h = (multiple - (h % multiple)) % multiple
|
|
pad_w = (multiple - (w % multiple)) % multiple
|
|
if pad_h == 0 and pad_w == 0:
|
|
return img_np, mask_np, (0, 0)
|
|
img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
|
|
mask_padded = np.pad(mask_np, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)
|
|
return img_padded, mask_padded, (pad_h, pad_w)
|
|
|
|
def _analyze_mask(self, mask: Image.Image, image_size: Tuple[int, int]) -> dict:
|
|
"""마스크를 분석하여 면적 비율, 컴포넌트 수 등을 반환합니다."""
|
|
mask_np = np.array(mask.convert("L"))
|
|
total_pixels = image_size[0] * image_size[1]
|
|
mask_pixels = np.sum(mask_np > 127)
|
|
mask_ratio = mask_pixels / total_pixels if total_pixels > 0 else 0
|
|
|
|
# 연결 컴포넌트 수 계산 (간단한 분산도 측정)
|
|
binary_mask = (mask_np > 127).astype(np.uint8)
|
|
num_labels, _ = cv2.connectedComponents(binary_mask)
|
|
|
|
return {
|
|
"mask_ratio": mask_ratio,
|
|
"num_components": num_labels - 1 if num_labels > 0 else 0 # 배경 제외
|
|
}
|
|
|
|
def _should_use_full_image(self, mask: Image.Image, bbox: Tuple[int, int, int, int], image_size: Tuple[int, int]) -> bool:
|
|
"""전체 이미지 인페인팅을 사용할지 결정합니다."""
|
|
if bbox is None:
|
|
return True # 마스크가 없으면 전체
|
|
|
|
# 마스크 분석
|
|
analysis = self._analyze_mask(mask, image_size)
|
|
|
|
# 조건 1: 마스크 면적이 전체의 50% 이상
|
|
if analysis["mask_ratio"] >= self._mask_area_ratio_threshold:
|
|
return True
|
|
|
|
# 조건 2: ROI 영역이 전체 이미지의 70% 이상
|
|
left, top, right, bottom = bbox
|
|
roi_area = (right - left) * (bottom - top)
|
|
total_area = image_size[0] * image_size[1]
|
|
roi_ratio = roi_area / total_area if total_area > 0 else 0
|
|
if roi_ratio >= self._roi_area_ratio_threshold:
|
|
return True
|
|
|
|
# 조건 3: 마스크 컴포넌트가 5개 이상 (분산된 작은 영역들)
|
|
if analysis["num_components"] >= self._min_mask_components:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _compute_target_size(self, width: int, height: int) -> Tuple[int, int]:
|
|
"""원본 비율을 유지하면서 긴 변을 self._max_long_side로 제한하고
|
|
모델 호환을 위해 각 변을 self._size_multiple의 배수로 맞춥니다.
|
|
업스케일은 하지 않습니다.
|
|
"""
|
|
max_long_side = self._max_long_side
|
|
multiple = self._size_multiple
|
|
long_side = max(width, height)
|
|
scale = 1.0 if long_side <= max_long_side else (max_long_side / float(long_side))
|
|
target_w = int(round(width * scale))
|
|
target_h = int(round(height * scale))
|
|
# 다중수에 맞춤 (0 방지)
|
|
target_w = max(multiple, (target_w // multiple) * multiple)
|
|
target_h = max(multiple, (target_h // multiple) * multiple)
|
|
return target_w, target_h
|
|
|
|
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""이미지를 전처리합니다."""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# RGB로 변환
|
|
if image.shape[2] == 4: # RGBA
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
elif image.shape[2] == 3 and image.dtype == np.uint8:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
# 정규화 (0-1)
|
|
image = image.astype(np.float32) / 255.0
|
|
|
|
# 텐서로 변환 (B, C, H, W)
|
|
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
|
|
|
if self._fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self._device)
|
|
|
|
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""마스크를 전처리합니다."""
|
|
if isinstance(mask, Image.Image):
|
|
mask = np.array(mask)
|
|
|
|
# 그레이스케일로 변환
|
|
if len(mask.shape) == 3:
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
|
|
# 이진화 (0 또는 1)
|
|
mask = (mask > 127).astype(np.float32)
|
|
|
|
# 텐서로 변환 (B, 1, H, W)
|
|
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
|
|
|
|
if self._fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self._device)
|
|
|
|
def postprocess_result(self, tensor: torch.Tensor) -> np.ndarray:
|
|
"""결과를 후처리합니다."""
|
|
# CPU로 이동하고 numpy로 변환
|
|
if tensor.is_cuda:
|
|
tensor = tensor.cpu()
|
|
if tensor.dtype == torch.float16:
|
|
tensor = tensor.float()
|
|
|
|
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
|
|
|
|
# 0-255 범위로 변환
|
|
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
return result
|
|
|
|
async def inpaint(
|
|
self,
|
|
images: List[np.ndarray],
|
|
masks: List[np.ndarray],
|
|
**kwargs,
|
|
) -> List[np.ndarray]:
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
if not self.is_ready:
|
|
raise RuntimeError("SimpleLama model is not loaded yet.")
|
|
|
|
# 모델이 GPU에 있는지 확인
|
|
if self._device.type != 'cpu':
|
|
torch.cuda.empty_cache()
|
|
|
|
# 전처리
|
|
pil_images = [Image.fromarray(img) for img in images]
|
|
pil_masks = [Image.fromarray(mask) for mask in masks]
|
|
|
|
# 전처리 (리사이즈 없이 마스크 ROI만 크롭 + 패딩)
|
|
preprocessed_items = []
|
|
for img, mask in zip(pil_images, pil_masks):
|
|
img_tensor, mask_tensor, meta = self._preprocess(img, mask)
|
|
preprocessed_items.append((img_tensor, mask_tensor, meta))
|
|
|
|
# 원본 이미지/사이즈 보관
|
|
original_images_and_sizes = list(zip(pil_images, [img.size for img in pil_images], pil_masks))
|
|
|
|
logger.info(f"실제 SimpleLama 모델로 {len(images)}개 이미지 인페인팅 수행")
|
|
# 성능 최적화: cuDNN benchmark
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
result_images = []
|
|
with torch.no_grad():
|
|
for i, (img_tensor, mask_tensor, meta) in enumerate(preprocessed_items):
|
|
# 배치 차원 추가
|
|
image_batch = img_tensor.unsqueeze(0)
|
|
mask_batch = mask_tensor.unsqueeze(0)
|
|
|
|
if self._device.type == 'cuda':
|
|
image_batch = image_batch.pin_memory().to(self._device, non_blocking=True)
|
|
mask_batch = mask_batch.pin_memory().to(self._device, non_blocking=True)
|
|
else:
|
|
image_batch = image_batch.to(self._device)
|
|
mask_batch = mask_batch.to(self._device)
|
|
|
|
# 모델 호출 (출력: [1, C, H, W])
|
|
inpainted = self._model.model(image_batch, mask_batch)
|
|
inpainted_tensor = inpainted[0] if isinstance(inpainted, torch.Tensor) else inpainted[0]
|
|
|
|
original_image, _, original_mask = original_images_and_sizes[i]
|
|
use_full_image = meta.get("use_full_image", False)
|
|
|
|
if use_full_image:
|
|
# 전체 이미지 처리: 패딩 제거 후 바로 최종 결과
|
|
roi_h, roi_w = meta["roi_size"]
|
|
pad_h, pad_w = meta["pad_hw"]
|
|
if pad_h or pad_w:
|
|
inpainted_tensor = inpainted_tensor[:, :roi_h, :roi_w]
|
|
|
|
# 텐서를 최종 이미지로 변환
|
|
final_np = inpainted_tensor.permute(1, 2, 0).detach().float().cpu().numpy()
|
|
final_np = np.nan_to_num(final_np, nan=0.0, posinf=1.0, neginf=0.0)
|
|
final_np = (np.clip(final_np, 0.0, 1.0) * 255.0).astype(np.uint8)
|
|
result_images.append(final_np)
|
|
else:
|
|
# ROI 처리: 기존 합성 로직
|
|
# 패딩 제거하여 원래 ROI 크기로 복원
|
|
roi_h, roi_w = meta["roi_size"]
|
|
pad_h, pad_w = meta["pad_hw"]
|
|
if pad_h or pad_w:
|
|
inpainted_tensor = inpainted_tensor[:, :roi_h, :roi_w]
|
|
|
|
# 텐서를 PIL ROI 이미지로 변환
|
|
roi_np = inpainted_tensor.permute(1, 2, 0).detach().float().cpu().numpy()
|
|
roi_np = np.nan_to_num(roi_np, nan=0.0, posinf=1.0, neginf=0.0)
|
|
roi_np = (np.clip(roi_np, 0.0, 1.0) * 255.0).astype(np.uint8)
|
|
roi_inpainted = Image.fromarray(roi_np)
|
|
|
|
# 원본 이미지에 ROI 합성
|
|
left, top, right, bottom = meta["bbox"]
|
|
original_roi = original_image.crop((left, top, right, bottom))
|
|
mask_bin = original_mask.convert("L").point(lambda p: 255 if p >= 128 else 0)
|
|
mask_roi = mask_bin.crop((left, top, right, bottom))
|
|
composited_roi = Image.composite(roi_inpainted, original_roi, mask_roi)
|
|
final_img = original_image.copy()
|
|
final_img.paste(composited_roi, (left, top))
|
|
result_images.append(np.array(final_img))
|
|
|
|
return result_images
|
|
|
|
def _preprocess(self, image: Image.Image, mask: Image.Image):
|
|
"""마스크 분석 후 ROI 크롭 또는 전체 이미지 처리로 자동 결정합니다."""
|
|
image = image.convert("RGB")
|
|
mask = mask.convert("L")
|
|
image_size = (image.width, image.height)
|
|
|
|
bbox = self._get_mask_bbox(mask)
|
|
use_full_image = self._should_use_full_image(mask, bbox, image_size)
|
|
|
|
if use_full_image:
|
|
# 전체 이미지 처리
|
|
left, top, right, bottom = 0, 0, image.width, image.height
|
|
# 8의 배수로 패딩
|
|
img_np = np.array(image, dtype=np.uint8)
|
|
mask_np = np.array(mask, dtype=np.uint8)
|
|
img_np_padded, mask_np_padded, pad_hw = self._pad_to_multiple(img_np, mask_np, multiple=8)
|
|
roi_h, roi_w = img_np.shape[0], img_np.shape[1]
|
|
else:
|
|
# ROI 크롭 + 마진 + 패딩
|
|
left, top, right, bottom = self._expand_bbox(bbox, image_size, margin=self._roi_margin)
|
|
# ROI 크롭
|
|
image_crop = image.crop((left, top, right, bottom))
|
|
mask_crop = mask.crop((left, top, right, bottom))
|
|
# numpy 변환
|
|
img_np = np.array(image_crop, dtype=np.uint8)
|
|
mask_np = np.array(mask_crop, dtype=np.uint8)
|
|
roi_h, roi_w = img_np.shape[0], img_np.shape[1]
|
|
# 8의 배수 패딩
|
|
img_np_padded, mask_np_padded, pad_hw = self._pad_to_multiple(img_np, mask_np, multiple=8)
|
|
|
|
# 정규화 및 텐서 변환 (마스크는 0..1 float32 유지)
|
|
image_tensor = torch.from_numpy(img_np_padded.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).squeeze(0)
|
|
mask_tensor = torch.from_numpy((mask_np_padded.astype(np.float32) / 255.0)).unsqueeze(0).unsqueeze(0).squeeze(0)
|
|
|
|
meta = {
|
|
"bbox": (left, top, right, bottom),
|
|
"pad_hw": pad_hw, # (pad_h, pad_w)
|
|
"roi_size": (roi_h, roi_w),
|
|
"use_full_image": use_full_image,
|
|
}
|
|
return image_tensor, mask_tensor, meta
|
|
|
|
def _postprocess(self, tensor: torch.Tensor, original_size: Tuple[int, int], original_image: Image.Image, original_mask: Image.Image) -> Image.Image:
|
|
"""모델 출력 텐서를 PIL 이미지로 후처리하고 원본에 합성합니다."""
|
|
# 텐서를 PIL 이미지로 변환
|
|
result_np = tensor.permute(1, 2, 0).detach().float().cpu().numpy()
|
|
# NaN/Inf 안전 처리 후 범위 클램프
|
|
result_np = np.nan_to_num(result_np, nan=0.0, posinf=1.0, neginf=0.0)
|
|
result_np = (np.clip(result_np, 0.0, 1.0) * 255.0).astype(np.uint8)
|
|
inpainted_image_512 = Image.fromarray(result_np)
|
|
|
|
# 원본 크기로 리사이즈
|
|
resized_inpainted_image = inpainted_image_512.resize(original_size, Image.Resampling.LANCZOS)
|
|
|
|
# 원본 마스크를 사용하여 원본 이미지와 합성
|
|
original_mask = original_mask.convert("L")
|
|
# 경계부 헤일로 방지를 위한 이진화
|
|
binary_mask = original_mask.point(lambda p: 255 if p >= 128 else 0)
|
|
final_image = Image.composite(resized_inpainted_image, original_image, binary_mask)
|
|
|
|
return final_image
|
|
|
|
|
|
@property
|
|
def is_ready(self) -> bool:
|
|
return self._model is not None
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "simple_lama",
|
|
"device": self._device,
|
|
"fp16": self._fp16,
|
|
"loaded": self.loaded,
|
|
"model_path": self.model_path
|
|
}
|