IMG_Worker/modules/old_modules/background_removal_module.py

370 lines
19 KiB
Python

import os
import cv2
from PIL import Image
import logging
import onnxruntime # [추가] ONNX 런타임 직접 사용을 위해 임포트
from rembg.sessions import sessions # [수정] rembg의 모델별 세션 클래스 딕셔너리 import
class BackgroundRemovalModule:
"""
rembg 기반 배경제거 모듈 (안전한 의존성 처리)
"""
# [수정] 사용하시려는 birefnet 모델을 지원 목록에 추가합니다.
SUPPORTED_MODELS = {
"u2net": "범용 배경제거 | 빠름 | 사람/사물 모두 양호 (기본값)",
"u2netp": "u2net 경량화 | 매우 빠름 | 실시간, 저사양PC",
"u2net_human_seg": "인물 전용 | 빠름 | 사람 경계 정밀",
"u2net_cloth_seg": "옷 전용 | 빠름 | 패션/의류 특화",
"isnet-general-use": "범용 고품질 | 느림 | 디테일 중시, 대용량",
"sam": "SAM 최고 품질 | 매우 느림 | 고성능PC 권장",
"sam-mobile": "SAM 경량화 | 보통 | 모바일, 중간성능",
"birefnet-general-lite": "BiRefNet 경량 모델 | 고품질 저용량 (로컬)"
}
# [추가] SUPPORTED_MODELS 키와 실제 rembg sessions 키 간의 매핑
MODEL_NAME_MAPPING = {
"u2net": "u2net",
"u2netp": "u2netp",
"u2net_human_seg": "u2net-human-seg",
"u2net_cloth_seg": "u2net-cloth-seg",
"isnet-general-use": "dis-general-use",
"sam": "sam",
"sam-mobile": "sam", # sam-mobile은 sam과 동일하게 처리
"birefnet-general-lite": "birefnet-general-lite"
}
def __init__(self, logger=None, default_model="u2net", gpu_manager=None, local_rembg_model_path: str | None = None):
self.logger = logger
self.default_model = default_model
self.gpu_manager = gpu_manager
self.local_rembg_model_path = local_rembg_model_path
self.sessions = {}
self._rembg_available = None
self._init_error = None
self._cuda_providers_tested = False
# _check_rembg_availability 메서드는 변경할 필요 없습니다. (기존 코드 유지)
def _check_rembg_availability(self):
"""rembg 모듈 사용 가능 여부를 확인하고 캐시"""
if self._rembg_available is not None:
return self._rembg_available
try:
import rembg
providers_to_test = []
if self.gpu_manager:
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
if 'DmlExecutionProvider' in available_providers:
providers_to_test = ['DmlExecutionProvider', 'CPUExecutionProvider']
if self.logger:
self.logger.log("rembg DirectML provider 테스트 시작", level=logging.INFO)
else:
providers_to_test = ['CPUExecutionProvider']
if self.logger:
self.logger.log("rembg DirectML 미지원, CPU 모드", level=logging.WARNING)
except Exception as e:
providers_to_test = ['CPUExecutionProvider']
if self.logger:
self.logger.log(f"rembg DirectML 확인 실패: {e}, CPU 모드", level=logging.WARNING)
else:
providers_to_test = ['CPUExecutionProvider']
if self.logger:
self.logger.log("rembg CPU-only 모드로 테스트", level=logging.INFO)
model_arg = 'u2net'
test_session = rembg.new_session(model_name=model_arg, providers=providers_to_test)
self._rembg_available = True
if hasattr(test_session, 'inner_session') and hasattr(test_session.inner_session, 'get_providers'):
actual_providers = test_session.inner_session.get_providers()
if self.logger:
self.logger.log(f"rembg 세션 생성 성공, 사용된 providers: {actual_providers}", level=logging.INFO)
if 'DmlExecutionProvider' in actual_providers:
self.logger.log("✅ rembg DirectML 가속 활성화됨", level=logging.INFO)
else:
self.logger.log("rembg CPU 모드로 동작", level=logging.INFO)
else:
if self.logger:
self.logger.log("rembg 모듈 사용 가능 확인됨", level=logging.INFO)
return True
except ImportError as e:
self._init_error = f"rembg 모듈이 설치되지 않음: {e}"
self._rembg_available = False
except Exception as e:
self._init_error = f"rembg 모듈 초기화 실패 (의존성/하드웨어 문제): {e}"
self._rembg_available = False
if self.logger:
self.logger.log(self._init_error, level=logging.ERROR)
return False
# ===================================================================================
# [핵심 수정] get_session 메서드를 아래 코드로 완전히 교체하세요.
# ===================================================================================
def get_session(self, model_name):
"""
모델별 세션을 캐싱하여 반환 (로컬 모델 경로 및 CUDA 지원 포함)
"""
if not self._check_rembg_availability():
if self.logger:
self.logger.log(f"rembg 사용 불가로 세션 생성 실패: {self._init_error}", level=logging.ERROR)
return None
cuda_enabled = self.gpu_manager and self.gpu_manager.can_use_cuda
# 세션 키에 로컬 모델 사용 여부도 반영
is_local = bool(self.local_rembg_model_path and os.path.exists(self.local_rembg_model_path))
# 실제 모델명을 세션 키에 사용
actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name)
session_key = f"{actual_model_name}_cuda_{cuda_enabled}_local_{is_local}"
if session_key not in self.sessions:
if self.logger:
self.logger.log(f"🔧 rembg 새 세션 생성 필요: {session_key}", level=logging.INFO)
try:
import rembg
# 안정적인 DirectML 기반 provider 설정
providers = []
if cuda_enabled: # cuda_enabled을 GPU 가속 플래그로 사용
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
if 'DmlExecutionProvider' in available_providers:
# DirectML 설정 (메모리 절약형)
dml_options = {
'device_id': 0, # 기본 GPU 장치 사용
'disable_metacommands': False, # 메타커맨드 활성화
'enable_dynamic_graph_fusion': False, # 🔧 메모리 절약을 위해 비활성화
'memory_type': 'default' # 기본 메모리 타입 사용
}
providers = [
('DmlExecutionProvider', dml_options),
('CPUExecutionProvider', {})
]
if self.logger: self.logger.log(f"rembg 세션 생성 (DirectML 최적화): {model_name}", level=logging.INFO)
else:
providers = [('CPUExecutionProvider', {})]
if self.logger: self.logger.log(f"rembg DirectML 미지원, CPU 모드: {model_name}", level=logging.WARNING)
except Exception as e:
providers = [('CPUExecutionProvider', {})]
if self.logger: self.logger.log(f"rembg DirectML 확인 실패: {e}, CPU 모드: {model_name}", level=logging.WARNING)
else:
providers = [('CPUExecutionProvider', {})]
if self.logger: self.logger.log(f"rembg 세션 생성 (CPU): {model_name}", level=logging.INFO)
session = None
# *** 로컬 모델 경로가 있으면 여기서 처리 ***
if is_local:
if self.logger:
self.logger.log(f"로컬 모델로 세션 생성: {self.local_rembg_model_path}", level=logging.INFO)
# 1. model_name에 맞는 rembg 세션 *클래스*를 가져옴 (e.g., BiRefNetSessionGeneral)
# 모델명 매핑을 통해 실제 sessions 키로 변환
actual_model_name = self.MODEL_NAME_MAPPING.get(model_name, model_name)
if actual_model_name not in sessions:
raise ValueError(f"지원하지 않는 모델명: {model_name} (매핑된 이름: {actual_model_name}). 사용 가능한 모델: {list(sessions.keys())}")
session_class = sessions[actual_model_name]
# 2. 클래스의 인스턴스를 생성하되, __init__을 호출하지 않아 모델 다운로드를 방지
session = session_class.__new__(session_class)
# 3. 로컬 ONNX 파일로 직접 onnxruntime 세션을 생성 (DirectML)
ort_session = None
if cuda_enabled:
fallback_attempts = [
(providers, "DirectML 가속"),
(["CPUExecutionProvider"], "CPU 폴백")
]
else:
fallback_attempts = [
(["CPUExecutionProvider"], "CPU 전용")
]
for attempt_providers, attempt_name in fallback_attempts:
if not attempt_providers: # 빈 리스트 건너뛰기
continue
try:
if self.logger:
self.logger.log(f"rembg 로컬 모델 {attempt_name} 시도: {attempt_providers}", level=logging.DEBUG)
# ONNX Runtime 세션 옵션 설정 (메모리 절약형)
sess_options = onnxruntime.SessionOptions()
sess_options.enable_mem_pattern = False # 🔧 메모리 절약을 위해 비활성화
sess_options.enable_cpu_mem_arena = True # CPU 메모리 아레나 (유지)
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC # 🔧 기본 최적화만 사용
# DirectML 사용시 메모리 절약 설정
if any('DmlExecutionProvider' in str(provider) for provider in attempt_providers):
sess_options.enable_profiling = False # 프로파일링 비활성화 (안정성)
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL # 순차 실행
ort_session = onnxruntime.InferenceSession(
self.local_rembg_model_path,
providers=attempt_providers,
sess_options=sess_options
)
actual_providers = ort_session.get_providers()
if self.logger:
self.logger.log(f"✅ rembg 로컬 모델 {attempt_name} 성공! 실제 providers: {actual_providers} ", level=logging.INFO)
break
except Exception as e:
if self.logger:
self.logger.log(f"rembg 로컬 모델 {attempt_name} 실패: {e}", level=logging.WARNING)
continue
if ort_session is None:
raise RuntimeError("rembg 로컬 모델 모든 폴백 시도 실패")
session.inner_session = ort_session
# 4. 세션에 필요한 다른 속성들을 수동으로 설정
session.model_name = model_name
session.providers = providers
# 로컬 모델 경로가 없으면 기존 방식으로 처리
else:
if self.logger: self.logger.log(f"내장 모델로 세션 생성: {actual_model_name}", level=logging.INFO)
session = rembg.new_session(model_name=actual_model_name, providers=providers)
self.sessions[session_key] = session
# 실제 사용된 provider 확인 및 로깅 (DirectML 지원 수정)
actual_providers = session.inner_session.get_providers()
if self.logger:
is_gpu = any('CUDA' in p or 'Tensorrt' in p or 'DmlExecutionProvider' in p for p in actual_providers)
status = "GPU 가속 활성화" if is_gpu else "CPU 모드로 동작"
self.logger.log(f"✅ rembg '{actual_model_name}' {status} (실제 providers: {actual_providers})", level=logging.INFO)
except Exception as e:
if self.logger:
self.logger.log(f"rembg 세션 생성 실패 ('{actual_model_name}'): {e}", level=logging.ERROR, exc_info=True)
return None
else:
if self.logger:
self.logger.log(f"♻️ rembg 기존 세션 재사용: {session_key}", level=logging.DEBUG)
return self.sessions.get(session_key)
# ===================================================================================
# 이하 다른 메서드들은 수정할 필요 없습니다. (기존 코드 유지)
# ===================================================================================
def set_default_model(self, model_name):
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"지원하지 않는 모델명: {model_name}")
self.default_model = model_name
if self.logger:
self.logger.log(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨", level=logging.INFO)
def get_default_model(self):
return self.default_model
def get_supported_models(self):
return self.SUPPORTED_MODELS.copy()
def get_model_description(self, model_name):
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
def is_available(self):
return self._check_rembg_availability()
def get_init_error(self):
return self._init_error
def to_white_background(self, img: Image.Image) -> Image.Image:
if img.mode in ("RGBA", "BGRA"):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[-1])
return bg
else:
return img.convert("RGB")
def remove_background(self, image_path, model_name=None, **kwargs):
if not self._check_rembg_availability():
if self.logger:
self.logger.log(f"rembg 사용 불가로 배경 제거 실패: {self._init_error}", level=logging.ERROR)
return None
if not os.path.exists(image_path):
if self.logger:
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}", level=logging.ERROR)
return None
try:
img = cv2.imread(image_path)
if img is None:
if self.logger:
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR)
return None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# [수정] 로컬 모델을 사용하려면, model_name을 파일과 맞는 이름으로 지정해야 합니다.
# 예를 들어, birefnet-general-lite.onnx 파일을 사용하려면 model_name='birefnet-general-lite'로 호출해야 합니다.
effective_model_name = model_name or self.default_model
if effective_model_name not in self.SUPPORTED_MODELS:
if self.logger:
self.logger.log(f"지원하지 않는 모델명: {effective_model_name}. u2net으로 대체 사용", level=logging.WARNING)
effective_model_name = "u2net"
session = self.get_session(effective_model_name)
if session is None:
return None
import rembg
import time
start_time = time.time()
result = rembg.remove(img_rgb, session=session, alpha_matting=kwargs.get("alpha_matting", False))
end_time = time.time()
if not isinstance(result, Image.Image):
result = Image.fromarray(result)
processing_time = end_time - start_time
if self.logger:
cuda_status = "CUDA" if (self.gpu_manager and self.gpu_manager.can_use_cuda) else "CPU"
self.logger.log(f"✅ 배경 제거 성공: {effective_model_name} ({cuda_status}, {processing_time:.2f}초)", level=logging.INFO)
if self.gpu_manager and self.gpu_manager.can_use_cuda:
self.gpu_manager.log_gpu_memory_usage()
return result
except Exception as e:
if self.logger:
self.logger.log(f"배경 제거 처리 중 오류 ({model_name}): {e}", level=logging.ERROR, exc_info=True)
return None
def _preload_sessions(self):
"""자주 사용되는 rembg 세션들을 미리 로딩하여 첫 번째 요청 시간을 단축"""
preload_models = ["u2net", "birefnet-general-lite"]
for model_name in preload_models:
try:
if self.logger:
self.logger.log(f"🔄 {model_name} 세션 미리 로딩 중...", level=logging.INFO)
# 세션 생성하여 캐시에 저장
session = self.get_session(model_name)
if session:
if self.logger:
self.logger.log(f"{model_name} 세션 미리 로딩 완료", level=logging.INFO)
else:
if self.logger:
self.logger.log(f"⚠️ {model_name} 세션 로딩 실패", level=logging.WARNING)
except Exception as e:
if self.logger:
self.logger.log(f"{model_name} 세션 미리 로딩 중 오류: {e}", level=logging.WARNING)