inpaintServer/app/models/simple_lama.py

220 lines
8.5 KiB
Python

"""
Simple LAMA 인페인팅 모델 구현
"""
import torch
import numpy as np
import cv2
from PIL import Image
import logging
from typing import Union, Tuple, List
import asyncio
from concurrent.futures import ThreadPoolExecutor
from simple_lama_inpainting import SimpleLama
# 사용하지 않는 import 정리
# from ..utils.image_utils import (
# decode_base64_to_image,
# encode_image_to_base64,
# get_image_size,
# resize_image_if_needed,
# )
logger = logging.getLogger(__name__)
class SimpleLamaInpainter:
def __init__(self, model_path: str, device: str = "cpu", fp16: bool = False):
self.model_path = model_path
self._device = torch.device(device)
self._fp16 = fp16
self._model = None
self.loaded = False
async def load_model(self):
"""모델을 비동기적으로 로드합니다."""
if self.loaded:
return
try:
logger.info("Loading Simple LAMA model...")
# 실제 simple-lama-inpainting 라이브러리 사용
try:
self._model = SimpleLama(device=self._device)
logger.info("실제 SimpleLama 모델 로딩 완료")
except ImportError as e:
logger.warning(f"SimpleLama 라이브러리 import 실패: {e}")
logger.info("fallback 모드로 전환합니다...")
# fallback으로 시뮬레이션 모드 사용
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
except Exception as e:
logger.error(f"SimpleLama 모델 초기화 실패: {e}")
logger.info("fallback 모드로 전환합니다...")
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
self.loaded = True
logger.info("Simple LAMA model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Simple LAMA model: {e}")
raise
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""이미지를 전처리합니다."""
if isinstance(image, Image.Image):
image = np.array(image)
# RGB로 변환
if image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
elif image.shape[2] == 3 and image.dtype == np.uint8:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 정규화 (0-1)
image = image.astype(np.float32) / 255.0
# 텐서로 변환 (B, C, H, W)
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
if self._fp16:
tensor = tensor.half()
return tensor.to(self._device)
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""마스크를 전처리합니다."""
if isinstance(mask, Image.Image):
mask = np.array(mask)
# 그레이스케일로 변환
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
# 이진화 (0 또는 1)
mask = (mask > 127).astype(np.float32)
# 텐서로 변환 (B, 1, H, W)
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
if self._fp16:
tensor = tensor.half()
return tensor.to(self._device)
def postprocess_result(self, tensor: torch.Tensor) -> np.ndarray:
"""결과를 후처리합니다."""
# CPU로 이동하고 numpy로 변환
if tensor.is_cuda:
tensor = tensor.cpu()
if tensor.dtype == torch.float16:
tensor = tensor.float()
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
# 0-255 범위로 변환
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
return result
async def inpaint(
self,
images: List[np.ndarray],
masks: List[np.ndarray],
**kwargs,
) -> List[np.ndarray]:
if not self.loaded:
await self.load_model()
if not self.is_ready:
raise RuntimeError("SimpleLama model is not loaded yet.")
# 모델이 GPU에 있는지 확인
if self._device.type != 'cpu':
torch.cuda.empty_cache()
# 전처리
pil_images = [Image.fromarray(img) for img in images]
pil_masks = [Image.fromarray(mask) for mask in masks]
preprocessed_images = []
preprocessed_masks = []
for img, mask in zip(pil_images, pil_masks):
img_tensor, mask_tensor = self._preprocess(img, mask)
preprocessed_images.append(img_tensor)
preprocessed_masks.append(mask_tensor)
# 고정 크기 입력이므로 pinned memory + non_blocking 복사 최적화
image_batch = torch.stack(preprocessed_images).pin_memory() if self._device.type == 'cuda' else torch.stack(preprocessed_images)
mask_batch = torch.stack(preprocessed_masks).pin_memory() if self._device.type == 'cuda' else torch.stack(preprocessed_masks)
image_batch = image_batch.to(self._device, non_blocking=True)
mask_batch = mask_batch.to(self._device, non_blocking=True)
# 원본 이미지와 사이즈 저장
original_images_and_sizes = list(zip(pil_images, [img.size for img in pil_images]))
# 모델 호출
logger.info(f"실제 SimpleLama 모델로 {len(images)}개 이미지 인페인팅 수행")
# 성능 최적화: AMP + cuDNN benchmark
torch.backends.cudnn.benchmark = True
with torch.no_grad():
if self._device.type == 'cuda':
with torch.cuda.amp.autocast(enabled=True):
inpainted_batch = self._model.model(image_batch, mask_batch)
else:
inpainted_batch = self._model.model(image_batch, mask_batch)
# 후처리
result_images = []
for i, inpainted_tensor in enumerate(inpainted_batch):
original_image, original_size = original_images_and_sizes[i]
original_mask = pil_masks[i]
result_pil = self._postprocess(inpainted_tensor, original_size, original_image, original_mask)
result_images.append(np.array(result_pil))
return result_images
def _preprocess(self, image: Image.Image, mask: Image.Image):
"""단일 이미지를 모델 입력 텐서로 전처리합니다."""
# simple_lama_inpainting.models.lama.py의 전처리 로직 참고
image = image.convert("RGB")
mask = mask.convert("L")
# 이미지 리사이즈 (모델 요구사항에 맞게)
resized_image = image.resize((512, 512), Image.Resampling.LANCZOS)
resized_mask = mask.resize((512, 512), Image.Resampling.NEAREST)
image_tensor = torch.from_numpy(np.array(resized_image, dtype=np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).squeeze(0)
mask_tensor = torch.from_numpy(np.array(resized_mask, dtype=np.float32) / 255.0).unsqueeze(0).unsqueeze(0).squeeze(0)
return image_tensor, mask_tensor
def _postprocess(self, tensor: torch.Tensor, original_size: Tuple[int, int], original_image: Image.Image, original_mask: Image.Image) -> Image.Image:
"""모델 출력 텐서를 PIL 이미지로 후처리하고 원본에 합성합니다."""
# 텐서를 PIL 이미지로 변환
result_np = tensor.permute(1, 2, 0).cpu().numpy()
result_np = np.clip(result_np * 255, 0, 255).astype(np.uint8)
inpainted_image_512 = Image.fromarray(result_np)
# 원본 크기로 리사이즈
resized_inpainted_image = inpainted_image_512.resize(original_size, Image.Resampling.LANCZOS)
# 원본 마스크를 사용하여 원본 이미지와 합성
original_mask = original_mask.convert("L")
final_image = Image.composite(resized_inpainted_image, original_image, original_mask)
return final_image
@property
def is_ready(self) -> bool:
return self._model is not None
def get_model_info(self) -> dict:
"""모델 정보를 반환합니다."""
return {
"model_type": "simple_lama",
"device": self._device,
"fp16": self._fp16,
"loaded": self.loaded,
"model_path": self.model_path
}