291 lines
12 KiB
Python
291 lines
12 KiB
Python
"""
|
|
REMBG 배경 제거 모델 구현 (실제 rembg 라이브러리 사용)
|
|
"""
|
|
import os
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
import numpy as np
|
|
import onnxruntime # ONNX 런타임 직접 사용을 위해 임포트
|
|
from typing import Union, Tuple, Optional
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RembgProcessor:
|
|
"""
|
|
rembg 기반 배경제거 모듈 (안전한 의존성 처리)
|
|
"""
|
|
|
|
# 사용하시려는 birefnet 모델을 지원 목록에 추가합니다.
|
|
SUPPORTED_MODELS = {
|
|
"u2net": "범용 배경제거 | 빠름 | 사람/사물 모두 양호 (기본값)",
|
|
"u2netp": "u2net 경량화 | 매우 빠름 | 실시간, 저사양PC",
|
|
"u2net_human_seg": "인물 전용 | 빠름 | 사람 경계 정밀",
|
|
"u2net_cloth_seg": "옷 전용 | 빠름 | 패션/의류 특화",
|
|
"isnet-general-use": "범용 고품질 | 느림 | 디테일 중시, 대용량",
|
|
"sam": "SAM 최고 품질 | 매우 느림 | 고성능PC 권장",
|
|
"sam-mobile": "SAM 경량화 | 보통 | 모바일, 중간성능",
|
|
"birefnet-general-lite": "BiRefNet 경량 모델 | 고품질 저용량 (로컬)"
|
|
}
|
|
|
|
# SUPPORTED_MODELS 키와 실제 rembg sessions 키 간의 매핑
|
|
MODEL_NAME_MAPPING = {
|
|
"u2net": "u2net",
|
|
"u2netp": "u2netp",
|
|
"u2net_human_seg": "u2net-human-seg",
|
|
"u2net_cloth_seg": "u2net-cloth-seg",
|
|
"isnet-general-use": "dis-general-use",
|
|
"sam": "sam",
|
|
"sam-mobile": "sam", # sam-mobile은 sam과 동일하게 처리
|
|
"birefnet-general-lite": "birefnet-general-lite"
|
|
}
|
|
|
|
def __init__(self, model_name: str = "u2net", device: str = "cuda", fp16: bool = True,
|
|
local_rembg_model_path: str = None):
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.local_rembg_model_path = local_rembg_model_path
|
|
self.sessions = {}
|
|
self.loaded = False
|
|
self._rembg_available = None
|
|
self._init_error = None
|
|
self._cuda_providers_tested = False
|
|
|
|
def _check_rembg_availability(self):
|
|
"""rembg 모듈 사용 가능 여부를 확인하고 캐시.
|
|
세션을 생성하지 않아 모델 다운로드를 유발하지 않도록 함."""
|
|
if self._rembg_available is not None:
|
|
return self._rembg_available
|
|
|
|
try:
|
|
import rembg # noqa: F401
|
|
self._rembg_available = True
|
|
logger.info("rembg 모듈 임포트 성공 (세션 생성은 지연 로딩)")
|
|
return True
|
|
except ImportError as e:
|
|
self._init_error = f"rembg 모듈이 설치되지 않음: {e}"
|
|
self._rembg_available = False
|
|
except Exception as e:
|
|
self._init_error = f"rembg 모듈 초기화 실패 (의존성/하드웨어 문제): {e}"
|
|
self._rembg_available = False
|
|
|
|
logger.error(self._init_error)
|
|
return False
|
|
|
|
def get_session(self, model_name, timeout_seconds: int = 90):
|
|
"""
|
|
모델별 세션을 캐싱하여 반환 (로컬 모델 경로 및 CUDA 지원 포함)
|
|
"""
|
|
if not self._check_rembg_availability():
|
|
logger.error(f"rembg 사용 불가로 세션 생성 실패: {self._init_error}")
|
|
return None
|
|
|
|
# device 설정에 따라 CUDA 사용 여부 결정 (간소화)
|
|
cuda_enabled = self.device == "cuda"
|
|
# 실제 모델명을 세션 키에 사용
|
|
actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name)
|
|
session_key = f"{actual_model_name}_cuda_{cuda_enabled}"
|
|
|
|
if session_key not in self.sessions:
|
|
logger.info(f"🔧 rembg 새 세션 생성 필요: {session_key}")
|
|
try:
|
|
import rembg
|
|
try:
|
|
from rembg.sessions import sessions
|
|
except ImportError:
|
|
# rembg 버전에 따라 import 경로가 다를 수 있음
|
|
sessions = None
|
|
logger.warning("rembg.sessions import 실패, 기본 방식 사용")
|
|
|
|
# Jetson 환경에서 TensorRT 충돌을 피하기 위해 프로바이더 명시
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
logger.info(f"rembg 세션 생성 providers: {providers}")
|
|
|
|
session = rembg.new_session(
|
|
model_name=actual_model_name,
|
|
providers=providers
|
|
)
|
|
|
|
self.sessions[session_key] = session
|
|
|
|
# 실제 사용된 provider 확인 및 로깅 (가드 처리)
|
|
actual_providers = []
|
|
try:
|
|
inner = getattr(session, 'inner_session', None)
|
|
if inner and hasattr(inner, 'get_providers'):
|
|
actual_providers = inner.get_providers() or []
|
|
except Exception as prov_err:
|
|
logger.debug(f"rembg provider 확인 실패: {prov_err}")
|
|
|
|
is_gpu = any(('CUDA' in p) or ('Tensorrt' in p) for p in actual_providers)
|
|
status = "GPU 가속" if is_gpu else "CPU 모드"
|
|
logger.info(
|
|
f"✅ rembg '{actual_model_name}' {status}로 동작 (providers: {actual_providers or '알 수 없음'})"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"rembg 세션 생성 실패 ('{actual_model_name}'): {e}", exc_info=True)
|
|
return None
|
|
else:
|
|
logger.debug(f"♻️ rembg 기존 세션 재사용: {session_key}")
|
|
|
|
return self.sessions.get(session_key)
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info(f"Loading REMBG model ({self.model_name})...")
|
|
|
|
# rembg 사용 가능성 확인
|
|
if not self._check_rembg_availability():
|
|
raise RuntimeError(f"REMBG 사용 불가: {self._init_error}")
|
|
|
|
# 세션 생성
|
|
session = self.get_session(self.model_name)
|
|
if session is None:
|
|
raise RuntimeError(f"REMBG 세션 생성 실패: {self.model_name}")
|
|
|
|
self.loaded = True
|
|
logger.info(f"REMBG model ({self.model_name}) loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load REMBG model: {e}")
|
|
raise
|
|
|
|
def to_white_background(self, img: Image.Image) -> Image.Image:
|
|
"""RGBA 이미지를 흰 배경으로 변환"""
|
|
if img.mode in ("RGBA", "BGRA"):
|
|
bg = Image.new("RGB", img.size, (255, 255, 255))
|
|
bg.paste(img, mask=img.split()[-1])
|
|
return bg
|
|
else:
|
|
return img.convert("RGB")
|
|
|
|
async def remove_background(self, image: Union[str, Image.Image, np.ndarray],
|
|
model_name: str = None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
배경을 제거하고 결과 이미지와 마스크를 반환합니다.
|
|
|
|
Args:
|
|
image: 입력 이미지 (파일 경로, PIL Image, 또는 numpy array)
|
|
model_name: 사용할 모델명 (없으면 기본 모델 사용)
|
|
**kwargs: 추가 옵션 (alpha_matting 등)
|
|
|
|
Returns:
|
|
(result_rgb, mask): 결과 이미지(RGB)와 마스크
|
|
"""
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
try:
|
|
# 이미지 로드 및 변환
|
|
if isinstance(image, str):
|
|
if not os.path.exists(image):
|
|
logger.error(f"입력 이미지가 존재하지 않습니다: {image}")
|
|
return None, None
|
|
img = cv2.imread(image)
|
|
if img is None:
|
|
logger.error(f"이미지 로드 실패: {image}")
|
|
return None, None
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
elif isinstance(image, Image.Image):
|
|
img_rgb = np.array(image.convert('RGB'))
|
|
elif isinstance(image, np.ndarray):
|
|
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
# BGR to RGB
|
|
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
else:
|
|
img_rgb = image
|
|
else:
|
|
logger.error(f"지원하지 않는 이미지 타입: {type(image)}")
|
|
return None, None
|
|
|
|
# 사용할 모델명 결정
|
|
effective_model_name = model_name or self.model_name
|
|
|
|
if effective_model_name not in self.SUPPORTED_MODELS:
|
|
logger.warning(f"지원하지 않는 모델명: {effective_model_name}. u2net으로 대체 사용")
|
|
effective_model_name = "u2net"
|
|
|
|
session = self.get_session(effective_model_name)
|
|
if session is None:
|
|
return None, None
|
|
|
|
import rembg
|
|
import time
|
|
|
|
start_time = time.time()
|
|
result = rembg.remove(img_rgb, session=session, alpha_matting=kwargs.get("alpha_matting", False))
|
|
end_time = time.time()
|
|
|
|
if not isinstance(result, Image.Image):
|
|
result = Image.fromarray(result)
|
|
|
|
# RGBA 이미지에서 RGB와 마스크 분리
|
|
result_rgba = np.array(result)
|
|
if result_rgba.shape[2] == 4:
|
|
result_rgb = result_rgba[:, :, :3]
|
|
mask = result_rgba[:, :, 3]
|
|
else:
|
|
result_rgb = result_rgba
|
|
# 간단한 마스크 생성 (배경이 검은색이라고 가정)
|
|
gray = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2GRAY)
|
|
mask = (gray > 10).astype(np.uint8) * 255
|
|
|
|
processing_time = end_time - start_time
|
|
# provider 기반 상태 로깅 (세션에서 확인 시도)
|
|
try:
|
|
sess = self.sessions.get(f"{self.MODEL_NAME_MAPPING.get(effective_model_name, effective_model_name)}_cuda_{self.device == 'cuda'}")
|
|
providers = []
|
|
if sess and getattr(sess, 'inner_session', None) and hasattr(sess.inner_session, 'get_providers'):
|
|
providers = sess.inner_session.get_providers() or []
|
|
cuda_status = "CUDA" if any('CUDA' in p or 'Tensorrt' in p for p in providers) else "CPU"
|
|
except Exception:
|
|
cuda_status = "알 수 없음"
|
|
logger.info(f"✅ 배경 제거 성공: {effective_model_name} ({cuda_status}, {processing_time:.2f}초)")
|
|
|
|
return result_rgb, mask
|
|
|
|
except Exception as e:
|
|
logger.error(f"배경 제거 처리 중 오류 ({model_name}): {e}", exc_info=True)
|
|
return None, None
|
|
|
|
def set_default_model(self, model_name):
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"지원하지 않는 모델명: {model_name}")
|
|
self.model_name = model_name
|
|
logger.info(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨")
|
|
|
|
def get_default_model(self):
|
|
return self.model_name
|
|
|
|
def get_supported_models(self):
|
|
return self.SUPPORTED_MODELS.copy()
|
|
|
|
def get_model_description(self, model_name):
|
|
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
|
|
|
|
def is_available(self):
|
|
return self._check_rembg_availability()
|
|
|
|
def get_init_error(self):
|
|
return self._init_error
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "rembg",
|
|
"model_name": self.model_name,
|
|
"device": self.device,
|
|
"fp16": self.fp16,
|
|
"loaded": self.loaded,
|
|
"available": self.is_available(),
|
|
"supported_models": list(self.SUPPORTED_MODELS.keys()),
|
|
"local_model_path": self.local_rembg_model_path
|
|
} |