92 lines
4.0 KiB
Python
92 lines
4.0 KiB
Python
"""
|
|
REMBG 배경 제거 모델 구현 (rembg 라이브러리 사용)
|
|
"""
|
|
import logging
|
|
import numpy as np
|
|
from PIL import Image
|
|
import rembg
|
|
from ..core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RembgProcessor:
|
|
"""Rembg 라이브러리를 사용한 배경 제거 프로세서"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""초기화 시 임의의 인수를 받아 무시 (session_pool 호환성)"""
|
|
self._session = None
|
|
logger.info("RembgProcessor 초기화 완료")
|
|
|
|
async def load_model(self):
|
|
"""Rembg 세션을 로드합니다."""
|
|
try:
|
|
logger.info("Rembg 세션 생성 중...")
|
|
|
|
# Jetson에서 TensorRT 프로바이더 제외하여 세션 생성
|
|
if settings.IS_JETSON:
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
self._session = rembg.new_session(settings.REMBG_MODEL_NAME, providers=providers)
|
|
else:
|
|
self._session = rembg.new_session(settings.REMBG_MODEL_NAME)
|
|
|
|
logger.info(f"Rembg 세션 생성 완료, 프로바이더: {self._session.providers if hasattr(self._session, 'providers') else 'Unknown'}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rembg 세션 생성 실패: {e}")
|
|
return False
|
|
|
|
async def remove_background(self, image: np.ndarray, model_name: str = None) -> tuple:
|
|
"""이미지에서 배경을 제거합니다."""
|
|
try:
|
|
logger.info(f"배경제거 시작: image.shape={image.shape}, model_name={model_name}")
|
|
|
|
if self._session is None:
|
|
logger.error("Rembg 세션이 None입니다!")
|
|
raise RuntimeError("Rembg 세션이 로드되지 않았습니다.")
|
|
|
|
logger.info(f"Rembg 세션 확인 완료: {type(self._session)}")
|
|
|
|
# numpy 배열을 PIL Image로 변환
|
|
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
# BGR to RGB 변환 (OpenCV 기본값)
|
|
rgb_image = image[:, :, ::-1]
|
|
else:
|
|
rgb_image = image
|
|
|
|
pil_image = Image.fromarray(rgb_image.astype(np.uint8))
|
|
logger.info(f"PIL 이미지 변환 완료: {pil_image.size}, {pil_image.mode}")
|
|
|
|
# rembg로 배경 제거 (RGBA 반환)
|
|
logger.info("rembg.remove() 호출 중...")
|
|
result_pil = rembg.remove(pil_image, session=self._session)
|
|
logger.info(f"rembg.remove() 완료: {result_pil.size}, {result_pil.mode}")
|
|
|
|
# 마스크 추출 및 통계 로깅
|
|
try:
|
|
alpha_channel = np.array(result_pil)[:, :, 3]
|
|
logger.info(f"RMBG mask stats: min={int(alpha_channel.min())}, max={int(alpha_channel.max())}, mean={float(alpha_channel.mean()):.3f}")
|
|
except Exception:
|
|
pass
|
|
|
|
# RGBA를 RGB로 변환 (흰색 배경 추가)
|
|
if result_pil.mode == 'RGBA':
|
|
# 흰색 배경 생성
|
|
white_bg = Image.new('RGB', result_pil.size, (255, 255, 255))
|
|
white_bg.paste(result_pil, mask=result_pil.split()[-1]) # 알파 채널을 마스크로 사용
|
|
result_pil = white_bg
|
|
|
|
# PIL을 numpy 배열로 변환 후 BGR로 변환 (OpenCV 호환)
|
|
result_array = np.array(result_pil)
|
|
if len(result_array.shape) == 3 and result_array.shape[2] == 3:
|
|
result_array = result_array[:, :, ::-1] # RGB to BGR
|
|
|
|
# 마스크 생성 (간단히 알파 채널을 그레이스케일로)
|
|
mask = alpha_channel if 'alpha_channel' in locals() else np.ones((result_array.shape[0], result_array.shape[1]), dtype=np.uint8) * 255
|
|
|
|
return result_array, mask
|
|
|
|
except Exception as e:
|
|
logger.error(f"배경 제거 처리 실패: {e}", exc_info=True)
|
|
return image, None |