inpaintServer/app/models/rembg_model.py

275 lines
11 KiB
Python

"""
REMBG 배경 제거 모델 구현 (실제 rembg 라이브러리 사용)
"""
import os
import cv2
from PIL import Image
import logging
import numpy as np
import onnxruntime # ONNX 런타임 직접 사용을 위해 임포트
from typing import Union, Tuple, Optional
import asyncio
from ..core.config import settings
from ..utils.gpu_monitor import gpu_monitor
from ..utils.image_utils import fill_transparent_background_with_white
logger = logging.getLogger(__name__)
class RembgProcessor:
"""
rembg 기반 배경제거 모듈 (안전한 의존성 처리)
"""
# 사용하시려는 birefnet 모델을 지원 목록에 추가합니다.
SUPPORTED_MODELS = {
"u2net": "범용 배경제거 | 빠름 | 사람/사물 모두 양호 (기본값)",
"u2netp": "u2net 경량화 | 매우 빠름 | 실시간, 저사양PC",
"u2net_human_seg": "인물 전용 | 빠름 | 사람 경계 정밀",
"u2net_cloth_seg": "옷 전용 | 빠름 | 패션/의류 특화",
"isnet-general-use": "범용 고품질 | 느림 | 디테일 중시, 대용량",
"sam": "SAM 최고 품질 | 매우 느림 | 고성능PC 권장",
"sam-mobile": "SAM 경량화 | 보통 | 모바일, 중간성능",
"birefnet-general-lite": "BiRefNet 경량 모델 | 고품질 저용량 (로컬)"
}
# SUPPORTED_MODELS 키와 실제 rembg sessions 키 간의 매핑
MODEL_NAME_MAPPING = {
"u2net": "u2net",
"u2netp": "u2netp",
"u2net_human_seg": "u2net-human-seg",
"u2net_cloth_seg": "u2net-cloth-seg",
"isnet-general-use": "dis-general-use",
"sam": "sam",
"sam-mobile": "sam", # sam-mobile은 sam과 동일하게 처리
"birefnet-general-lite": "birefnet-general-lite"
}
def __init__(self, model_name: str = "u2net", device: str = "cuda", fp16: bool = True,
local_rembg_model_path: str = None):
self.model_name = model_name
self.device = device
self.fp16 = fp16
self.local_rembg_model_path = local_rembg_model_path
self.sessions = {}
self.loaded = False
self._rembg_available = None
self._init_error = None
self._cuda_providers_tested = False
def _check_rembg_availability(self):
"""rembg 모듈 사용 가능 여부를 확인하고 캐시.
세션을 생성하지 않아 모델 다운로드를 유발하지 않도록 함."""
if self._rembg_available is not None:
return self._rembg_available
try:
import rembg # noqa: F401
self._rembg_available = True
logger.info("rembg 모듈 임포트 성공 (세션 생성은 지연 로딩)")
return True
except ImportError as e:
self._init_error = f"rembg 모듈이 설치되지 않음: {e}"
self._rembg_available = False
except Exception as e:
self._init_error = f"rembg 모듈 초기화 실패 (의존성/하드웨어 문제): {e}"
self._rembg_available = False
logger.error(self._init_error)
return False
def get_session(self, model_name, timeout_seconds: int = 90):
"""
모델별 세션을 캐싱하여 반환 (로컬 모델 경로 및 CUDA 지원 포함)
"""
if not self._check_rembg_availability():
logger.error(f"rembg 사용 불가로 세션 생성 실패: {self._init_error}")
return None
# device 설정에 따라 CUDA 사용 여부 결정 (간소화)
cuda_enabled = self.device == "cuda"
# 실제 모델명을 세션 키에 사용
actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name)
session_key = f"{actual_model_name}_cuda_{cuda_enabled}"
if session_key not in self.sessions:
logger.info(f"🔧 rembg 새 세션 생성 필요: {session_key}")
try:
import rembg
try:
from rembg.sessions import sessions
except ImportError:
# rembg 버전에 따라 import 경로가 다를 수 있음
sessions = None
logger.warning("rembg.sessions import 실패, 기본 방식 사용")
# Jetson 환경에서 TensorRT 충돌을 피하기 위해 프로바이더 명시
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
logger.info(f"rembg 세션 생성 providers: {providers}")
session = rembg.new_session(
model_name=actual_model_name,
providers=providers
)
self.sessions[session_key] = session
# 실제 사용된 provider 확인 및 로깅 (가드 처리)
actual_providers = []
try:
inner = getattr(session, 'inner_session', None)
if inner and hasattr(inner, 'get_providers'):
actual_providers = inner.get_providers() or []
except Exception as prov_err:
logger.debug(f"rembg provider 확인 실패: {prov_err}")
is_gpu = any(('CUDA' in p) or ('Tensorrt' in p) for p in actual_providers)
status = "GPU 가속" if is_gpu else "CPU 모드"
logger.info(
f"✅ rembg '{actual_model_name}' {status}로 동작 (providers: {actual_providers or '알 수 없음'})"
)
except Exception as e:
logger.error(f"rembg 세션 생성 실패 ('{actual_model_name}'): {e}", exc_info=True)
return None
else:
logger.debug(f"♻️ rembg 기존 세션 재사용: {session_key}")
return self.sessions.get(session_key)
async def load_model(self):
"""모델을 비동기적으로 로드합니다."""
if self.loaded:
return
try:
logger.info(f"Loading REMBG model ({self.model_name})...")
# rembg 사용 가능성 확인
if not self._check_rembg_availability():
raise RuntimeError(f"REMBG 사용 불가: {self._init_error}")
# 세션 생성
session = self.get_session(self.model_name)
if session is None:
raise RuntimeError(f"REMBG 세션 생성 실패: {self.model_name}")
self.loaded = True
logger.info(f"REMBG model ({self.model_name}) loaded successfully")
except Exception as e:
logger.error(f"Failed to load REMBG model: {e}")
raise
def to_white_background(self, img: Image.Image) -> Image.Image:
"""RGBA 이미지를 흰 배경으로 변환"""
if img.mode in ("RGBA", "BGRA"):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[-1])
return bg
else:
return img.convert("RGB")
async def remove_background(self, image: Union[str, Image.Image, np.ndarray],
model_name: str = None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
"""
배경을 제거하고 결과 이미지와 마스크를 반환합니다.
Args:
image: 입력 이미지 (파일 경로, PIL Image, 또는 numpy array)
model_name: 사용할 모델명 (없으면 기본 모델 사용)
**kwargs: 추가 옵션 (alpha_matting 등)
Returns:
(result_rgb, mask): 결과 이미지(RGB)와 마스크
"""
if not self.loaded:
await self.load_model()
try:
# 이미지 로드 및 변환
if isinstance(image, str):
if not os.path.exists(image):
logger.error(f"입력 이미지가 존재하지 않습니다: {image}")
return None, None
img = cv2.imread(image)
if img is None:
logger.error(f"이미지 로드 실패: {image}")
return None, None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif isinstance(image, Image.Image):
img_rgb = np.array(image.convert('RGB'))
elif isinstance(image, np.ndarray):
if len(image.shape) == 3 and image.shape[2] == 3:
# BGR to RGB
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
img_rgb = image
else:
logger.error(f"지원하지 않는 이미지 타입: {type(image)}")
return None, None
# 사용할 모델명 결정
effective_model_name = model_name or self.model_name
if effective_model_name not in self.SUPPORTED_MODELS:
logger.warning(f"지원하지 않는 모델명: {effective_model_name}. u2net으로 대체 사용")
effective_model_name = "u2net"
session = self.get_session(effective_model_name)
if session is None:
return None, None
import rembg
import time
start_time = time.time()
# rembg.remove는 RGBA 이미지를 반환
output_image_rgba = rembg.remove(image, session=session)
# 투명 배경을 흰색으로 채우기
output_image_rgb = fill_transparent_background_with_white(output_image_rgba)
# 마스크 생성 (알파 채널 사용)
mask = output_image_rgba[:, :, 3]
logger.debug("Background removal and white filling successful.")
return output_image_rgb, mask
except Exception as e:
logger.error(f"Error during rembg processing: {e}", exc_info=True)
return None, None
def set_default_model(self, model_name):
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"지원하지 않는 모델명: {model_name}")
self.model_name = model_name
logger.info(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨")
def get_default_model(self):
return self.model_name
def get_supported_models(self):
return self.SUPPORTED_MODELS.copy()
def get_model_description(self, model_name):
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
def is_available(self):
return self._check_rembg_availability()
def get_init_error(self):
return self._init_error
def get_model_info(self) -> dict:
"""모델 정보를 반환합니다."""
return {
"model_type": "rembg",
"model_name": self.model_name,
"device": self.device,
"fp16": self.fp16,
"loaded": self.loaded,
"available": self.is_available(),
"supported_models": list(self.SUPPORTED_MODELS.keys()),
"local_model_path": self.local_rembg_model_path
}