777 lines
35 KiB
Python
777 lines
35 KiB
Python
"""
|
|
BriaAI RMBG 1.4 ONNXRuntime 기반 배경제거 모듈
|
|
기존 background_removal_module.py를 완전 대체할 수 있는 호환 인터페이스 제공
|
|
"""
|
|
import os
|
|
import cv2
|
|
from PIL import Image
|
|
import logging
|
|
import numpy as np
|
|
import time
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
class BriaBackgroundRemovalModule:
|
|
"""
|
|
BriaAI RMBG 1.4 ONNX 모델 기반 배경제거 모듈
|
|
기존 BackgroundRemovalModule과 완전 호환되는 인터페이스 제공
|
|
"""
|
|
|
|
# 지원하는 모델 목록 (BriaAI 기반)
|
|
SUPPORTED_MODELS = {
|
|
"bria-rmbg-1.4": "BriaAI RMBG 1.4 | 고품질 | 빠름 | 범용 (기본값)",
|
|
"bria-rmbg-aggressive": "BriaAI RMBG 1.4 | 강력한 배경제거 | 빠름 | 깔끔한 결과",
|
|
"bria-rmbg-gentle": "BriaAI RMBG 1.4 | 부드러운 배경제거 | 빠름 | 세밀한 경계"
|
|
}
|
|
|
|
# 모델별 aggressiveness 매핑
|
|
MODEL_AGGRESSIVENESS = {
|
|
"bria-rmbg-1.4": 0.5, # 기본값
|
|
"bria-rmbg-aggressive": 0.8, # 강력한 배경제거
|
|
"bria-rmbg-gentle": 0.2 # 부드러운 배경제거
|
|
}
|
|
|
|
def __init__(self, logger=None, default_model="bria-rmbg-1.4", gpu_manager=None, local_rembg_model_path: str | None = None):
|
|
self.logger = logger
|
|
self.default_model = default_model
|
|
self.gpu_manager = gpu_manager
|
|
self.local_model_path = local_rembg_model_path # BriaAI ONNX 모델 경로
|
|
|
|
# ONNX 세션 관련
|
|
self._session: Optional = None
|
|
self._input_name: Optional[str] = None
|
|
self._output_name: Optional[str] = None
|
|
self._model_input_size: Tuple[int, int] = (1024, 1024) # (W, H)
|
|
self._model_loaded = False
|
|
self._init_error = None
|
|
|
|
if self.logger:
|
|
self.logger.log("BriaAI 배경제거 모듈 초기화 시작", level=logging.INFO)
|
|
|
|
# ONNX Runtime 사용 가능성 확인
|
|
self._check_onnxruntime_availability()
|
|
|
|
if self.logger:
|
|
self.logger.log("BriaAI 배경제거 모듈 초기화 완료", level=logging.INFO)
|
|
|
|
def _check_onnxruntime_availability(self):
|
|
"""ONNX Runtime 사용 가능 여부 확인"""
|
|
try:
|
|
import onnxruntime as ort
|
|
|
|
# 사용 가능한 프로바이더 확인
|
|
available_providers = ort.get_available_providers()
|
|
if self.logger:
|
|
self.logger.log(f"ONNX Runtime 사용 가능한 프로바이더: {available_providers}", level=logging.INFO)
|
|
|
|
# Override/캐시 우선 결정
|
|
provider_override = os.environ.get('IMGWK_REMBG_PROVIDER', 'auto').lower()
|
|
cached_provider = self._read_provider_cache()
|
|
self.providers = ['CPUExecutionProvider']
|
|
|
|
# GPU 가속 설정 (BriaAI는 가벼워서 1GB도 충분)
|
|
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
|
gpu_memory_mb = int(getattr(self.gpu_manager, 'gpu_memory_total', 0) or 0)
|
|
|
|
def can_use_dml():
|
|
return ('DmlExecutionProvider' in available_providers) and (gpu_memory_mb >= 1024)
|
|
|
|
def can_use_cuda():
|
|
return ('CUDAExecutionProvider' in available_providers) and (gpu_memory_mb >= 1024)
|
|
|
|
if provider_override == 'cpu':
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("BriaAI provider override=cpu", level=logging.INFO)
|
|
elif provider_override == 'dml':
|
|
if can_use_dml():
|
|
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
|
else:
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI provider override=dml → {self.providers}", level=logging.INFO)
|
|
elif cached_provider == 'dml':
|
|
if can_use_dml():
|
|
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("BriaAI provider cache=dml 적용", level=logging.INFO)
|
|
else:
|
|
self.providers = ['CPUExecutionProvider']
|
|
else:
|
|
# auto: DML → CUDA → CPU 순으로 선택
|
|
if can_use_dml():
|
|
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI DirectML 가속 사용 가능 (VRAM: {gpu_memory_mb}MB)", level=logging.INFO)
|
|
elif can_use_cuda():
|
|
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI CUDA 가속 사용 가능 (VRAM: {gpu_memory_mb}MB)", level=logging.INFO)
|
|
else:
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"VRAM/Provider 조건 불충족으로 CPU 모드 사용 (VRAM: {gpu_memory_mb}MB)", level=logging.WARNING)
|
|
else:
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("BriaAI CPU 모드로 설정", level=logging.INFO)
|
|
|
|
self._onnxruntime_available = True
|
|
return True
|
|
|
|
except ImportError as e:
|
|
self._init_error = f"ONNX Runtime이 설치되지 않음: {e}"
|
|
self._onnxruntime_available = False
|
|
if self.logger:
|
|
self.logger.log(self._init_error, level=logging.ERROR)
|
|
return False
|
|
except Exception as e:
|
|
self._init_error = f"ONNX Runtime 초기화 실패: {e}"
|
|
self._onnxruntime_available = False
|
|
if self.logger:
|
|
self.logger.log(self._init_error, level=logging.ERROR)
|
|
return False
|
|
|
|
def _load_model(self) -> bool:
|
|
"""BriaAI ONNX 모델을 로드합니다."""
|
|
if self._model_loaded:
|
|
return True
|
|
|
|
if not self._onnxruntime_available:
|
|
if self.logger:
|
|
self.logger.log("ONNX Runtime을 사용할 수 없어 모델 로드 실패", level=logging.ERROR)
|
|
return False
|
|
|
|
if not self.local_model_path or not os.path.exists(self.local_model_path):
|
|
self._init_error = f"BriaAI 모델 파일을 찾을 수 없음: {self.local_model_path}"
|
|
if self.logger:
|
|
self.logger.log(self._init_error, level=logging.ERROR)
|
|
return False
|
|
|
|
# 단계별 폴백 시도
|
|
fallback_providers = [
|
|
(self.providers, "원본 providers"),
|
|
([("CPUExecutionProvider", {})], "CPU 폴백")
|
|
]
|
|
|
|
for attempt_providers, attempt_name in fallback_providers:
|
|
try:
|
|
import onnxruntime as ort
|
|
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI ONNX 모델 로딩 시도 ({attempt_name}): {self.local_model_path}", level=logging.INFO)
|
|
|
|
# 세션 옵션 설정 (안전성 우선)
|
|
sess_options = ort.SessionOptions()
|
|
|
|
# CPU 모드에서는 더 보수적인 설정
|
|
if attempt_providers == [("CPUExecutionProvider", {})]:
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
sess_options.enable_mem_pattern = False
|
|
sess_options.enable_cpu_mem_arena = False
|
|
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
if self.logger:
|
|
self.logger.log("🔒 CPU 안전 모드: 모든 최적화 비활성화", level=logging.INFO)
|
|
else:
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
sess_options.enable_mem_pattern = True
|
|
sess_options.enable_cpu_mem_arena = True
|
|
|
|
# DirectML 사용시 추가 설정
|
|
if 'DmlExecutionProvider' in str(attempt_providers):
|
|
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
if self.logger:
|
|
self.logger.log("DirectML 최적화 설정 적용", level=logging.DEBUG)
|
|
|
|
# ONNX 세션 생성 (타임아웃 설정)
|
|
import signal
|
|
|
|
def timeout_handler(signum, frame):
|
|
raise TimeoutError("모델 로딩 타임아웃")
|
|
|
|
# Windows에서는 signal.alarm이 작동하지 않으므로 threading 사용
|
|
import threading
|
|
import time
|
|
|
|
session_created = [False]
|
|
session_result = [None]
|
|
session_error = [None]
|
|
|
|
def create_session():
|
|
try:
|
|
session_result[0] = ort.InferenceSession(
|
|
self.local_model_path,
|
|
sess_options=sess_options,
|
|
providers=attempt_providers
|
|
)
|
|
session_created[0] = True
|
|
except Exception as e:
|
|
session_error[0] = e
|
|
|
|
# 세션 생성을 별도 스레드에서 실행 (30초 타임아웃)
|
|
session_thread = threading.Thread(target=create_session)
|
|
session_thread.daemon = True
|
|
session_thread.start()
|
|
session_thread.join(timeout=30)
|
|
|
|
if not session_created[0]:
|
|
if session_error[0]:
|
|
raise session_error[0]
|
|
else:
|
|
raise TimeoutError("모델 로딩이 30초 내에 완료되지 않음")
|
|
|
|
self._session = session_result[0]
|
|
|
|
# 입출력 정보 가져오기
|
|
inputs = self._session.get_inputs()
|
|
outputs = self._session.get_outputs()
|
|
|
|
if not inputs or not outputs:
|
|
raise RuntimeError("ONNX 모델의 입출력 정의를 찾을 수 없습니다")
|
|
|
|
# 입력/출력 이름 설정
|
|
self._input_name = inputs[0].name
|
|
self._output_name = outputs[0].name
|
|
|
|
# 실제 사용된 프로바이더 확인
|
|
actual_providers = self._session.get_providers()
|
|
|
|
if self.logger:
|
|
self.logger.log(
|
|
f"✅ BriaAI ONNX 모델 로딩 완료 ({attempt_name}) | "
|
|
f"Providers: {actual_providers} | "
|
|
f"Input: {self._input_name} | Output: {self._output_name}",
|
|
level=logging.INFO
|
|
)
|
|
|
|
self._model_loaded = True
|
|
self._providers_used = actual_providers
|
|
# 성공 프로바이더 캐시 기록
|
|
try:
|
|
prov = 'dml' if any('Dml' in p for p in actual_providers) else 'cpu'
|
|
self._write_provider_cache(prov)
|
|
except Exception:
|
|
pass
|
|
return True
|
|
|
|
except Exception as e:
|
|
error_msg = f"BriaAI ONNX 모델 로딩 실패 ({attempt_name}): {e}"
|
|
if self.logger:
|
|
self.logger.log(error_msg, level=logging.WARNING if attempt_name != "CPU 폴백" else logging.ERROR, exc_info=True)
|
|
|
|
# 마지막 시도가 아니면 계속 진행
|
|
if attempt_name != "CPU 폴백":
|
|
continue
|
|
|
|
# 모든 시도 실패
|
|
self._init_error = "모든 provider에서 BriaAI ONNX 모델 로딩 실패"
|
|
if self.logger:
|
|
self.logger.log(self._init_error, level=logging.ERROR)
|
|
return False
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# 프로바이더 캐시 유틸
|
|
# ─────────────────────────────────────────────────────────────
|
|
def _cache_path(self):
|
|
try:
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
user_data = os.path.join(root_dir, 'user_data')
|
|
os.makedirs(user_data, exist_ok=True)
|
|
return os.path.join(user_data, 'rembg_provider.json')
|
|
except Exception:
|
|
return None
|
|
|
|
def _read_provider_cache(self):
|
|
try:
|
|
path = self._cache_path()
|
|
if not path or not os.path.exists(path):
|
|
return None
|
|
import json
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
prov = (data or {}).get('last_success_provider', '').lower()
|
|
return prov if prov in ('dml', 'cpu') else None
|
|
except Exception:
|
|
return None
|
|
|
|
def _write_provider_cache(self, provider: str):
|
|
try:
|
|
path = self._cache_path()
|
|
if not path:
|
|
return
|
|
import json
|
|
with open(path, 'w', encoding='utf-8') as f:
|
|
json.dump({"last_success_provider": provider}, f, ensure_ascii=False)
|
|
except Exception:
|
|
pass
|
|
|
|
def _preprocess(self, image_bgr: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
|
"""BGR uint8 이미지를 모델 입력(NCHW float32, 정규화)로 변환 (허깅페이스 호환)"""
|
|
orig_h, orig_w = image_bgr.shape[:2]
|
|
|
|
# 입력 검증
|
|
if len(image_bgr.shape) != 3 or image_bgr.shape[2] != 3:
|
|
raise ValueError(f"입력 이미지는 3채널 BGR이어야 합니다. 현재: {image_bgr.shape}")
|
|
|
|
# BGR -> RGB (허깅페이스는 RGB 입력 가정)
|
|
image_rgb = image_bgr[:, :, ::-1].copy()
|
|
|
|
# 리사이즈 (W,H) - 허깅페이스와 동일한 bilinear 보간
|
|
target_w, target_h = self._model_input_size
|
|
resized = cv2.resize(image_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# 허깅페이스 방식: float32 변환 후 정규화
|
|
# 1. [0,255] -> [0,1]
|
|
tensor = resized.astype(np.float32) / 255.0
|
|
|
|
# 2. normalize(mean=[0.5,0.5,0.5], std=[1.0,1.0,1.0]) -> (x - 0.5) / 1.0
|
|
tensor = tensor - 0.5
|
|
|
|
# 3. HWC -> CHW, 배치 축 추가 (NCHW)
|
|
nchw = np.transpose(tensor, (2, 0, 1))[np.newaxis, ...]
|
|
|
|
# 디버그 정보
|
|
if self.logger:
|
|
self.logger.log(f"전처리 완료: {image_bgr.shape} -> {nchw.shape}, 값 범위: [{nchw.min():.3f}, {nchw.max():.3f}]", level=logging.DEBUG)
|
|
|
|
return nchw, (orig_h, orig_w)
|
|
|
|
def _infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
|
"""ONNX 추론 수행 후 [H,W] 마스크 확률맵 반환 (BriaAI 모델 특화)"""
|
|
outputs = self._session.run(None, {self._input_name: input_tensor})
|
|
|
|
# BriaAI 모델은 여러 출력을 가질 수 있음 (6개 side outputs)
|
|
# 첫 번째 출력(d1)이 가장 정확한 결과
|
|
if isinstance(outputs, (list, tuple)) and len(outputs) > 0:
|
|
pred = outputs[0] # 첫 번째 출력 사용 (d1)
|
|
if self.logger:
|
|
self.logger.log(f"ONNX 모델 출력 개수: {len(outputs)}, 첫 번째 출력 shape: {pred.shape}", level=logging.DEBUG)
|
|
else:
|
|
pred = outputs
|
|
|
|
# numpy array로 변환
|
|
pred = np.array(pred)
|
|
|
|
# 차원 정리: [B, C, H, W] -> [H, W]
|
|
if pred.ndim == 4:
|
|
# [1, 1, H, W] -> [H, W]
|
|
pred = pred[0, 0]
|
|
elif pred.ndim == 3:
|
|
if pred.shape[0] == 1:
|
|
# [1, H, W] -> [H, W]
|
|
pred = pred[0]
|
|
else:
|
|
# [C, H, W] -> [H, W] (첫 번째 채널 사용)
|
|
pred = pred[0]
|
|
elif pred.ndim == 2:
|
|
# 이미 [H, W] 형태
|
|
pass
|
|
else:
|
|
raise ValueError(f"예상하지 못한 출력 차원: {pred.shape}")
|
|
|
|
# 확률값 범위 확인 및 정규화 (0~1 사이가 아닐 경우)
|
|
if pred.max() > 1.0 or pred.min() < 0.0:
|
|
if self.logger:
|
|
self.logger.log(f"출력 값 범위 이상: [{pred.min():.3f}, {pred.max():.3f}] -> sigmoid 적용", level=logging.WARNING)
|
|
# sigmoid 함수 적용 (모델에서 sigmoid가 적용되지 않은 경우)
|
|
pred = 1.0 / (1.0 + np.exp(-pred))
|
|
|
|
if self.logger:
|
|
self.logger.log(f"추론 완료: {pred.shape}, 값 범위: [{pred.min():.3f}, {pred.max():.3f}]", level=logging.DEBUG)
|
|
|
|
return pred
|
|
|
|
def _postprocess(self, mask_pred: np.ndarray, orig_size: Tuple[int, int], aggressiveness: float = 0.5) -> np.ndarray:
|
|
"""모델 출력 마스크를 원본 해상도로 보간하고 0..255 uint8로 변환 (허깅페이스 방식)"""
|
|
orig_h, orig_w = orig_size
|
|
|
|
# 모델 출력(H,W)을 원본 크기로 리사이즈 (W,H) - 허깅페이스와 동일하게 bilinear 보간
|
|
mask_resized = cv2.resize(mask_pred, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# 허깅페이스 방식의 min-max 정규화 (임계값 없이)
|
|
ma = float(mask_resized.max())
|
|
mi = float(mask_resized.min())
|
|
denom = (ma - mi) if (ma - mi) != 0 else 1.0
|
|
mask_norm = (mask_resized - mi) / denom
|
|
|
|
# 허깅페이스 방식: 직접 0-255로 스케일링 (임계값 적용 안함)
|
|
mask_255 = (mask_norm * 255).astype(np.uint8)
|
|
|
|
# aggressiveness 파라미터 적용 (필요시에만)
|
|
if aggressiveness != 0.5:
|
|
# aggressiveness가 0.5가 아닐 때만 약간의 조정 적용
|
|
if aggressiveness > 0.5:
|
|
# 더 공격적: 밝기 증가
|
|
factor = 1.0 + (aggressiveness - 0.5) * 0.4
|
|
mask_255 = np.clip(mask_255 * factor, 0, 255).astype(np.uint8)
|
|
else:
|
|
# 더 부드럽게: 밝기 감소
|
|
factor = 0.6 + aggressiveness * 0.8
|
|
mask_255 = np.clip(mask_255 * factor, 0, 255).astype(np.uint8)
|
|
|
|
return mask_255
|
|
|
|
# ===================================================================================
|
|
# 기존 BackgroundRemovalModule 호환 인터페이스
|
|
# ===================================================================================
|
|
|
|
def _refine_mask(
|
|
self,
|
|
alpha_mask: np.ndarray,
|
|
*,
|
|
alpha_threshold: int | None = None,
|
|
keep_top_k_components: int | None = None,
|
|
min_component_area: int | None = None,
|
|
morph_kernel: int | None = 3,
|
|
morph_open_iters: int = 0,
|
|
morph_close_iters: int = 0,
|
|
dilate_iters: int = 0,
|
|
erode_iters: int = 0,
|
|
fill_holes: bool = False,
|
|
edge_feather: int = 0,
|
|
) -> np.ndarray:
|
|
"""알파 마스크를 후처리하여 과도 제거/잔여 잡음을 완화합니다.
|
|
|
|
매개변수 설명:
|
|
- alpha_threshold: 이진화 임계값(0~255). 설정 시 확실한 전경만 남깁니다.
|
|
- keep_top_k_components: 가장 큰 연결요소 K개만 유지(사람+상품=2 권장).
|
|
- min_component_area: 이 면적 미만의 작은 요소 제거.
|
|
- morph_*: 모폴로지 연산으로 노이즈 제거/홀 메움.
|
|
- dilate/erode: 경계 확장/축소 미세 조정.
|
|
- fill_holes: 내부 구멍 채우기(큰 구멍 포함).
|
|
- edge_feather: 가장자리 페더링(>0이면 가우시안 블러, 값은 커널 반경).
|
|
"""
|
|
|
|
try:
|
|
mask_uint8 = alpha_mask.astype(np.uint8)
|
|
|
|
# 1) 이진화 준비
|
|
if alpha_threshold is not None:
|
|
_, bin_mask = cv2.threshold(mask_uint8, int(alpha_threshold), 255, cv2.THRESH_BINARY)
|
|
else:
|
|
# 소프트 마스크라도 후단 처리를 위해 이진 마스크 생성
|
|
bin_mask = (mask_uint8 > 0).astype(np.uint8) * 255
|
|
|
|
# 2) 연결요소 기반 정리
|
|
if keep_top_k_components is not None or (min_component_area is not None and min_component_area > 1):
|
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(bin_mask, connectivity=8)
|
|
# 0은 배경
|
|
component_indices = list(range(1, num_labels))
|
|
if min_component_area is not None:
|
|
component_indices = [i for i in component_indices if int(stats[i, cv2.CC_STAT_AREA]) >= int(min_component_area)]
|
|
|
|
# 가장 큰 K개만 남기기
|
|
if keep_top_k_components is not None and keep_top_k_components > 0:
|
|
component_indices = sorted(
|
|
component_indices,
|
|
key=lambda i: int(stats[i, cv2.CC_STAT_AREA]),
|
|
reverse=True,
|
|
)[: int(keep_top_k_components)]
|
|
|
|
filtered = np.zeros_like(bin_mask)
|
|
for i in component_indices:
|
|
filtered[labels == i] = 255
|
|
bin_mask = filtered
|
|
|
|
# 3) 모폴로지 연산으로 노이즈 제거/홀 메움
|
|
k = 3 if morph_kernel is None or morph_kernel <= 0 else int(morph_kernel)
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
|
|
|
if morph_open_iters > 0:
|
|
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_open_iters))
|
|
|
|
if morph_close_iters > 0:
|
|
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel, iterations=int(morph_close_iters))
|
|
|
|
if dilate_iters > 0:
|
|
bin_mask = cv2.dilate(bin_mask, kernel, iterations=int(dilate_iters))
|
|
if erode_iters > 0:
|
|
bin_mask = cv2.erode(bin_mask, kernel, iterations=int(erode_iters))
|
|
|
|
# 4) 내부 큰 구멍 채우기
|
|
if fill_holes:
|
|
flood = bin_mask.copy()
|
|
h, w = flood.shape[:2]
|
|
flood_mask = np.zeros((h + 2, w + 2), np.uint8)
|
|
cv2.floodFill(flood, flood_mask, (0, 0), 255)
|
|
flood_inv = cv2.bitwise_not(flood)
|
|
bin_mask = cv2.bitwise_or(bin_mask, flood_inv)
|
|
|
|
# 5) 가장자리 페더링(부드럽게)
|
|
if edge_feather and edge_feather > 0:
|
|
r = int(edge_feather)
|
|
r = max(1, min(31, r * 2 + 1)) # 홀수 커널 보장, 과도한 값 방지
|
|
soft = cv2.GaussianBlur(bin_mask, (r, r), 0)
|
|
return soft.astype(np.uint8)
|
|
|
|
return bin_mask.astype(np.uint8)
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"마스크 후처리 오류: {e}", level=logging.ERROR, exc_info=True)
|
|
return alpha_mask
|
|
|
|
def is_available(self):
|
|
"""배경제거 모듈 사용 가능 여부 반환"""
|
|
return self._onnxruntime_available and (self.local_model_path and os.path.exists(self.local_model_path))
|
|
|
|
def get_init_error(self):
|
|
"""초기화 에러 메시지 반환"""
|
|
return self._init_error
|
|
|
|
def get_supported_models(self):
|
|
"""지원하는 모델 목록 반환"""
|
|
return self.SUPPORTED_MODELS.copy()
|
|
|
|
def get_default_model(self):
|
|
"""기본 모델명 반환"""
|
|
return self.default_model
|
|
|
|
def set_default_model(self, model_name):
|
|
"""기본 모델 설정"""
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"지원하지 않는 모델명: {model_name}")
|
|
self.default_model = model_name
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI 기본 모델이 '{model_name}'으로 변경됨", level=logging.INFO)
|
|
|
|
def get_model_description(self, model_name):
|
|
"""모델 설명 반환"""
|
|
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
|
|
|
|
def to_white_background(self, img: Image.Image) -> Image.Image:
|
|
"""RGBA 이미지를 흰 배경과 합성 (이미 RGB라면 그대로 반환)"""
|
|
if img.mode in ("RGBA", "BGRA"):
|
|
bg = Image.new("RGB", img.size, (255, 255, 255))
|
|
bg.paste(img, mask=img.split()[-1])
|
|
return bg
|
|
else:
|
|
# 이미 RGB이거나 다른 모드라면 RGB로 변환
|
|
return img.convert("RGB")
|
|
|
|
def remove_background(self, image_path, model_name=None, force_cpu=None, **kwargs):
|
|
"""
|
|
이미지에서 배경을 제거하여 PIL Image 반환
|
|
기존 BackgroundRemovalModule.remove_background와 동일한 인터페이스
|
|
"""
|
|
if not self.is_available():
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI 모듈 사용 불가: {self._init_error}", level=logging.ERROR)
|
|
return None
|
|
|
|
if not os.path.exists(image_path):
|
|
if self.logger:
|
|
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}", level=logging.ERROR)
|
|
return None
|
|
|
|
# force_cpu 매개변수 처리
|
|
original_providers = self.providers
|
|
if force_cpu is True:
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("⚠️ CPU 모드 강제 실행 (BriaAI)", level=logging.WARNING)
|
|
elif force_cpu is False:
|
|
# GPU 가속 강제 사용 (DirectML 테스트용)
|
|
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
|
if 'DmlExecutionProvider' in original_providers:
|
|
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
|
elif 'CUDAExecutionProvider' in original_providers:
|
|
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
else:
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"🔥 GPU 모드 강제 실행 (BriaAI): {self.providers}", level=logging.WARNING)
|
|
|
|
# 모델 로드 (지연 로딩)
|
|
model_loaded = self._load_model()
|
|
|
|
# 원래 providers 복원
|
|
if force_cpu is not None:
|
|
self.providers = original_providers
|
|
|
|
if not model_loaded:
|
|
return None
|
|
|
|
try:
|
|
# 이미지 로드
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
if self.logger:
|
|
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR)
|
|
return None
|
|
|
|
# 모델명 결정 및 aggressiveness 설정
|
|
effective_model_name = model_name or self.default_model
|
|
if effective_model_name not in self.SUPPORTED_MODELS:
|
|
if self.logger:
|
|
self.logger.log(f"지원하지 않는 모델명: {effective_model_name}. bria-rmbg-1.4로 대체 사용", level=logging.WARNING)
|
|
effective_model_name = "bria-rmbg-1.4"
|
|
|
|
aggressiveness = self.MODEL_AGGRESSIVENESS.get(effective_model_name, 0.5)
|
|
custom_aggressiveness = kwargs.get("aggressiveness", aggressiveness)
|
|
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI 배경제거 시작: {effective_model_name} (aggressiveness={custom_aggressiveness})", level=logging.DEBUG)
|
|
|
|
start_time = time.time()
|
|
|
|
# 전처리
|
|
input_tensor, (orig_h, orig_w) = self._preprocess(img)
|
|
|
|
# 추론
|
|
mask_pred = self._infer(input_tensor)
|
|
|
|
# 후처리
|
|
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w), aggressiveness=custom_aggressiveness)
|
|
|
|
# 선택적 마스크 보정
|
|
if any(k in kwargs for k in (
|
|
'alpha_threshold', 'keep_top_k_components', 'min_component_area',
|
|
'morph_kernel', 'morph_open_iters', 'morph_close_iters',
|
|
'dilate_iters', 'erode_iters', 'fill_holes', 'edge_feather'
|
|
)):
|
|
alpha_mask = self._refine_mask(
|
|
alpha_mask,
|
|
alpha_threshold=kwargs.get('alpha_threshold'),
|
|
keep_top_k_components=kwargs.get('keep_top_k_components'),
|
|
min_component_area=kwargs.get('min_component_area'),
|
|
morph_kernel=kwargs.get('morph_kernel', 3),
|
|
morph_open_iters=kwargs.get('morph_open_iters', 0),
|
|
morph_close_iters=kwargs.get('morph_close_iters', 0),
|
|
dilate_iters=kwargs.get('dilate_iters', 0),
|
|
erode_iters=kwargs.get('erode_iters', 0),
|
|
fill_holes=kwargs.get('fill_holes', False),
|
|
edge_feather=kwargs.get('edge_feather', 0),
|
|
)
|
|
|
|
# DirectML 알파 채널 이슈 회피: 바로 흰 배경 합성된 BGR로 처리
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
# 알파 마스크를 이용해 흰 배경 합성 (DirectML 알파 채널 이슈 회피)
|
|
alpha_normalized = alpha_mask.astype(np.float32) / 255.0
|
|
alpha_3d = np.stack([alpha_normalized] * 3, axis=-1)
|
|
|
|
# 흰 배경과 합성
|
|
white_bg = np.full_like(img_rgb, 255, dtype=np.uint8)
|
|
blended_rgb = (
|
|
img_rgb.astype(np.float32) * alpha_3d +
|
|
white_bg.astype(np.float32) * (1.0 - alpha_3d)
|
|
).astype(np.uint8)
|
|
|
|
# PIL Image로 변환 (RGB 모드로 - 알파 채널 없음)
|
|
result = Image.fromarray(blended_rgb, 'RGB')
|
|
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
|
|
if self.logger:
|
|
provider_status = "GPU" if any('CUDA' in p or 'Dml' in p for p in self._session.get_providers()) else "CPU"
|
|
self.logger.log(f"✅ BriaAI 배경제거 성공: {effective_model_name} ({provider_status}, {processing_time:.2f}초)", level=logging.INFO)
|
|
|
|
# 마스크 통계 로깅
|
|
mask_stats = {
|
|
'min': int(alpha_mask.min()),
|
|
'max': int(alpha_mask.max()),
|
|
'mean': float(alpha_mask.mean()),
|
|
'nonzero_count': int(np.count_nonzero(alpha_mask))
|
|
}
|
|
self.logger.log(f"BriaAI 마스크 통계: {mask_stats}", level=logging.DEBUG)
|
|
|
|
# GPU 메모리 사용량 로깅
|
|
if self.gpu_manager and hasattr(self.gpu_manager, 'log_gpu_memory_usage'):
|
|
self.gpu_manager.log_gpu_memory_usage()
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI 배경제거 처리 중 오류: {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|
|
|
|
def remove_background_array(self, image_bgr: np.ndarray, model_name=None, force_cpu=None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
배경제거 결과를 numpy array로 직접 반환 (추가 편의 메서드)
|
|
Returns: (result_bgr, alpha_mask)
|
|
"""
|
|
# force_cpu 매개변수 처리
|
|
original_providers = self.providers
|
|
if force_cpu is True:
|
|
self.providers = ['CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log("⚠️ CPU 모드 강제 실행 (BriaAI array)", level=logging.WARNING)
|
|
elif force_cpu is False:
|
|
# GPU 가속 강제 사용
|
|
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
|
if 'DmlExecutionProvider' in original_providers:
|
|
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
|
elif 'CUDAExecutionProvider' in original_providers:
|
|
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
if self.logger:
|
|
self.logger.log(f"🔥 GPU 모드 강제 실행 (BriaAI array): {self.providers}", level=logging.WARNING)
|
|
|
|
# 모델 로드
|
|
model_loaded = self.is_available() and self._load_model()
|
|
|
|
# 원래 providers 복원
|
|
if force_cpu is not None:
|
|
self.providers = original_providers
|
|
|
|
if not model_loaded:
|
|
return image_bgr, None
|
|
|
|
try:
|
|
effective_model_name = model_name or self.default_model
|
|
aggressiveness = self.MODEL_AGGRESSIVENESS.get(effective_model_name, 0.5)
|
|
custom_aggressiveness = kwargs.get("aggressiveness", aggressiveness)
|
|
|
|
# 전처리 -> 추론 -> 후처리
|
|
input_tensor, (orig_h, orig_w) = self._preprocess(image_bgr)
|
|
mask_pred = self._infer(input_tensor)
|
|
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w), aggressiveness=custom_aggressiveness)
|
|
|
|
# 선택적 마스크 보정
|
|
if any(k in kwargs for k in (
|
|
'alpha_threshold', 'keep_top_k_components', 'min_component_area',
|
|
'morph_kernel', 'morph_open_iters', 'morph_close_iters',
|
|
'dilate_iters', 'erode_iters', 'fill_holes', 'edge_feather'
|
|
)):
|
|
alpha_mask = self._refine_mask(
|
|
alpha_mask,
|
|
alpha_threshold=kwargs.get('alpha_threshold'),
|
|
keep_top_k_components=kwargs.get('keep_top_k_components'),
|
|
min_component_area=kwargs.get('min_component_area'),
|
|
morph_kernel=kwargs.get('morph_kernel', 3),
|
|
morph_open_iters=kwargs.get('morph_open_iters', 0),
|
|
morph_close_iters=kwargs.get('morph_close_iters', 0),
|
|
dilate_iters=kwargs.get('dilate_iters', 0),
|
|
erode_iters=kwargs.get('erode_iters', 0),
|
|
fill_holes=kwargs.get('fill_holes', False),
|
|
edge_feather=kwargs.get('edge_feather', 0),
|
|
)
|
|
|
|
# 흰색 배경 합성
|
|
mask_3d = np.stack([alpha_mask] * 3, axis=-1)
|
|
result_bgr = (
|
|
image_bgr.astype(np.float32) * (mask_3d.astype(np.float32) / 255.0)
|
|
+ 255.0 * (1.0 - (mask_3d.astype(np.float32) / 255.0))
|
|
).clip(0, 255).astype(np.uint8)
|
|
|
|
return result_bgr, alpha_mask
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"BriaAI 배경제거 array 처리 중 오류: {e}", level=logging.ERROR, exc_info=True)
|
|
return image_bgr, None
|
|
|
|
def _preload_sessions(self):
|
|
"""BriaAI 모델 미리 로딩"""
|
|
if self.logger:
|
|
self.logger.log("🔄 BriaAI 모델 미리 로딩 중...", level=logging.INFO)
|
|
|
|
if self._load_model():
|
|
if self.logger:
|
|
self.logger.log("✅ BriaAI 모델 미리 로딩 완료", level=logging.INFO)
|
|
else:
|
|
if self.logger:
|
|
self.logger.log("⚠️ BriaAI 모델 로딩 실패", level=logging.WARNING)
|