inpaintServer/app/models/migan.py

191 lines
7.0 KiB
Python

"""
MIGAN 인페인팅 모델 구현
"""
import torch
import numpy as np
import cv2
from PIL import Image
import logging
from typing import Union, Tuple
import asyncio
logger = logging.getLogger(__name__)
class MiganInpainter:
def __init__(self, model_path: str = None, device: str = "cuda", fp16: bool = True):
self.model_path = model_path
self.device = device
self.fp16 = fp16
self.model = None
self.loaded = False
async def load_model(self):
"""모델을 비동기적으로 로드합니다."""
if self.loaded:
return
try:
logger.info("Loading MIGAN model...")
# 실제 구현에서는 MIGAN 모델을 로드
# 여기서는 플레이스홀더로 구현
await asyncio.sleep(0.1) # 모델 로딩 시뮬레이션
# TODO: 실제 모델 로딩 로직
# self.model = load_migan_model(self.model_path, device=self.device)
self.model = {"type": "migan", "device": self.device, "fp16": self.fp16}
self.loaded = True
logger.info("MIGAN model loaded successfully")
except Exception as e:
logger.error(f"Failed to load MIGAN model: {e}")
raise
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""이미지를 전처리합니다."""
if isinstance(image, Image.Image):
image = np.array(image)
# RGB로 변환
if image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
elif image.shape[2] == 3 and image.dtype == np.uint8:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 크기 조정 (MIGAN은 특정 크기를 선호할 수 있음)
height, width = image.shape[:2]
if height != 512 or width != 512:
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LANCZOS4)
# 정규화 (-1 to 1, MIGAN 스타일)
image = image.astype(np.float32) / 127.5 - 1.0
# 텐서로 변환 (B, C, H, W)
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
if self.fp16:
tensor = tensor.half()
return tensor.to(self.device)
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""마스크를 전처리합니다."""
if isinstance(mask, Image.Image):
mask = np.array(mask)
# 그레이스케일로 변환
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
# 크기 조정
if mask.shape[0] != 512 or mask.shape[1] != 512:
mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
# 이진화 (0 또는 1)
mask = (mask > 127).astype(np.float32)
# 텐서로 변환 (B, 1, H, W)
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
if self.fp16:
tensor = tensor.half()
return tensor.to(self.device)
def postprocess_result(self, tensor: torch.Tensor, original_size: Tuple[int, int]) -> np.ndarray:
"""결과를 후처리합니다."""
# CPU로 이동하고 numpy로 변환
if tensor.is_cuda:
tensor = tensor.cpu()
if tensor.dtype == torch.float16:
tensor = tensor.float()
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
# -1 to 1 범위에서 0-255로 변환
result = ((result + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
# 원본 크기로 복원
if result.shape[:2] != original_size:
result = cv2.resize(result, (original_size[1], original_size[0]),
interpolation=cv2.INTER_LANCZOS4)
return result
async def inpaint(self, image: Union[Image.Image, np.ndarray],
mask: Union[Image.Image, np.ndarray]) -> np.ndarray:
"""인페인팅을 수행합니다."""
if not self.loaded:
await self.load_model()
try:
# 원본 크기 저장
if isinstance(image, Image.Image):
original_size = image.size[::-1] # (height, width)
else:
original_size = image.shape[:2]
# 전처리
image_tensor = self.preprocess_image(image)
mask_tensor = self.preprocess_mask(mask)
# 추론 (실제 구현에서는 모델 추론)
with torch.no_grad():
# TODO: 실제 모델 추론 로직
# result = self.model(image_tensor, mask_tensor)
# 플레이스홀더: 더 정교한 인페인팅 시뮬레이션
result = await self._simulate_advanced_inpainting(image_tensor, mask_tensor)
# 후처리
result_np = self.postprocess_result(result, original_size)
return result_np
except Exception as e:
logger.error(f"MIGAN inpainting failed: {e}")
raise
async def _simulate_advanced_inpainting(self, image_tensor: torch.Tensor,
mask_tensor: torch.Tensor) -> torch.Tensor:
"""고급 인페인팅 시뮬레이션 (실제 구현에서는 제거)"""
# 비동기 처리 시뮬레이션
await asyncio.sleep(0.15) # MIGAN은 더 오래 걸린다고 가정
result = image_tensor.clone()
mask_bool = mask_tensor.bool()
# 더 정교한 인페인팅 시뮬레이션: 주변 픽셀의 가중 평균
if mask_bool.any():
# 간단한 inpainting 시뮬레이션
for c in range(3):
channel = result[0, c]
mask_2d = mask_bool[0, 0]
# 마스크 영역의 경계에서 값을 가져와서 보간
kernel = torch.ones(3, 3, device=self.device) / 9.0
if self.fp16:
kernel = kernel.half()
# 간단한 convolution 기반 인페인팅
padded_channel = torch.nn.functional.pad(channel.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='replicate')
smoothed = torch.nn.functional.conv2d(padded_channel, kernel.unsqueeze(0).unsqueeze(0), padding=0)
result[0, c][mask_2d] = smoothed[0, 0][mask_2d]
return result
def get_model_info(self) -> dict:
"""모델 정보를 반환합니다."""
return {
"model_type": "migan",
"device": self.device,
"fp16": self.fp16,
"loaded": self.loaded,
"model_path": self.model_path,
"input_size": (512, 512)
}