175 lines
6.5 KiB
Python
175 lines
6.5 KiB
Python
"""
|
|
REMBG 배경 제거 모델 구현
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
from typing import Union, Tuple
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RembgProcessor:
|
|
def __init__(self, model_name: str = "u2net", device: str = "cuda", fp16: bool = True):
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.model = None
|
|
self.loaded = False
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info(f"Loading REMBG model ({self.model_name})...")
|
|
|
|
# 실제 구현에서는 rembg 라이브러리를 사용
|
|
# 여기서는 플레이스홀더로 구현
|
|
await asyncio.sleep(0.1) # 모델 로딩 시뮬레이션
|
|
|
|
# TODO: 실제 모델 로딩 로직
|
|
# from rembg import new_session
|
|
# self.model = new_session(self.model_name)
|
|
|
|
self.model = {
|
|
"type": "rembg",
|
|
"model_name": self.model_name,
|
|
"device": self.device,
|
|
"fp16": self.fp16
|
|
}
|
|
self.loaded = True
|
|
|
|
logger.info(f"REMBG model ({self.model_name}) loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load REMBG model: {e}")
|
|
raise
|
|
|
|
def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> np.ndarray:
|
|
"""이미지를 전처리합니다."""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# RGB로 변환
|
|
if image.shape[2] == 4: # RGBA
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
elif len(image.shape) == 3 and image.shape[2] == 3:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
return image
|
|
|
|
def create_mask_from_alpha(self, rgba_image: np.ndarray) -> np.ndarray:
|
|
"""RGBA 이미지에서 알파 채널을 마스크로 변환합니다."""
|
|
if rgba_image.shape[2] != 4:
|
|
raise ValueError("Input image must have 4 channels (RGBA)")
|
|
|
|
# 알파 채널을 마스크로 사용
|
|
alpha_channel = rgba_image[:, :, 3]
|
|
|
|
# 0-255 범위의 마스크 생성
|
|
mask = alpha_channel.astype(np.uint8)
|
|
|
|
return mask
|
|
|
|
async def remove_background(self, image: Union[Image.Image, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""배경을 제거하고 결과 이미지와 마스크를 반환합니다."""
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
try:
|
|
# 전처리
|
|
processed_image = self.preprocess_image(image)
|
|
original_shape = processed_image.shape
|
|
|
|
# 배경 제거 (실제 구현에서는 rembg 사용)
|
|
# TODO: 실제 모델 추론 로직
|
|
# from rembg import remove
|
|
# result_rgba = remove(self.model, processed_image)
|
|
|
|
# 플레이스홀더: 배경 제거 시뮬레이션
|
|
result_rgba = await self._simulate_background_removal(processed_image)
|
|
|
|
# 결과에서 RGB 이미지와 마스크 분리
|
|
result_rgb = result_rgba[:, :, :3]
|
|
mask = self.create_mask_from_alpha(result_rgba)
|
|
|
|
return result_rgb, mask
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background removal failed: {e}")
|
|
raise
|
|
|
|
async def _simulate_background_removal(self, image: np.ndarray) -> np.ndarray:
|
|
"""배경 제거 시뮬레이션 (실제 구현에서는 제거)"""
|
|
# 비동기 처리 시뮬레이션
|
|
await asyncio.sleep(0.08) # REMBG는 상대적으로 빠르다고 가정
|
|
|
|
height, width = image.shape[:2]
|
|
|
|
# 간단한 전경/배경 분리 시뮬레이션
|
|
# 중앙 영역을 전경으로, 가장자리를 배경으로 가정
|
|
center_x, center_y = width // 2, height // 2
|
|
|
|
# 타원형 마스크 생성
|
|
y, x = np.ogrid[:height, :width]
|
|
mask = ((x - center_x) ** 2 / (width * 0.3) ** 2 +
|
|
(y - center_y) ** 2 / (height * 0.4) ** 2) <= 1
|
|
|
|
# 부드러운 가장자리를 위한 가우시안 블러
|
|
mask_float = mask.astype(np.float32)
|
|
mask_blurred = cv2.GaussianBlur(mask_float, (51, 51), 20)
|
|
|
|
# RGBA 이미지 생성
|
|
result_rgba = np.zeros((height, width, 4), dtype=np.uint8)
|
|
result_rgba[:, :, :3] = image # RGB 채널
|
|
result_rgba[:, :, 3] = (mask_blurred * 255).astype(np.uint8) # 알파 채널
|
|
|
|
return result_rgba
|
|
|
|
async def apply_new_background(self, foreground: np.ndarray, mask: np.ndarray,
|
|
background: Union[np.ndarray, tuple]) -> np.ndarray:
|
|
"""새로운 배경을 적용합니다."""
|
|
try:
|
|
height, width = foreground.shape[:2]
|
|
|
|
# 배경 준비
|
|
if isinstance(background, tuple):
|
|
# 단색 배경
|
|
bg = np.full((height, width, 3), background, dtype=np.uint8)
|
|
else:
|
|
# 이미지 배경
|
|
if isinstance(background, Image.Image):
|
|
background = np.array(background)
|
|
bg = cv2.resize(background, (width, height))
|
|
if len(bg.shape) == 3 and bg.shape[2] == 4:
|
|
bg = bg[:, :, :3] # RGBA에서 RGB로
|
|
|
|
# 마스크를 0-1 범위로 정규화
|
|
mask_norm = mask.astype(np.float32) / 255.0
|
|
mask_3ch = np.stack([mask_norm] * 3, axis=-1)
|
|
|
|
# 알파 블렌딩
|
|
result = (foreground.astype(np.float32) * mask_3ch +
|
|
bg.astype(np.float32) * (1 - mask_3ch))
|
|
|
|
return result.astype(np.uint8)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background application failed: {e}")
|
|
raise
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "rembg",
|
|
"model_name": self.model_name,
|
|
"device": self.device,
|
|
"fp16": self.fp16,
|
|
"loaded": self.loaded
|
|
}
|