212 lines
7.9 KiB
Python
212 lines
7.9 KiB
Python
"""
|
|
Simple LAMA 인페인팅 모델 구현
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
from typing import Union, Tuple, List
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from simple_lama_inpainting import SimpleLama
|
|
# 사용하지 않는 import 정리
|
|
# from ..utils.image_utils import (
|
|
# decode_base64_to_image,
|
|
# encode_image_to_base64,
|
|
# get_image_size,
|
|
# resize_image_if_needed,
|
|
# )
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SimpleLamaInpainter:
|
|
def __init__(self, model_path: str, device: str = "cpu", fp16: bool = False):
|
|
self.model_path = model_path
|
|
self._device = torch.device(device)
|
|
self._fp16 = fp16
|
|
self._model = None
|
|
self.loaded = False
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info("Loading Simple LAMA model...")
|
|
|
|
# 실제 simple-lama-inpainting 라이브러리 사용
|
|
try:
|
|
self._model = SimpleLama(device=self._device)
|
|
logger.info("실제 SimpleLama 모델 로딩 완료")
|
|
except ImportError as e:
|
|
logger.warning(f"SimpleLama 라이브러리 import 실패: {e}")
|
|
logger.info("fallback 모드로 전환합니다...")
|
|
# fallback으로 시뮬레이션 모드 사용
|
|
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
|
|
except Exception as e:
|
|
logger.error(f"SimpleLama 모델 초기화 실패: {e}")
|
|
logger.info("fallback 모드로 전환합니다...")
|
|
self._model = {"type": "simple_lama_fallback", "device": self._device, "fp16": self._fp16}
|
|
|
|
self.loaded = True
|
|
logger.info("Simple LAMA model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Simple LAMA model: {e}")
|
|
raise
|
|
|
|
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""이미지를 전처리합니다."""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# RGB로 변환
|
|
if image.shape[2] == 4: # RGBA
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
elif image.shape[2] == 3 and image.dtype == np.uint8:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
# 정규화 (0-1)
|
|
image = image.astype(np.float32) / 255.0
|
|
|
|
# 텐서로 변환 (B, C, H, W)
|
|
tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
|
|
|
if self._fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self._device)
|
|
|
|
def preprocess_mask(self, mask: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""마스크를 전처리합니다."""
|
|
if isinstance(mask, Image.Image):
|
|
mask = np.array(mask)
|
|
|
|
# 그레이스케일로 변환
|
|
if len(mask.shape) == 3:
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
|
|
# 이진화 (0 또는 1)
|
|
mask = (mask > 127).astype(np.float32)
|
|
|
|
# 텐서로 변환 (B, 1, H, W)
|
|
tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
|
|
|
|
if self._fp16:
|
|
tensor = tensor.half()
|
|
|
|
return tensor.to(self._device)
|
|
|
|
def postprocess_result(self, tensor: torch.Tensor) -> np.ndarray:
|
|
"""결과를 후처리합니다."""
|
|
# CPU로 이동하고 numpy로 변환
|
|
if tensor.is_cuda:
|
|
tensor = tensor.cpu()
|
|
if tensor.dtype == torch.float16:
|
|
tensor = tensor.float()
|
|
|
|
result = tensor.squeeze(0).permute(1, 2, 0).numpy()
|
|
|
|
# 0-255 범위로 변환
|
|
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
return result
|
|
|
|
async def inpaint(
|
|
self,
|
|
images: List[np.ndarray],
|
|
masks: List[np.ndarray],
|
|
**kwargs,
|
|
) -> List[np.ndarray]:
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
if not self.is_ready:
|
|
raise RuntimeError("SimpleLama model is not loaded yet.")
|
|
|
|
# 모델이 GPU에 있는지 확인
|
|
if self._device.type != 'cpu':
|
|
torch.cuda.empty_cache()
|
|
|
|
# 전처리
|
|
pil_images = [Image.fromarray(img) for img in images]
|
|
pil_masks = [Image.fromarray(mask) for mask in masks]
|
|
|
|
preprocessed_images = []
|
|
preprocessed_masks = []
|
|
for img, mask in zip(pil_images, pil_masks):
|
|
img_tensor, mask_tensor = self._preprocess(img, mask)
|
|
preprocessed_images.append(img_tensor)
|
|
preprocessed_masks.append(mask_tensor)
|
|
|
|
image_batch = torch.stack(preprocessed_images).to(self._device)
|
|
mask_batch = torch.stack(preprocessed_masks).to(self._device)
|
|
|
|
# 원본 이미지와 사이즈 저장
|
|
original_images_and_sizes = list(zip(pil_images, [img.size for img in pil_images]))
|
|
|
|
# 모델 호출
|
|
logger.info(f"실제 SimpleLama 모델로 {len(images)}개 이미지 인페인팅 수행")
|
|
with torch.no_grad():
|
|
# 라이브러리의 __call__ 대신 내부 torch 모델을 직접 호출
|
|
inpainted_batch = self._model.model(image_batch, mask_batch)
|
|
|
|
# 후처리
|
|
result_images = []
|
|
for i, inpainted_tensor in enumerate(inpainted_batch):
|
|
original_image, original_size = original_images_and_sizes[i]
|
|
original_mask = pil_masks[i]
|
|
result_pil = self._postprocess(inpainted_tensor, original_size, original_image, original_mask)
|
|
result_images.append(np.array(result_pil))
|
|
|
|
return result_images
|
|
|
|
def _preprocess(self, image: Image.Image, mask: Image.Image):
|
|
"""단일 이미지를 모델 입력 텐서로 전처리합니다."""
|
|
# simple_lama_inpainting.models.lama.py의 전처리 로직 참고
|
|
image = image.convert("RGB")
|
|
mask = mask.convert("L")
|
|
|
|
# 이미지 리사이즈 (모델 요구사항에 맞게)
|
|
resized_image = image.resize((512, 512), Image.Resampling.LANCZOS)
|
|
resized_mask = mask.resize((512, 512), Image.Resampling.NEAREST)
|
|
|
|
image_tensor = torch.from_numpy(np.array(resized_image, dtype=np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).squeeze(0)
|
|
mask_tensor = torch.from_numpy(np.array(resized_mask, dtype=np.float32) / 255.0).unsqueeze(0).unsqueeze(0).squeeze(0)
|
|
|
|
return image_tensor, mask_tensor
|
|
|
|
def _postprocess(self, tensor: torch.Tensor, original_size: Tuple[int, int], original_image: Image.Image, original_mask: Image.Image) -> Image.Image:
|
|
"""모델 출력 텐서를 PIL 이미지로 후처리하고 원본에 합성합니다."""
|
|
# 텐서를 PIL 이미지로 변환
|
|
result_np = tensor.permute(1, 2, 0).cpu().numpy()
|
|
result_np = np.clip(result_np * 255, 0, 255).astype(np.uint8)
|
|
inpainted_image_512 = Image.fromarray(result_np)
|
|
|
|
# 원본 크기로 리사이즈
|
|
resized_inpainted_image = inpainted_image_512.resize(original_size, Image.Resampling.LANCZOS)
|
|
|
|
# 원본 마스크를 사용하여 원본 이미지와 합성
|
|
original_mask = original_mask.convert("L")
|
|
final_image = Image.composite(resized_inpainted_image, original_image, original_mask)
|
|
|
|
return final_image
|
|
|
|
|
|
@property
|
|
def is_ready(self) -> bool:
|
|
return self._model is not None
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "simple_lama",
|
|
"device": self._device,
|
|
"fp16": self._fp16,
|
|
"loaded": self.loaded,
|
|
"model_path": self.model_path
|
|
}
|