inpaintServer/app/models/simple_lama.py

381 lines
17 KiB
Python

"""
Simple LAMA 인페인팅 모델 구현
"""
import torch
import numpy as np
import cv2
from PIL import Image
import logging
from typing import Union, Tuple, List
import asyncio
from concurrent.futures import ThreadPoolExecutor
from simple_lama_inpainting import SimpleLama
# 사용하지 않는 import 정리
# from ..utils.image_utils import (
# decode_base64_to_image,
# encode_image_to_base64,
# get_image_size,
# resize_image_if_needed,
# )
logger = logging.getLogger(__name__)
class SimpleLamaInpainter:
def __init__(self, model_path: str, device: str = "cpu", fp16: bool = False):
self.model_path = model_path
self._device = torch.device(device)
# LAMA 경로에서는 FP16을 사용하지 않습니다.
self._fp16 = False
self._model = None
self.loaded = False
# 동적 리사이즈 파라미터: 긴 변 상한 및 네트워크 호환 다중수
self._max_long_side = 1024
self._size_multiple = 8
# 자동 전체 인페인팅 전환 조건 (환경변수로 설정 가능)
import os
self._mask_area_ratio_threshold = float(os.getenv('LAMA_MASK_AREA_RATIO', '0.5')) # 마스크 면적이 전체의 50% 이상
self._roi_area_ratio_threshold = float(os.getenv('LAMA_ROI_AREA_RATIO', '0.7')) # ROI가 전체 이미지의 70% 이상
self._min_mask_components = int(os.getenv('LAMA_MIN_COMPONENTS', '5')) # 마스크 컴포넌트가 5개 이상 (분산도)
self._roi_margin = int(os.getenv('LAMA_ROI_MARGIN', '32')) # ROI 마진 (기본 32px)
async def load_model(self):
"""모델을 비동기적으로 로드합니다."""
if self.loaded:
return
try:
logger.info("Loading Simple LAMA model...")
# 실제 simple-lama-inpainting 라이브러리 사용
try:
self._model = SimpleLama(device=self._device)
logger.info("실제 SimpleLama 모델 로딩 완료")
except ImportError as e:
logger.warning(f"SimpleLama 라이브러리 import 실패: {e}")
logger.info("fallback 모드로 전환합니다...")
# fallback으로 시뮬레이션 모드 사용
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
except Exception as e:
logger.error(f"SimpleLama 모델 초기화 실패: {e}")
logger.info("fallback 모드로 전환합니다...")
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
self.loaded = True
logger.info("Simple LAMA model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Simple LAMA model: {e}")
raise
def _get_mask_bbox(self, mask: Image.Image) -> Union[Tuple[int, int, int, int], None]:
"""마스크의 유효 영역 바운딩 박스를 반환합니다. 없으면 None 반환."""
m = mask.convert("L")
m_bin = m.point(lambda p: 255 if p >= 128 else 0)
return m_bin.getbbox()
def _expand_bbox(self, bbox: Tuple[int, int, int, int], image_size: Tuple[int, int], margin: int = 16) -> Tuple[int, int, int, int]:
"""문맥을 확보하기 위해 bbox를 margin만큼 확장합니다."""
left, top, right, bottom = bbox
width, height = image_size
left = max(0, left - margin)
top = max(0, top - margin)
right = min(width, right + margin)
bottom = min(height, bottom + margin)
return left, top, right, bottom
def _pad_to_multiple(self, img_np: np.ndarray, mask_np: np.ndarray, multiple: int = 8) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int]]:
"""하단/우측 0 패딩으로 (H,W)를 multiple 배수로 맞춥니다. pad_h, pad_w 반환."""
h, w = img_np.shape[:2]
pad_h = (multiple - (h % multiple)) % multiple
pad_w = (multiple - (w % multiple)) % multiple
if pad_h == 0 and pad_w == 0:
return img_np, mask_np, (0, 0)
img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
mask_padded = np.pad(mask_np, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)
return img_padded, mask_padded, (pad_h, pad_w)
def _analyze_mask(self, mask: Image.Image, image_size: Tuple[int, int]) -> dict:
"""마스크를 분석하여 면적 비율, 컴포넌트 수 등을 반환합니다."""
mask_np = np.array(mask.convert("L"))
total_pixels = image_size[0] * image_size[1]
mask_pixels = np.sum(mask_np > 127)
mask_ratio = mask_pixels / total_pixels if total_pixels > 0 else 0
# 연결 컴포넌트 수 계산 (간단한 분산도 측정)
binary_mask = (mask_np > 127).astype(np.uint8)
num_labels, _ = cv2.connectedComponents(binary_mask)
return {
"mask_ratio": mask_ratio,
"num_components": num_labels - 1 if num_labels > 0 else 0 # 배경 제외
}
def _should_use_full_image(self, mask: Image.Image, bbox: Tuple[int, int, int, int], image_size: Tuple[int, int]) -> bool:
"""전체 이미지 인페인팅을 사용할지 결정합니다."""
if bbox is None:
return True # 마스크가 없으면 전체
# 마스크 분석
analysis = self._analyze_mask(mask, image_size)
# 조건 1: 마스크 면적이 전체의 50% 이상
if analysis["mask_ratio"] >= self._mask_area_ratio_threshold:
return True
# 조건 2: ROI 영역이 전체 이미지의 70% 이상
left, top, right, bottom = bbox
roi_area = (right - left) * (bottom - top)
total_area = image_size[0] * image_size[1]
roi_ratio = roi_area / total_area if total_area > 0 else 0
if roi_ratio >= self._roi_area_ratio_threshold:
return True
# 조건 3: 마스크 컴포넌트가 5개 이상 (분산된 작은 영역들)
if analysis["num_components"] >= self._min_mask_components:
return True
return False
def _compute_target_size(self, width: int, height: int) -> Tuple[int, int]:
"""원본 비율을 유지하면서 긴 변을 self._max_long_side로 제한하고
모델 호환을 위해 각 변을 self._size_multiple의 배수로 맞춥니다.
업스케일은 하지 않습니다.
"""
max_long_side = self._max_long_side
multiple = self._size_multiple
long_side = max(width, height)
scale = 1.0 if long_side <= max_long_side else (max_long_side / float(long_side))
target_w = int(round(width * scale))
target_h = int(round(height * scale))
# 다중수에 맞춤 (0 방지)
target_w = max(multiple, (target_w // multiple) * multiple)
target_h = max(multiple, (target_h // multiple) * multiple)
return target_w, target_h
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""이미지를 전처리합니다."""
if isinstance(image, Image.Image):
image = np.array(image)
# RGB로 변환
if image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
elif image.shape[2] == 3 and image.dtype == np.uint8:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 정규화 (0-1)
image = image.astype(np.float32) / 255.0
# 텐서로 변환 (B, C, H, W)
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
if self._fp16:
tensor = tensor.half()
return tensor.to(self._device)
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""마스크를 전처리합니다."""
if isinstance(mask, Image.Image):
mask = np.array(mask)
# 그레이스케일로 변환
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
# 이진화 (0 또는 1)
mask = (mask > 127).astype(np.float32)
# 텐서로 변환 (B, 1, H, W)
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
if self._fp16:
tensor = tensor.half()
return tensor.to(self._device)
def postprocess_result(self, tensor: torch.Tensor) -> np.ndarray:
"""결과를 후처리합니다."""
# CPU로 이동하고 numpy로 변환
if tensor.is_cuda:
tensor = tensor.cpu()
if tensor.dtype == torch.float16:
tensor = tensor.float()
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
# 0-255 범위로 변환
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
return result
async def inpaint(
self,
images: List[np.ndarray],
masks: List[np.ndarray],
**kwargs,
) -> List[np.ndarray]:
if not self.loaded:
await self.load_model()
if not self.is_ready:
raise RuntimeError("SimpleLama model is not loaded yet.")
# 모델이 GPU에 있는지 확인
if self._device.type != 'cpu':
torch.cuda.empty_cache()
# 전처리
pil_images = [Image.fromarray(img) for img in images]
pil_masks = [Image.fromarray(mask) for mask in masks]
# 전처리 (리사이즈 없이 마스크 ROI만 크롭 + 패딩)
preprocessed_items = []
for img, mask in zip(pil_images, pil_masks):
img_tensor, mask_tensor, meta = self._preprocess(img, mask)
preprocessed_items.append((img_tensor, mask_tensor, meta))
# 원본 이미지/사이즈 보관
original_images_and_sizes = list(zip(pil_images, [img.size for img in pil_images], pil_masks))
logger.info(f"실제 SimpleLama 모델로 {len(images)}개 이미지 인페인팅 수행")
# 성능 최적화: cuDNN benchmark
torch.backends.cudnn.benchmark = True
result_images = []
with torch.no_grad():
for i, (img_tensor, mask_tensor, meta) in enumerate(preprocessed_items):
# 배치 차원 추가
image_batch = img_tensor.unsqueeze(0)
mask_batch = mask_tensor.unsqueeze(0)
if self._device.type == 'cuda':
image_batch = image_batch.pin_memory().to(self._device, non_blocking=True)
mask_batch = mask_batch.pin_memory().to(self._device, non_blocking=True)
else:
image_batch = image_batch.to(self._device)
mask_batch = mask_batch.to(self._device)
# 모델 호출 (출력: [1, C, H, W])
inpainted = self._model.model(image_batch, mask_batch)
inpainted_tensor = inpainted[0] if isinstance(inpainted, torch.Tensor) else inpainted[0]
original_image, _, original_mask = original_images_and_sizes[i]
use_full_image = meta.get("use_full_image", False)
if use_full_image:
# 전체 이미지 처리: 패딩 제거 후 바로 최종 결과
roi_h, roi_w = meta["roi_size"]
pad_h, pad_w = meta["pad_hw"]
if pad_h or pad_w:
inpainted_tensor = inpainted_tensor[:, :roi_h, :roi_w]
# 텐서를 최종 이미지로 변환
final_np = inpainted_tensor.permute(1, 2, 0).detach().float().cpu().numpy()
final_np = np.nan_to_num(final_np, nan=0.0, posinf=1.0, neginf=0.0)
final_np = (np.clip(final_np, 0.0, 1.0) * 255.0).astype(np.uint8)
result_images.append(final_np)
else:
# ROI 처리: 기존 합성 로직
# 패딩 제거하여 원래 ROI 크기로 복원
roi_h, roi_w = meta["roi_size"]
pad_h, pad_w = meta["pad_hw"]
if pad_h or pad_w:
inpainted_tensor = inpainted_tensor[:, :roi_h, :roi_w]
# 텐서를 PIL ROI 이미지로 변환
roi_np = inpainted_tensor.permute(1, 2, 0).detach().float().cpu().numpy()
roi_np = np.nan_to_num(roi_np, nan=0.0, posinf=1.0, neginf=0.0)
roi_np = (np.clip(roi_np, 0.0, 1.0) * 255.0).astype(np.uint8)
roi_inpainted = Image.fromarray(roi_np)
# 원본 이미지에 ROI 합성
left, top, right, bottom = meta["bbox"]
original_roi = original_image.crop((left, top, right, bottom))
mask_bin = original_mask.convert("L").point(lambda p: 255 if p >= 128 else 0)
mask_roi = mask_bin.crop((left, top, right, bottom))
composited_roi = Image.composite(roi_inpainted, original_roi, mask_roi)
final_img = original_image.copy()
final_img.paste(composited_roi, (left, top))
result_images.append(np.array(final_img))
return result_images
def _preprocess(self, image: Image.Image, mask: Image.Image):
"""마스크 분석 후 ROI 크롭 또는 전체 이미지 처리로 자동 결정합니다."""
image = image.convert("RGB")
mask = mask.convert("L")
image_size = (image.width, image.height)
bbox = self._get_mask_bbox(mask)
use_full_image = self._should_use_full_image(mask, bbox, image_size)
if use_full_image:
# 전체 이미지 처리
left, top, right, bottom = 0, 0, image.width, image.height
# 8의 배수로 패딩
img_np = np.array(image, dtype=np.uint8)
mask_np = np.array(mask, dtype=np.uint8)
img_np_padded, mask_np_padded, pad_hw = self._pad_to_multiple(img_np, mask_np, multiple=8)
roi_h, roi_w = img_np.shape[0], img_np.shape[1]
else:
# ROI 크롭 + 마진 + 패딩
left, top, right, bottom = self._expand_bbox(bbox, image_size, margin=self._roi_margin)
# ROI 크롭
image_crop = image.crop((left, top, right, bottom))
mask_crop = mask.crop((left, top, right, bottom))
# numpy 변환
img_np = np.array(image_crop, dtype=np.uint8)
mask_np = np.array(mask_crop, dtype=np.uint8)
roi_h, roi_w = img_np.shape[0], img_np.shape[1]
# 8의 배수 패딩
img_np_padded, mask_np_padded, pad_hw = self._pad_to_multiple(img_np, mask_np, multiple=8)
# 정규화 및 텐서 변환 (마스크는 0..1 float32 유지)
image_tensor = torch.from_numpy(img_np_padded.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).squeeze(0)
mask_tensor = torch.from_numpy((mask_np_padded.astype(np.float32) / 255.0)).unsqueeze(0).unsqueeze(0).squeeze(0)
meta = {
"bbox": (left, top, right, bottom),
"pad_hw": pad_hw, # (pad_h, pad_w)
"roi_size": (roi_h, roi_w),
"use_full_image": use_full_image,
}
return image_tensor, mask_tensor, meta
def _postprocess(self, tensor: torch.Tensor, original_size: Tuple[int, int], original_image: Image.Image, original_mask: Image.Image) -> Image.Image:
"""모델 출력 텐서를 PIL 이미지로 후처리하고 원본에 합성합니다."""
# 텐서를 PIL 이미지로 변환
result_np = tensor.permute(1, 2, 0).detach().float().cpu().numpy()
# NaN/Inf 안전 처리 후 범위 클램프
result_np = np.nan_to_num(result_np, nan=0.0, posinf=1.0, neginf=0.0)
result_np = (np.clip(result_np, 0.0, 1.0) * 255.0).astype(np.uint8)
inpainted_image_512 = Image.fromarray(result_np)
# 원본 크기로 리사이즈
resized_inpainted_image = inpainted_image_512.resize(original_size, Image.Resampling.LANCZOS)
# 원본 마스크를 사용하여 원본 이미지와 합성
original_mask = original_mask.convert("L")
# 경계부 헤일로 방지를 위한 이진화
binary_mask = original_mask.point(lambda p: 255 if p >= 128 else 0)
final_image = Image.composite(resized_inpainted_image, original_image, binary_mask)
return final_image
@property
def is_ready(self) -> bool:
return self._model is not None
def get_model_info(self) -> dict:
"""모델 정보를 반환합니다."""
return {
"model_type": "simple_lama",
"device": self._device,
"fp16": self._fp16,
"loaded": self.loaded,
"model_path": self.model_path
}