191 lines
7.0 KiB
Python
191 lines
7.0 KiB
Python
"""
|
|
MIGAN 인페인팅 모델 구현
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
from typing import Union, Tuple
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MiganInpainter:
|
|
def __init__(self, model_path: str = None, device: str = "cuda", fp16: bool = True):
|
|
self.model_path = model_path
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.model = None
|
|
self.loaded = False
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info("Loading MIGAN model...")
|
|
|
|
# 실제 구현에서는 MIGAN 모델을 로드
|
|
# 여기서는 플레이스홀더로 구현
|
|
await asyncio.sleep(0.1) # 모델 로딩 시뮬레이션
|
|
|
|
# TODO: 실제 모델 로딩 로직
|
|
# self.model = load_migan_model(self.model_path, device=self.device)
|
|
|
|
self.model = {"type": "migan", "device": self.device, "fp16": self.fp16}
|
|
self.loaded = True
|
|
|
|
logger.info("MIGAN model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load MIGAN model: {e}")
|
|
raise
|
|
|
|
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""이미지를 전처리합니다."""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# RGB로 변환
|
|
if image.shape[2] == 4: # RGBA
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
elif image.shape[2] == 3 and image.dtype == np.uint8:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
# 크기 조정 (MIGAN은 특정 크기를 선호할 수 있음)
|
|
height, width = image.shape[:2]
|
|
if height != 512 or width != 512:
|
|
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LANCZOS4)
|
|
|
|
# 정규화 (-1 to 1, MIGAN 스타일)
|
|
image = image.astype(np.float32) / 127.5 - 1.0
|
|
|
|
# 텐서로 변환 (B, C, H, W)
|
|
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
|
|
|
if self.fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self.device)
|
|
|
|
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""마스크를 전처리합니다."""
|
|
if isinstance(mask, Image.Image):
|
|
mask = np.array(mask)
|
|
|
|
# 그레이스케일로 변환
|
|
if len(mask.shape) == 3:
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
|
|
# 크기 조정
|
|
if mask.shape[0] != 512 or mask.shape[1] != 512:
|
|
mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
|
|
|
# 이진화 (0 또는 1)
|
|
mask = (mask > 127).astype(np.float32)
|
|
|
|
# 텐서로 변환 (B, 1, H, W)
|
|
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
|
|
|
|
if self.fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self.device)
|
|
|
|
def postprocess_result(self, tensor: torch.Tensor, original_size: Tuple[int, int]) -> np.ndarray:
|
|
"""결과를 후처리합니다."""
|
|
# CPU로 이동하고 numpy로 변환
|
|
if tensor.is_cuda:
|
|
tensor = tensor.cpu()
|
|
if tensor.dtype == torch.float16:
|
|
tensor = tensor.float()
|
|
|
|
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
|
|
|
|
# -1 to 1 범위에서 0-255로 변환
|
|
result = ((result + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
|
|
|
# 원본 크기로 복원
|
|
if result.shape[:2] != original_size:
|
|
result = cv2.resize(result, (original_size[1], original_size[0]),
|
|
interpolation=cv2.INTER_LANCZOS4)
|
|
|
|
return result
|
|
|
|
async def inpaint(self, image: Union[Image.Image, np.ndarray],
|
|
mask: Union[Image.Image, np.ndarray]) -> np.ndarray:
|
|
"""인페인팅을 수행합니다."""
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
try:
|
|
# 원본 크기 저장
|
|
if isinstance(image, Image.Image):
|
|
original_size = image.size[::-1] # (height, width)
|
|
else:
|
|
original_size = image.shape[:2]
|
|
|
|
# 전처리
|
|
image_tensor = self.preprocess_image(image)
|
|
mask_tensor = self.preprocess_mask(mask)
|
|
|
|
# 추론 (실제 구현에서는 모델 추론)
|
|
with torch.no_grad():
|
|
# TODO: 실제 모델 추론 로직
|
|
# result = self.model(image_tensor, mask_tensor)
|
|
|
|
# 플레이스홀더: 더 정교한 인페인팅 시뮬레이션
|
|
result = await self._simulate_advanced_inpainting(image_tensor, mask_tensor)
|
|
|
|
# 후처리
|
|
result_np = self.postprocess_result(result, original_size)
|
|
|
|
return result_np
|
|
|
|
except Exception as e:
|
|
logger.error(f"MIGAN inpainting failed: {e}")
|
|
raise
|
|
|
|
async def _simulate_advanced_inpainting(self, image_tensor: torch.Tensor,
|
|
mask_tensor: torch.Tensor) -> torch.Tensor:
|
|
"""고급 인페인팅 시뮬레이션 (실제 구현에서는 제거)"""
|
|
# 비동기 처리 시뮬레이션
|
|
await asyncio.sleep(0.15) # MIGAN은 더 오래 걸린다고 가정
|
|
|
|
result = image_tensor.clone()
|
|
mask_bool = mask_tensor.bool()
|
|
|
|
# 더 정교한 인페인팅 시뮬레이션: 주변 픽셀의 가중 평균
|
|
if mask_bool.any():
|
|
# 간단한 inpainting 시뮬레이션
|
|
for c in range(3):
|
|
channel = result[0, c]
|
|
mask_2d = mask_bool[0, 0]
|
|
|
|
# 마스크 영역의 경계에서 값을 가져와서 보간
|
|
kernel = torch.ones(3, 3, device=self.device) / 9.0
|
|
if self.fp16:
|
|
kernel = kernel.half()
|
|
|
|
# 간단한 convolution 기반 인페인팅
|
|
padded_channel = torch.nn.functional.pad(channel.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='replicate')
|
|
smoothed = torch.nn.functional.conv2d(padded_channel, kernel.unsqueeze(0).unsqueeze(0), padding=0)
|
|
|
|
result[0, c][mask_2d] = smoothed[0, 0][mask_2d]
|
|
|
|
return result
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "migan",
|
|
"device": self.device,
|
|
"fp16": self.fp16,
|
|
"loaded": self.loaded,
|
|
"model_path": self.model_path,
|
|
"input_size": (512, 512)
|
|
}
|