181 lines
6.7 KiB
Python
181 lines
6.7 KiB
Python
"""
|
|
Simple LAMA 인페인팅 모델 구현
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
from typing import Union, Tuple
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SimpleLamaInpainter:
|
|
def __init__(self, model_path: str = None, device: str = "cuda", fp16: bool = True):
|
|
self.model_path = model_path
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.model = None
|
|
self.loaded = False
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info("Loading Simple LAMA model...")
|
|
|
|
# 실제 simple-lama-inpainting 라이브러리 사용
|
|
try:
|
|
from simple_lama_inpainting import SimpleLama
|
|
self.model = SimpleLama(device=self.device)
|
|
logger.info("실제 SimpleLama 모델 로딩 완료")
|
|
except ImportError as e:
|
|
logger.warning(f"SimpleLama 라이브러리 import 실패: {e}")
|
|
logger.info("fallback 모드로 전환합니다...")
|
|
# fallback으로 시뮬레이션 모드 사용
|
|
self.model = {"type": "simple_lama_fallback", "device": self.device, "fp16": self.fp16}
|
|
except Exception as e:
|
|
logger.error(f"SimpleLama 모델 초기화 실패: {e}")
|
|
logger.info("fallback 모드로 전환합니다...")
|
|
self.model = {"type": "simple_lama_fallback", "device": self.device, "fp16": self.fp16}
|
|
|
|
self.loaded = True
|
|
logger.info("Simple LAMA model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Simple LAMA model: {e}")
|
|
raise
|
|
|
|
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""이미지를 전처리합니다."""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# RGB로 변환
|
|
if image.shape[2] == 4: # RGBA
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
elif image.shape[2] == 3 and image.dtype == np.uint8:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
# 정규화 (0-1)
|
|
image = image.astype(np.float32) / 255.0
|
|
|
|
# 텐서로 변환 (B, C, H, W)
|
|
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
|
|
|
if self.fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self.device)
|
|
|
|
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""마스크를 전처리합니다."""
|
|
if isinstance(mask, Image.Image):
|
|
mask = np.array(mask)
|
|
|
|
# 그레이스케일로 변환
|
|
if len(mask.shape) == 3:
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
|
|
# 이진화 (0 또는 1)
|
|
mask = (mask > 127).astype(np.float32)
|
|
|
|
# 텐서로 변환 (B, 1, H, W)
|
|
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
|
|
|
|
if self.fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self.device)
|
|
|
|
def postprocess_result(self, tensor: torch.Tensor) -> np.ndarray:
|
|
"""결과를 후처리합니다."""
|
|
# CPU로 이동하고 numpy로 변환
|
|
if tensor.is_cuda:
|
|
tensor = tensor.cpu()
|
|
if tensor.dtype == torch.float16:
|
|
tensor = tensor.float()
|
|
|
|
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
|
|
|
|
# 0-255 범위로 변환
|
|
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
return result
|
|
|
|
async def inpaint(self, image: Union[Image.Image, np.ndarray],
|
|
mask: Union[Image.Image, np.ndarray]) -> np.ndarray:
|
|
"""인페인팅을 수행합니다."""
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
try:
|
|
# 전처리
|
|
image_tensor = self.preprocess_image(image)
|
|
mask_tensor = self.preprocess_mask(mask)
|
|
|
|
# 실제 모델 추론
|
|
with torch.no_grad():
|
|
if hasattr(self.model, '__call__') and not isinstance(self.model, dict):
|
|
# 실제 SimpleLama 모델 사용
|
|
logger.info("실제 SimpleLama 모델로 인페인팅 수행")
|
|
|
|
# SimpleLama는 PIL Image를 받으므로 변환
|
|
if isinstance(image, np.ndarray):
|
|
pil_image = Image.fromarray(image)
|
|
else:
|
|
pil_image = image
|
|
|
|
if isinstance(mask, np.ndarray):
|
|
pil_mask = Image.fromarray(mask)
|
|
else:
|
|
pil_mask = mask
|
|
|
|
# 실제 추론 수행
|
|
result_pil = self.model(pil_image, pil_mask)
|
|
result_np = np.array(result_pil)
|
|
|
|
return result_np
|
|
else:
|
|
# Fallback: 시뮬레이션 모드
|
|
logger.warning("Fallback 모드: 시뮬레이션 인페인팅 사용")
|
|
result = await self._simulate_inpainting(image_tensor, mask_tensor)
|
|
result_np = self.postprocess_result(result)
|
|
return result_np
|
|
|
|
except Exception as e:
|
|
logger.error(f"Inpainting failed: {e}")
|
|
raise
|
|
|
|
async def _simulate_inpainting(self, image_tensor: torch.Tensor,
|
|
mask_tensor: torch.Tensor) -> torch.Tensor:
|
|
"""인페인팅 시뮬레이션 (실제 구현에서는 제거)"""
|
|
# 비동기 처리 시뮬레이션
|
|
await asyncio.sleep(0.1)
|
|
|
|
# 마스크 영역을 이미지의 평균 색상으로 채우기
|
|
result = image_tensor.clone()
|
|
mask_bool = mask_tensor.bool()
|
|
|
|
# 각 채널별 평균 계산
|
|
for c in range(3):
|
|
channel_mean = image_tensor[0, c][~mask_bool[0, 0]].mean()
|
|
result[0, c][mask_bool[0, 0]] = channel_mean
|
|
|
|
return result
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "simple_lama",
|
|
"device": self.device,
|
|
"fp16": self.fp16,
|
|
"loaded": self.loaded,
|
|
"model_path": self.model_path
|
|
}
|