inpaintServer/app/models/rembg_model.py

175 lines
6.5 KiB
Python

"""
REMBG 배경 제거 모델 구현
"""
import torch
import numpy as np
import cv2
from PIL import Image
import logging
from typing import Union, Tuple
import asyncio
logger = logging.getLogger(__name__)
class RembgProcessor:
def __init__(self, model_name: str = "u2net", device: str = "cuda", fp16: bool = True):
self.model_name = model_name
self.device = device
self.fp16 = fp16
self.model = None
self.loaded = False
async def load_model(self):
"""모델을 비동기적으로 로드합니다."""
if self.loaded:
return
try:
logger.info(f"Loading REMBG model ({self.model_name})...")
# 실제 구현에서는 rembg 라이브러리를 사용
# 여기서는 플레이스홀더로 구현
await asyncio.sleep(0.1) # 모델 로딩 시뮬레이션
# TODO: 실제 모델 로딩 로직
# from rembg import new_session
# self.model = new_session(self.model_name)
self.model = {
"type": "rembg",
"model_name": self.model_name,
"device": self.device,
"fp16": self.fp16
}
self.loaded = True
logger.info(f"REMBG model ({self.model_name}) loaded successfully")
except Exception as e:
logger.error(f"Failed to load REMBG model: {e}")
raise
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> np.ndarray:
"""이미지를 전처리합니다."""
if isinstance(image, Image.Image):
image = np.array(image)
# RGB로 변환
if image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
elif len(image.shape) == 3 and image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def create_mask_from_alpha(self, rgba_image: np.ndarray) -> np.ndarray:
"""RGBA 이미지에서 알파 채널을 마스크로 변환합니다."""
if rgba_image.shape[2] != 4:
raise ValueError("Input image must have 4 channels (RGBA)")
# 알파 채널을 마스크로 사용
alpha_channel = rgba_image[:, :, 3]
# 0-255 범위의 마스크 생성
mask = alpha_channel.astype(np.uint8)
return mask
async def remove_background(self, image: Union[Image.Image, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""배경을 제거하고 결과 이미지와 마스크를 반환합니다."""
if not self.loaded:
await self.load_model()
try:
# 전처리
processed_image = self.preprocess_image(image)
original_shape = processed_image.shape
# 배경 제거 (실제 구현에서는 rembg 사용)
# TODO: 실제 모델 추론 로직
# from rembg import remove
# result_rgba = remove(self.model, processed_image)
# 플레이스홀더: 배경 제거 시뮬레이션
result_rgba = await self._simulate_background_removal(processed_image)
# 결과에서 RGB 이미지와 마스크 분리
result_rgb = result_rgba[:, :, :3]
mask = self.create_mask_from_alpha(result_rgba)
return result_rgb, mask
except Exception as e:
logger.error(f"Background removal failed: {e}")
raise
async def _simulate_background_removal(self, image: np.ndarray) -> np.ndarray:
"""배경 제거 시뮬레이션 (실제 구현에서는 제거)"""
# 비동기 처리 시뮬레이션
await asyncio.sleep(0.08) # REMBG는 상대적으로 빠르다고 가정
height, width = image.shape[:2]
# 간단한 전경/배경 분리 시뮬레이션
# 중앙 영역을 전경으로, 가장자리를 배경으로 가정
center_x, center_y = width // 2, height // 2
# 타원형 마스크 생성
y, x = np.ogrid[:height, :width]
mask = ((x - center_x) ** 2 / (width * 0.3) ** 2 +
(y - center_y) ** 2 / (height * 0.4) ** 2) <= 1
# 부드러운 가장자리를 위한 가우시안 블러
mask_float = mask.astype(np.float32)
mask_blurred = cv2.GaussianBlur(mask_float, (51, 51), 20)
# RGBA 이미지 생성
result_rgba = np.zeros((height, width, 4), dtype=np.uint8)
result_rgba[:, :, :3] = image # RGB 채널
result_rgba[:, :, 3] = (mask_blurred * 255).astype(np.uint8) # 알파 채널
return result_rgba
async def apply_new_background(self, foreground: np.ndarray, mask: np.ndarray,
background: Union[np.ndarray, tuple]) -> np.ndarray:
"""새로운 배경을 적용합니다."""
try:
height, width = foreground.shape[:2]
# 배경 준비
if isinstance(background, tuple):
# 단색 배경
bg = np.full((height, width, 3), background, dtype=np.uint8)
else:
# 이미지 배경
if isinstance(background, Image.Image):
background = np.array(background)
bg = cv2.resize(background, (width, height))
if len(bg.shape) == 3 and bg.shape[2] == 4:
bg = bg[:, :, :3] # RGBA에서 RGB로
# 마스크를 0-1 범위로 정규화
mask_norm = mask.astype(np.float32) / 255.0
mask_3ch = np.stack([mask_norm] * 3, axis=-1)
# 알파 블렌딩
result = (foreground.astype(np.float32) * mask_3ch +
bg.astype(np.float32) * (1 - mask_3ch))
return result.astype(np.uint8)
except Exception as e:
logger.error(f"Background application failed: {e}")
raise
def get_model_info(self) -> dict:
"""모델 정보를 반환합니다."""
return {
"model_type": "rembg",
"model_name": self.model_name,
"device": self.device,
"fp16": self.fp16,
"loaded": self.loaded
}