inpaintServer/app/models/simple_lama.py

181 lines
6.7 KiB
Python

"""
Simple LAMA 인페인팅 모델 구현
"""
import torch
import numpy as np
import cv2
from PIL import Image
import logging
from typing import Union, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
class SimpleLamaInpainter:
def __init__(self, model_path: str = None, device: str = "cuda", fp16: bool = True):
self.model_path = model_path
self.device = device
self.fp16 = fp16
self.model = None
self.loaded = False
async def load_model(self):
"""모델을 비동기적으로 로드합니다."""
if self.loaded:
return
try:
logger.info("Loading Simple LAMA model...")
# 실제 simple-lama-inpainting 라이브러리 사용
try:
from simple_lama_inpainting import SimpleLama
self.model = SimpleLama(device=self.device)
logger.info("실제 SimpleLama 모델 로딩 완료")
except ImportError as e:
logger.warning(f"SimpleLama 라이브러리 import 실패: {e}")
logger.info("fallback 모드로 전환합니다...")
# fallback으로 시뮬레이션 모드 사용
self.model = {"type": "simple_lama_fallback", "device": self.device, "fp16": self.fp16}
except Exception as e:
logger.error(f"SimpleLama 모델 초기화 실패: {e}")
logger.info("fallback 모드로 전환합니다...")
self.model = {"type": "simple_lama_fallback", "device": self.device, "fp16": self.fp16}
self.loaded = True
logger.info("Simple LAMA model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Simple LAMA model: {e}")
raise
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""이미지를 전처리합니다."""
if isinstance(image, Image.Image):
image = np.array(image)
# RGB로 변환
if image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
elif image.shape[2] == 3 and image.dtype == np.uint8:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 정규화 (0-1)
image = image.astype(np.float32) / 255.0
# 텐서로 변환 (B, C, H, W)
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
if self.fp16:
tensor = tensor.half()
return tensor.to(self.device)
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""마스크를 전처리합니다."""
if isinstance(mask, Image.Image):
mask = np.array(mask)
# 그레이스케일로 변환
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
# 이진화 (0 또는 1)
mask = (mask > 127).astype(np.float32)
# 텐서로 변환 (B, 1, H, W)
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
if self.fp16:
tensor = tensor.half()
return tensor.to(self.device)
def postprocess_result(self, tensor: torch.Tensor) -> np.ndarray:
"""결과를 후처리합니다."""
# CPU로 이동하고 numpy로 변환
if tensor.is_cuda:
tensor = tensor.cpu()
if tensor.dtype == torch.float16:
tensor = tensor.float()
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
# 0-255 범위로 변환
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
return result
async def inpaint(self, image: Union[Image.Image, np.ndarray],
mask: Union[Image.Image, np.ndarray]) -> np.ndarray:
"""인페인팅을 수행합니다."""
if not self.loaded:
await self.load_model()
try:
# 전처리
image_tensor = self.preprocess_image(image)
mask_tensor = self.preprocess_mask(mask)
# 실제 모델 추론
with torch.no_grad():
if hasattr(self.model, '__call__') and not isinstance(self.model, dict):
# 실제 SimpleLama 모델 사용
logger.info("실제 SimpleLama 모델로 인페인팅 수행")
# SimpleLama는 PIL Image를 받으므로 변환
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image)
else:
pil_image = image
if isinstance(mask, np.ndarray):
pil_mask = Image.fromarray(mask)
else:
pil_mask = mask
# 실제 추론 수행
result_pil = self.model(pil_image, pil_mask)
result_np = np.array(result_pil)
return result_np
else:
# Fallback: 시뮬레이션 모드
logger.warning("Fallback 모드: 시뮬레이션 인페인팅 사용")
result = await self._simulate_inpainting(image_tensor, mask_tensor)
result_np = self.postprocess_result(result)
return result_np
except Exception as e:
logger.error(f"Inpainting failed: {e}")
raise
async def _simulate_inpainting(self, image_tensor: torch.Tensor,
mask_tensor: torch.Tensor) -> torch.Tensor:
"""인페인팅 시뮬레이션 (실제 구현에서는 제거)"""
# 비동기 처리 시뮬레이션
await asyncio.sleep(0.1)
# 마스크 영역을 이미지의 평균 색상으로 채우기
result = image_tensor.clone()
mask_bool = mask_tensor.bool()
# 각 채널별 평균 계산
for c in range(3):
channel_mean = image_tensor[0, c][~mask_bool[0, 0]].mean()
result[0, c][mask_bool[0, 0]] = channel_mean
return result
def get_model_info(self) -> dict:
"""모델 정보를 반환합니다."""
return {
"model_type": "simple_lama",
"device": self.device,
"fp16": self.fp16,
"loaded": self.loaded,
"model_path": self.model_path
}