""" REMBG 배경 제거 모델 구현 """ import torch import numpy as np import cv2 from PIL import Image import logging from typing import Union, Tuple import asyncio logger = logging.getLogger(__name__) class RembgProcessor: def __init__(self, model_name: str = "u2net", device: str = "cuda", fp16: bool = True): self.model_name = model_name self.device = device self.fp16 = fp16 self.model = None self.loaded = False async def load_model(self): """모델을 비동기적으로 로드합니다.""" if self.loaded: return try: logger.info(f"Loading REMBG model ({self.model_name})...") # 실제 구현에서는 rembg 라이브러리를 사용 # 여기서는 플레이스홀더로 구현 await asyncio.sleep(0.1) # 모델 로딩 시뮬레이션 # TODO: 실제 모델 로딩 로직 # from rembg import new_session # self.model = new_session(self.model_name) self.model = { "type": "rembg", "model_name": self.model_name, "device": self.device, "fp16": self.fp16 } self.loaded = True logger.info(f"REMBG model ({self.model_name}) loaded successfully") except Exception as e: logger.error(f"Failed to load REMBG model: {e}") raise def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> np.ndarray: """이미지를 전처리합니다.""" if isinstance(image, Image.Image): image = np.array(image) # RGB로 변환 if image.shape[2] == 4: # RGBA image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) elif len(image.shape) == 3 and image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def create_mask_from_alpha(self, rgba_image: np.ndarray) -> np.ndarray: """RGBA 이미지에서 알파 채널을 마스크로 변환합니다.""" if rgba_image.shape[2] != 4: raise ValueError("Input image must have 4 channels (RGBA)") # 알파 채널을 마스크로 사용 alpha_channel = rgba_image[:, :, 3] # 0-255 범위의 마스크 생성 mask = alpha_channel.astype(np.uint8) return mask async def remove_background(self, image: Union[Image.Image, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: """배경을 제거하고 결과 이미지와 마스크를 반환합니다.""" if not self.loaded: await self.load_model() try: # 전처리 processed_image = self.preprocess_image(image) original_shape = processed_image.shape # 배경 제거 (실제 구현에서는 rembg 사용) # TODO: 실제 모델 추론 로직 # from rembg import remove # result_rgba = remove(self.model, processed_image) # 플레이스홀더: 배경 제거 시뮬레이션 result_rgba = await self._simulate_background_removal(processed_image) # 결과에서 RGB 이미지와 마스크 분리 result_rgb = result_rgba[:, :, :3] mask = self.create_mask_from_alpha(result_rgba) return result_rgb, mask except Exception as e: logger.error(f"Background removal failed: {e}") raise async def _simulate_background_removal(self, image: np.ndarray) -> np.ndarray: """배경 제거 시뮬레이션 (실제 구현에서는 제거)""" # 비동기 처리 시뮬레이션 await asyncio.sleep(0.08) # REMBG는 상대적으로 빠르다고 가정 height, width = image.shape[:2] # 간단한 전경/배경 분리 시뮬레이션 # 중앙 영역을 전경으로, 가장자리를 배경으로 가정 center_x, center_y = width // 2, height // 2 # 타원형 마스크 생성 y, x = np.ogrid[:height, :width] mask = ((x - center_x) ** 2 / (width * 0.3) ** 2 + (y - center_y) ** 2 / (height * 0.4) ** 2) <= 1 # 부드러운 가장자리를 위한 가우시안 블러 mask_float = mask.astype(np.float32) mask_blurred = cv2.GaussianBlur(mask_float, (51, 51), 20) # RGBA 이미지 생성 result_rgba = np.zeros((height, width, 4), dtype=np.uint8) result_rgba[:, :, :3] = image # RGB 채널 result_rgba[:, :, 3] = (mask_blurred * 255).astype(np.uint8) # 알파 채널 return result_rgba async def apply_new_background(self, foreground: np.ndarray, mask: np.ndarray, background: Union[np.ndarray, tuple]) -> np.ndarray: """새로운 배경을 적용합니다.""" try: height, width = foreground.shape[:2] # 배경 준비 if isinstance(background, tuple): # 단색 배경 bg = np.full((height, width, 3), background, dtype=np.uint8) else: # 이미지 배경 if isinstance(background, Image.Image): background = np.array(background) bg = cv2.resize(background, (width, height)) if len(bg.shape) == 3 and bg.shape[2] == 4: bg = bg[:, :, :3] # RGBA에서 RGB로 # 마스크를 0-1 범위로 정규화 mask_norm = mask.astype(np.float32) / 255.0 mask_3ch = np.stack([mask_norm] * 3, axis=-1) # 알파 블렌딩 result = (foreground.astype(np.float32) * mask_3ch + bg.astype(np.float32) * (1 - mask_3ch)) return result.astype(np.uint8) except Exception as e: logger.error(f"Background application failed: {e}") raise def get_model_info(self) -> dict: """모델 정보를 반환합니다.""" return { "model_type": "rembg", "model_name": self.model_name, "device": self.device, "fp16": self.fp16, "loaded": self.loaded }