370 lines
19 KiB
Python
370 lines
19 KiB
Python
import os
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
import onnxruntime # [추가] ONNX 런타임 직접 사용을 위해 임포트
|
|
from rembg.sessions import sessions # [수정] rembg의 모델별 세션 클래스 딕셔너리 import
|
|
|
|
|
|
class BackgroundRemovalModule:
|
|
"""
|
|
rembg 기반 배경제거 모듈 (안전한 의존성 처리)
|
|
"""
|
|
|
|
# [수정] 사용하시려는 birefnet 모델을 지원 목록에 추가합니다.
|
|
SUPPORTED_MODELS = {
|
|
"u2net": "범용 배경제거 | 빠름 | 사람/사물 모두 양호 (기본값)",
|
|
"u2netp": "u2net 경량화 | 매우 빠름 | 실시간, 저사양PC",
|
|
"u2net_human_seg": "인물 전용 | 빠름 | 사람 경계 정밀",
|
|
"u2net_cloth_seg": "옷 전용 | 빠름 | 패션/의류 특화",
|
|
"isnet-general-use": "범용 고품질 | 느림 | 디테일 중시, 대용량",
|
|
"sam": "SAM 최고 품질 | 매우 느림 | 고성능PC 권장",
|
|
"sam-mobile": "SAM 경량화 | 보통 | 모바일, 중간성능",
|
|
"birefnet-general-lite": "BiRefNet 경량 모델 | 고품질 저용량 (로컬)"
|
|
}
|
|
|
|
# [추가] SUPPORTED_MODELS 키와 실제 rembg sessions 키 간의 매핑
|
|
MODEL_NAME_MAPPING = {
|
|
"u2net": "u2net",
|
|
"u2netp": "u2netp",
|
|
"u2net_human_seg": "u2net-human-seg",
|
|
"u2net_cloth_seg": "u2net-cloth-seg",
|
|
"isnet-general-use": "dis-general-use",
|
|
"sam": "sam",
|
|
"sam-mobile": "sam", # sam-mobile은 sam과 동일하게 처리
|
|
"birefnet-general-lite": "birefnet-general-lite"
|
|
}
|
|
|
|
def __init__(self, logger=None, default_model="u2net", gpu_manager=None, local_rembg_model_path: str | None = None):
|
|
self.logger = logger
|
|
self.default_model = default_model
|
|
self.gpu_manager = gpu_manager
|
|
self.local_rembg_model_path = local_rembg_model_path
|
|
self.sessions = {}
|
|
self._rembg_available = None
|
|
self._init_error = None
|
|
self._cuda_providers_tested = False
|
|
|
|
# _check_rembg_availability 메서드는 변경할 필요 없습니다. (기존 코드 유지)
|
|
def _check_rembg_availability(self):
|
|
"""rembg 모듈 사용 가능 여부를 확인하고 캐시"""
|
|
if self._rembg_available is not None:
|
|
return self._rembg_available
|
|
|
|
try:
|
|
import rembg
|
|
|
|
providers_to_test = []
|
|
if self.gpu_manager:
|
|
try:
|
|
import onnxruntime as ort
|
|
available_providers = ort.get_available_providers()
|
|
if 'DmlExecutionProvider' in available_providers:
|
|
providers_to_test = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("rembg DirectML provider 테스트 시작", level=logging.INFO)
|
|
else:
|
|
providers_to_test = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("rembg DirectML 미지원, CPU 모드", level=logging.WARNING)
|
|
except Exception as e:
|
|
providers_to_test = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"rembg DirectML 확인 실패: {e}, CPU 모드", level=logging.WARNING)
|
|
else:
|
|
providers_to_test = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("rembg CPU-only 모드로 테스트", level=logging.INFO)
|
|
|
|
model_arg = 'u2net'
|
|
|
|
test_session = rembg.new_session(model_name=model_arg, providers=providers_to_test)
|
|
self._rembg_available = True
|
|
|
|
if hasattr(test_session, 'inner_session') and hasattr(test_session.inner_session, 'get_providers'):
|
|
actual_providers = test_session.inner_session.get_providers()
|
|
if self.logger:
|
|
self.logger.log(f"rembg 세션 생성 성공, 사용된 providers: {actual_providers}", level=logging.INFO)
|
|
if 'DmlExecutionProvider' in actual_providers:
|
|
self.logger.log("✅ rembg DirectML 가속 활성화됨", level=logging.INFO)
|
|
else:
|
|
self.logger.log("rembg CPU 모드로 동작", level=logging.INFO)
|
|
else:
|
|
if self.logger:
|
|
self.logger.log("rembg 모듈 사용 가능 확인됨", level=logging.INFO)
|
|
|
|
return True
|
|
|
|
except ImportError as e:
|
|
self._init_error = f"rembg 모듈이 설치되지 않음: {e}"
|
|
self._rembg_available = False
|
|
except Exception as e:
|
|
self._init_error = f"rembg 모듈 초기화 실패 (의존성/하드웨어 문제): {e}"
|
|
self._rembg_available = False
|
|
|
|
if self.logger:
|
|
self.logger.log(self._init_error, level=logging.ERROR)
|
|
return False
|
|
|
|
# ===================================================================================
|
|
# [핵심 수정] get_session 메서드를 아래 코드로 완전히 교체하세요.
|
|
# ===================================================================================
|
|
def get_session(self, model_name):
|
|
"""
|
|
모델별 세션을 캐싱하여 반환 (로컬 모델 경로 및 CUDA 지원 포함)
|
|
"""
|
|
if not self._check_rembg_availability():
|
|
if self.logger:
|
|
self.logger.log(f"rembg 사용 불가로 세션 생성 실패: {self._init_error}", level=logging.ERROR)
|
|
return None
|
|
|
|
cuda_enabled = self.gpu_manager and self.gpu_manager.can_use_cuda
|
|
# 세션 키에 로컬 모델 사용 여부도 반영
|
|
is_local = bool(self.local_rembg_model_path and os.path.exists(self.local_rembg_model_path))
|
|
# 실제 모델명을 세션 키에 사용
|
|
actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name)
|
|
session_key = f"{actual_model_name}_cuda_{cuda_enabled}_local_{is_local}"
|
|
|
|
if session_key not in self.sessions:
|
|
if self.logger:
|
|
self.logger.log(f"🔧 rembg 새 세션 생성 필요: {session_key}", level=logging.INFO)
|
|
try:
|
|
import rembg
|
|
|
|
# 안정적인 DirectML 기반 provider 설정
|
|
providers = []
|
|
if cuda_enabled: # cuda_enabled을 GPU 가속 플래그로 사용
|
|
try:
|
|
import onnxruntime as ort
|
|
available_providers = ort.get_available_providers()
|
|
if 'DmlExecutionProvider' in available_providers:
|
|
# DirectML 설정 (메모리 절약형)
|
|
dml_options = {
|
|
'device_id': 0, # 기본 GPU 장치 사용
|
|
'disable_metacommands': False, # 메타커맨드 활성화
|
|
'enable_dynamic_graph_fusion': False, # 🔧 메모리 절약을 위해 비활성화
|
|
'memory_type': 'default' # 기본 메모리 타입 사용
|
|
}
|
|
providers = [
|
|
('DmlExecutionProvider', dml_options),
|
|
('CPUExecutionProvider', {})
|
|
]
|
|
if self.logger: self.logger.log(f"rembg 세션 생성 (DirectML 최적화): {model_name}", level=logging.INFO)
|
|
else:
|
|
providers = [('CPUExecutionProvider', {})]
|
|
if self.logger: self.logger.log(f"rembg DirectML 미지원, CPU 모드: {model_name}", level=logging.WARNING)
|
|
except Exception as e:
|
|
providers = [('CPUExecutionProvider', {})]
|
|
if self.logger: self.logger.log(f"rembg DirectML 확인 실패: {e}, CPU 모드: {model_name}", level=logging.WARNING)
|
|
else:
|
|
providers = [('CPUExecutionProvider', {})]
|
|
if self.logger: self.logger.log(f"rembg 세션 생성 (CPU): {model_name}", level=logging.INFO)
|
|
|
|
session = None
|
|
|
|
# *** 로컬 모델 경로가 있으면 여기서 처리 ***
|
|
if is_local:
|
|
if self.logger:
|
|
self.logger.log(f"로컬 모델로 세션 생성: {self.local_rembg_model_path}", level=logging.INFO)
|
|
|
|
# 1. model_name에 맞는 rembg 세션 *클래스*를 가져옴 (e.g., BiRefNetSessionGeneral)
|
|
# 모델명 매핑을 통해 실제 sessions 키로 변환
|
|
actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name)
|
|
if actual_model_name not in sessions:
|
|
raise ValueError(f"지원하지 않는 모델명: {model_name} (매핑된 이름: {actual_model_name}). 사용 가능한 모델: {list(sessions.keys())}")
|
|
session_class = sessions[actual_model_name]
|
|
|
|
# 2. 클래스의 인스턴스를 생성하되, __init__을 호출하지 않아 모델 다운로드를 방지
|
|
session = session_class.__new__(session_class)
|
|
|
|
# 3. 로컬 ONNX 파일로 직접 onnxruntime 세션을 생성 (DirectML)
|
|
ort_session = None
|
|
if cuda_enabled:
|
|
fallback_attempts = [
|
|
(providers, "DirectML 가속"),
|
|
(["CPUExecutionProvider"], "CPU 폴백")
|
|
]
|
|
else:
|
|
fallback_attempts = [
|
|
(["CPUExecutionProvider"], "CPU 전용")
|
|
]
|
|
|
|
for attempt_providers, attempt_name in fallback_attempts:
|
|
if not attempt_providers: # 빈 리스트 건너뛰기
|
|
continue
|
|
try:
|
|
if self.logger:
|
|
self.logger.log(f"rembg 로컬 모델 {attempt_name} 시도: {attempt_providers}", level=logging.DEBUG)
|
|
|
|
# ONNX Runtime 세션 옵션 설정 (메모리 절약형)
|
|
sess_options = onnxruntime.SessionOptions()
|
|
sess_options.enable_mem_pattern = False # 🔧 메모리 절약을 위해 비활성화
|
|
sess_options.enable_cpu_mem_arena = True # CPU 메모리 아레나 (유지)
|
|
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC # 🔧 기본 최적화만 사용
|
|
|
|
# DirectML 사용시 메모리 절약 설정
|
|
if any('DmlExecutionProvider' in str(provider) for provider in attempt_providers):
|
|
sess_options.enable_profiling = False # 프로파일링 비활성화 (안정성)
|
|
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL # 순차 실행
|
|
|
|
ort_session = onnxruntime.InferenceSession(
|
|
self.local_rembg_model_path,
|
|
providers=attempt_providers,
|
|
sess_options=sess_options
|
|
)
|
|
actual_providers = ort_session.get_providers()
|
|
if self.logger:
|
|
self.logger.log(f"✅ rembg 로컬 모델 {attempt_name} 성공! 실제 providers: {actual_providers} ", level=logging.INFO)
|
|
break
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"rembg 로컬 모델 {attempt_name} 실패: {e}", level=logging.WARNING)
|
|
continue
|
|
|
|
if ort_session is None:
|
|
raise RuntimeError("rembg 로컬 모델 모든 폴백 시도 실패")
|
|
|
|
session.inner_session = ort_session
|
|
|
|
# 4. 세션에 필요한 다른 속성들을 수동으로 설정
|
|
session.model_name = model_name
|
|
session.providers = providers
|
|
|
|
# 로컬 모델 경로가 없으면 기존 방식으로 처리
|
|
else:
|
|
if self.logger: self.logger.log(f"내장 모델로 세션 생성: {actual_model_name}", level=logging.INFO)
|
|
session = rembg.new_session(model_name=actual_model_name, providers=providers)
|
|
|
|
self.sessions[session_key] = session
|
|
|
|
# 실제 사용된 provider 확인 및 로깅 (DirectML 지원 수정)
|
|
actual_providers = session.inner_session.get_providers()
|
|
if self.logger:
|
|
is_gpu = any('CUDA' in p or 'Tensorrt' in p or 'DmlExecutionProvider' in p for p in actual_providers)
|
|
status = "GPU 가속 활성화" if is_gpu else "CPU 모드로 동작"
|
|
self.logger.log(f"✅ rembg '{actual_model_name}' {status} (실제 providers: {actual_providers})", level=logging.INFO)
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"rembg 세션 생성 실패 ('{actual_model_name}'): {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|
|
else:
|
|
if self.logger:
|
|
self.logger.log(f"♻️ rembg 기존 세션 재사용: {session_key}", level=logging.DEBUG)
|
|
|
|
return self.sessions.get(session_key)
|
|
|
|
# ===================================================================================
|
|
# 이하 다른 메서드들은 수정할 필요 없습니다. (기존 코드 유지)
|
|
# ===================================================================================
|
|
|
|
def set_default_model(self, model_name):
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"지원하지 않는 모델명: {model_name}")
|
|
self.default_model = model_name
|
|
if self.logger:
|
|
self.logger.log(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨", level=logging.INFO)
|
|
|
|
def get_default_model(self):
|
|
return self.default_model
|
|
|
|
def get_supported_models(self):
|
|
return self.SUPPORTED_MODELS.copy()
|
|
|
|
def get_model_description(self, model_name):
|
|
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
|
|
|
|
def is_available(self):
|
|
return self._check_rembg_availability()
|
|
|
|
def get_init_error(self):
|
|
return self._init_error
|
|
|
|
def to_white_background(self, img: Image.Image) -> Image.Image:
|
|
if img.mode in ("RGBA", "BGRA"):
|
|
bg = Image.new("RGB", img.size, (255, 255, 255))
|
|
bg.paste(img, mask=img.split()[-1])
|
|
return bg
|
|
else:
|
|
return img.convert("RGB")
|
|
|
|
def remove_background(self, image_path, model_name=None, **kwargs):
|
|
if not self._check_rembg_availability():
|
|
if self.logger:
|
|
self.logger.log(f"rembg 사용 불가로 배경 제거 실패: {self._init_error}", level=logging.ERROR)
|
|
return None
|
|
|
|
if not os.path.exists(image_path):
|
|
if self.logger:
|
|
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}", level=logging.ERROR)
|
|
return None
|
|
|
|
try:
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
if self.logger:
|
|
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR)
|
|
return None
|
|
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
# [수정] 로컬 모델을 사용하려면, model_name을 파일과 맞는 이름으로 지정해야 합니다.
|
|
# 예를 들어, birefnet-general-lite.onnx 파일을 사용하려면 model_name='birefnet-general-lite'로 호출해야 합니다.
|
|
effective_model_name = model_name or self.default_model
|
|
|
|
if effective_model_name not in self.SUPPORTED_MODELS:
|
|
if self.logger:
|
|
self.logger.log(f"지원하지 않는 모델명: {effective_model_name}. u2net으로 대체 사용", level=logging.WARNING)
|
|
effective_model_name = "u2net"
|
|
|
|
session = self.get_session(effective_model_name)
|
|
if session is None:
|
|
return None
|
|
|
|
import rembg
|
|
import time
|
|
|
|
start_time = time.time()
|
|
result = rembg.remove(img_rgb, session=session, alpha_matting=kwargs.get("alpha_matting", False))
|
|
end_time = time.time()
|
|
|
|
if not isinstance(result, Image.Image):
|
|
result = Image.fromarray(result)
|
|
|
|
processing_time = end_time - start_time
|
|
if self.logger:
|
|
cuda_status = "CUDA" if (self.gpu_manager and self.gpu_manager.can_use_cuda) else "CPU"
|
|
self.logger.log(f"✅ 배경 제거 성공: {effective_model_name} ({cuda_status}, {processing_time:.2f}초)", level=logging.INFO)
|
|
|
|
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
|
self.gpu_manager.log_gpu_memory_usage()
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"배경 제거 처리 중 오류 ({model_name}): {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|
|
|
|
def _preload_sessions(self):
|
|
"""자주 사용되는 rembg 세션들을 미리 로딩하여 첫 번째 요청 시간을 단축"""
|
|
preload_models = ["u2net", "birefnet-general-lite"]
|
|
|
|
for model_name in preload_models:
|
|
try:
|
|
if self.logger:
|
|
self.logger.log(f"🔄 {model_name} 세션 미리 로딩 중...", level=logging.INFO)
|
|
|
|
# 세션 생성하여 캐시에 저장
|
|
session = self.get_session(model_name)
|
|
if session:
|
|
if self.logger:
|
|
self.logger.log(f"✅ {model_name} 세션 미리 로딩 완료", level=logging.INFO)
|
|
else:
|
|
if self.logger:
|
|
self.logger.log(f"⚠️ {model_name} 세션 로딩 실패", level=logging.WARNING)
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"❌ {model_name} 세션 미리 로딩 중 오류: {e}", level=logging.WARNING)
|