import os import cv2 from PIL import Image import logging import onnxruntime # [추가] ONNX 런타임 직접 사용을 위해 임포트 from rembg.sessions import sessions # [수정] rembg의 모델별 세션 클래스 딕셔너리 import class BackgroundRemovalModule: """ rembg 기반 배경제거 모듈 (안전한 의존성 처리) """ # [수정] 사용하시려는 birefnet 모델을 지원 목록에 추가합니다. SUPPORTED_MODELS = { "u2net": "범용 배경제거 | 빠름 | 사람/사물 모두 양호 (기본값)", "u2netp": "u2net 경량화 | 매우 빠름 | 실시간, 저사양PC", "u2net_human_seg": "인물 전용 | 빠름 | 사람 경계 정밀", "u2net_cloth_seg": "옷 전용 | 빠름 | 패션/의류 특화", "isnet-general-use": "범용 고품질 | 느림 | 디테일 중시, 대용량", "sam": "SAM 최고 품질 | 매우 느림 | 고성능PC 권장", "sam-mobile": "SAM 경량화 | 보통 | 모바일, 중간성능", "birefnet-general-lite": "BiRefNet 경량 모델 | 고품질 저용량 (로컬)" } # [추가] SUPPORTED_MODELS 키와 실제 rembg sessions 키 간의 매핑 MODEL_NAME_MAPPING = { "u2net": "u2net", "u2netp": "u2netp", "u2net_human_seg": "u2net-human-seg", "u2net_cloth_seg": "u2net-cloth-seg", "isnet-general-use": "dis-general-use", "sam": "sam", "sam-mobile": "sam", # sam-mobile은 sam과 동일하게 처리 "birefnet-general-lite": "birefnet-general-lite" } def __init__(self, logger=None, default_model="u2net", gpu_manager=None, local_rembg_model_path: str | None = None): self.logger = logger self.default_model = default_model self.gpu_manager = gpu_manager self.local_rembg_model_path = local_rembg_model_path self.sessions = {} self._rembg_available = None self._init_error = None self._cuda_providers_tested = False # _check_rembg_availability 메서드는 변경할 필요 없습니다. (기존 코드 유지) def _check_rembg_availability(self): """rembg 모듈 사용 가능 여부를 확인하고 캐시""" if self._rembg_available is not None: return self._rembg_available try: import rembg providers_to_test = [] if self.gpu_manager: try: import onnxruntime as ort available_providers = ort.get_available_providers() if 'DmlExecutionProvider' in available_providers: providers_to_test = ['DmlExecutionProvider', 'CPUExecutionProvider'] if self.logger: self.logger.log("rembg DirectML provider 테스트 시작", level=logging.INFO) else: providers_to_test = ['CPUExecutionProvider'] if self.logger: self.logger.log("rembg DirectML 미지원, CPU 모드", level=logging.WARNING) except Exception as e: providers_to_test = ['CPUExecutionProvider'] if self.logger: self.logger.log(f"rembg DirectML 확인 실패: {e}, CPU 모드", level=logging.WARNING) else: providers_to_test = ['CPUExecutionProvider'] if self.logger: self.logger.log("rembg CPU-only 모드로 테스트", level=logging.INFO) model_arg = 'u2net' test_session = rembg.new_session(model_name=model_arg, providers=providers_to_test) self._rembg_available = True if hasattr(test_session, 'inner_session') and hasattr(test_session.inner_session, 'get_providers'): actual_providers = test_session.inner_session.get_providers() if self.logger: self.logger.log(f"rembg 세션 생성 성공, 사용된 providers: {actual_providers}", level=logging.INFO) if 'DmlExecutionProvider' in actual_providers: self.logger.log("✅ rembg DirectML 가속 활성화됨", level=logging.INFO) else: self.logger.log("rembg CPU 모드로 동작", level=logging.INFO) else: if self.logger: self.logger.log("rembg 모듈 사용 가능 확인됨", level=logging.INFO) return True except ImportError as e: self._init_error = f"rembg 모듈이 설치되지 않음: {e}" self._rembg_available = False except Exception as e: self._init_error = f"rembg 모듈 초기화 실패 (의존성/하드웨어 문제): {e}" self._rembg_available = False if self.logger: self.logger.log(self._init_error, level=logging.ERROR) return False # =================================================================================== # [핵심 수정] get_session 메서드를 아래 코드로 완전히 교체하세요. # =================================================================================== def get_session(self, model_name): """ 모델별 세션을 캐싱하여 반환 (로컬 모델 경로 및 CUDA 지원 포함) """ if not self._check_rembg_availability(): if self.logger: self.logger.log(f"rembg 사용 불가로 세션 생성 실패: {self._init_error}", level=logging.ERROR) return None cuda_enabled = self.gpu_manager and self.gpu_manager.can_use_cuda # 세션 키에 로컬 모델 사용 여부도 반영 is_local = bool(self.local_rembg_model_path and os.path.exists(self.local_rembg_model_path)) # 실제 모델명을 세션 키에 사용 actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name) session_key = f"{actual_model_name}_cuda_{cuda_enabled}_local_{is_local}" if session_key not in self.sessions: if self.logger: self.logger.log(f"🔧 rembg 새 세션 생성 필요: {session_key}", level=logging.INFO) try: import rembg # 안정적인 DirectML 기반 provider 설정 providers = [] if cuda_enabled: # cuda_enabled을 GPU 가속 플래그로 사용 try: import onnxruntime as ort available_providers = ort.get_available_providers() if 'DmlExecutionProvider' in available_providers: # DirectML 설정 (메모리 절약형) dml_options = { 'device_id': 0, # 기본 GPU 장치 사용 'disable_metacommands': False, # 메타커맨드 활성화 'enable_dynamic_graph_fusion': False, # 🔧 메모리 절약을 위해 비활성화 'memory_type': 'default' # 기본 메모리 타입 사용 } providers = [ ('DmlExecutionProvider', dml_options), ('CPUExecutionProvider', {}) ] if self.logger: self.logger.log(f"rembg 세션 생성 (DirectML 최적화): {model_name}", level=logging.INFO) else: providers = [('CPUExecutionProvider', {})] if self.logger: self.logger.log(f"rembg DirectML 미지원, CPU 모드: {model_name}", level=logging.WARNING) except Exception as e: providers = [('CPUExecutionProvider', {})] if self.logger: self.logger.log(f"rembg DirectML 확인 실패: {e}, CPU 모드: {model_name}", level=logging.WARNING) else: providers = [('CPUExecutionProvider', {})] if self.logger: self.logger.log(f"rembg 세션 생성 (CPU): {model_name}", level=logging.INFO) session = None # *** 로컬 모델 경로가 있으면 여기서 처리 *** if is_local: if self.logger: self.logger.log(f"로컬 모델로 세션 생성: {self.local_rembg_model_path}", level=logging.INFO) # 1. model_name에 맞는 rembg 세션 *클래스*를 가져옴 (e.g., BiRefNetSessionGeneral) # 모델명 매핑을 통해 실제 sessions 키로 변환 actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name) if actual_model_name not in sessions: raise ValueError(f"지원하지 않는 모델명: {model_name} (매핑된 이름: {actual_model_name}). 사용 가능한 모델: {list(sessions.keys())}") session_class = sessions[actual_model_name] # 2. 클래스의 인스턴스를 생성하되, __init__을 호출하지 않아 모델 다운로드를 방지 session = session_class.__new__(session_class) # 3. 로컬 ONNX 파일로 직접 onnxruntime 세션을 생성 (DirectML) ort_session = None if cuda_enabled: fallback_attempts = [ (providers, "DirectML 가속"), (["CPUExecutionProvider"], "CPU 폴백") ] else: fallback_attempts = [ (["CPUExecutionProvider"], "CPU 전용") ] for attempt_providers, attempt_name in fallback_attempts: if not attempt_providers: # 빈 리스트 건너뛰기 continue try: if self.logger: self.logger.log(f"rembg 로컬 모델 {attempt_name} 시도: {attempt_providers}", level=logging.DEBUG) # ONNX Runtime 세션 옵션 설정 (메모리 절약형) sess_options = onnxruntime.SessionOptions() sess_options.enable_mem_pattern = False # 🔧 메모리 절약을 위해 비활성화 sess_options.enable_cpu_mem_arena = True # CPU 메모리 아레나 (유지) sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC # 🔧 기본 최적화만 사용 # DirectML 사용시 메모리 절약 설정 if any('DmlExecutionProvider' in str(provider) for provider in attempt_providers): sess_options.enable_profiling = False # 프로파일링 비활성화 (안정성) sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL # 순차 실행 ort_session = onnxruntime.InferenceSession( self.local_rembg_model_path, providers=attempt_providers, sess_options=sess_options ) actual_providers = ort_session.get_providers() if self.logger: self.logger.log(f"✅ rembg 로컬 모델 {attempt_name} 성공! 실제 providers: {actual_providers} ", level=logging.INFO) break except Exception as e: if self.logger: self.logger.log(f"rembg 로컬 모델 {attempt_name} 실패: {e}", level=logging.WARNING) continue if ort_session is None: raise RuntimeError("rembg 로컬 모델 모든 폴백 시도 실패") session.inner_session = ort_session # 4. 세션에 필요한 다른 속성들을 수동으로 설정 session.model_name = model_name session.providers = providers # 로컬 모델 경로가 없으면 기존 방식으로 처리 else: if self.logger: self.logger.log(f"내장 모델로 세션 생성: {actual_model_name}", level=logging.INFO) session = rembg.new_session(model_name=actual_model_name, providers=providers) self.sessions[session_key] = session # 실제 사용된 provider 확인 및 로깅 (DirectML 지원 수정) actual_providers = session.inner_session.get_providers() if self.logger: is_gpu = any('CUDA' in p or 'Tensorrt' in p or 'DmlExecutionProvider' in p for p in actual_providers) status = "GPU 가속 활성화" if is_gpu else "CPU 모드로 동작" self.logger.log(f"✅ rembg '{actual_model_name}' {status} (실제 providers: {actual_providers})", level=logging.INFO) except Exception as e: if self.logger: self.logger.log(f"rembg 세션 생성 실패 ('{actual_model_name}'): {e}", level=logging.ERROR, exc_info=True) return None else: if self.logger: self.logger.log(f"♻️ rembg 기존 세션 재사용: {session_key}", level=logging.DEBUG) return self.sessions.get(session_key) # =================================================================================== # 이하 다른 메서드들은 수정할 필요 없습니다. (기존 코드 유지) # =================================================================================== def set_default_model(self, model_name): if model_name not in self.SUPPORTED_MODELS: raise ValueError(f"지원하지 않는 모델명: {model_name}") self.default_model = model_name if self.logger: self.logger.log(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨", level=logging.INFO) def get_default_model(self): return self.default_model def get_supported_models(self): return self.SUPPORTED_MODELS.copy() def get_model_description(self, model_name): return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음") def is_available(self): return self._check_rembg_availability() def get_init_error(self): return self._init_error def to_white_background(self, img: Image.Image) -> Image.Image: if img.mode in ("RGBA", "BGRA"): bg = Image.new("RGB", img.size, (255, 255, 255)) bg.paste(img, mask=img.split()[-1]) return bg else: return img.convert("RGB") def remove_background(self, image_path, model_name=None, **kwargs): if not self._check_rembg_availability(): if self.logger: self.logger.log(f"rembg 사용 불가로 배경 제거 실패: {self._init_error}", level=logging.ERROR) return None if not os.path.exists(image_path): if self.logger: self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}", level=logging.ERROR) return None try: img = cv2.imread(image_path) if img is None: if self.logger: self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR) return None img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # [수정] 로컬 모델을 사용하려면, model_name을 파일과 맞는 이름으로 지정해야 합니다. # 예를 들어, birefnet-general-lite.onnx 파일을 사용하려면 model_name='birefnet-general-lite'로 호출해야 합니다. effective_model_name = model_name or self.default_model if effective_model_name not in self.SUPPORTED_MODELS: if self.logger: self.logger.log(f"지원하지 않는 모델명: {effective_model_name}. u2net으로 대체 사용", level=logging.WARNING) effective_model_name = "u2net" session = self.get_session(effective_model_name) if session is None: return None import rembg import time start_time = time.time() result = rembg.remove(img_rgb, session=session, alpha_matting=kwargs.get("alpha_matting", False)) end_time = time.time() if not isinstance(result, Image.Image): result = Image.fromarray(result) processing_time = end_time - start_time if self.logger: cuda_status = "CUDA" if (self.gpu_manager and self.gpu_manager.can_use_cuda) else "CPU" self.logger.log(f"✅ 배경 제거 성공: {effective_model_name} ({cuda_status}, {processing_time:.2f}초)", level=logging.INFO) if self.gpu_manager and self.gpu_manager.can_use_cuda: self.gpu_manager.log_gpu_memory_usage() return result except Exception as e: if self.logger: self.logger.log(f"배경 제거 처리 중 오류 ({model_name}): {e}", level=logging.ERROR, exc_info=True) return None def _preload_sessions(self): """자주 사용되는 rembg 세션들을 미리 로딩하여 첫 번째 요청 시간을 단축""" preload_models = ["u2net", "birefnet-general-lite"] for model_name in preload_models: try: if self.logger: self.logger.log(f"🔄 {model_name} 세션 미리 로딩 중...", level=logging.INFO) # 세션 생성하여 캐시에 저장 session = self.get_session(model_name) if session: if self.logger: self.logger.log(f"✅ {model_name} 세션 미리 로딩 완료", level=logging.INFO) else: if self.logger: self.logger.log(f"⚠️ {model_name} 세션 로딩 실패", level=logging.WARNING) except Exception as e: if self.logger: self.logger.log(f"❌ {model_name} 세션 미리 로딩 중 오류: {e}", level=logging.WARNING)