IMG_Worker/modules/bria_background_removal_mod...

777 lines
35 KiB
Python

"""
BriaAI RMBG 1.4 ONNXRuntime 기반 배경제거 모듈
기존 background_removal_module.py를 완전 대체할 수 있는 호환 인터페이스 제공
"""
import os
import cv2
from PIL import Image
import logging
import numpy as np
import time
from typing import Optional, Tuple
class BriaBackgroundRemovalModule:
"""
BriaAI RMBG 1.4 ONNX 모델 기반 배경제거 모듈
기존 BackgroundRemovalModule과 완전 호환되는 인터페이스 제공
"""
# 지원하는 모델 목록 (BriaAI 기반)
SUPPORTED_MODELS = {
"bria-rmbg-1.4": "BriaAI RMBG 1.4 | 고품질 | 빠름 | 범용 (기본값)",
"bria-rmbg-aggressive": "BriaAI RMBG 1.4 | 강력한 배경제거 | 빠름 | 깔끔한 결과",
"bria-rmbg-gentle": "BriaAI RMBG 1.4 | 부드러운 배경제거 | 빠름 | 세밀한 경계"
}
# 모델별 aggressiveness 매핑
MODEL_AGGRESSIVENESS = {
"bria-rmbg-1.4": 0.5, # 기본값
"bria-rmbg-aggressive": 0.8, # 강력한 배경제거
"bria-rmbg-gentle": 0.2 # 부드러운 배경제거
}
def __init__(self, logger=None, default_model="bria-rmbg-1.4", gpu_manager=None, local_rembg_model_path: str | None = None):
self.logger = logger
self.default_model = default_model
self.gpu_manager = gpu_manager
self.local_model_path = local_rembg_model_path # BriaAI ONNX 모델 경로
# ONNX 세션 관련
self._session: Optional = None
self._input_name: Optional[str] = None
self._output_name: Optional[str] = None
self._model_input_size: Tuple[int, int] = (1024, 1024) # (W, H)
self._model_loaded = False
self._init_error = None
if self.logger:
self.logger.log("BriaAI 배경제거 모듈 초기화 시작", level=logging.INFO)
# ONNX Runtime 사용 가능성 확인
self._check_onnxruntime_availability()
if self.logger:
self.logger.log("BriaAI 배경제거 모듈 초기화 완료", level=logging.INFO)
def _check_onnxruntime_availability(self):
"""ONNX Runtime 사용 가능 여부 확인"""
try:
import onnxruntime as ort
# 사용 가능한 프로바이더 확인
available_providers = ort.get_available_providers()
if self.logger:
self.logger.log(f"ONNX Runtime 사용 가능한 프로바이더: {available_providers}", level=logging.INFO)
# Override/캐시 우선 결정
provider_override = os.environ.get('IMGWK_REMBG_PROVIDER', 'auto').lower()
cached_provider = self._read_provider_cache()
self.providers = ['CPUExecutionProvider']
# GPU 가속 설정 (BriaAI는 가벼워서 1GB도 충분)
if self.gpu_manager and self.gpu_manager.can_use_cuda:
gpu_memory_mb = int(getattr(self.gpu_manager, 'gpu_memory_total', 0) or 0)
def can_use_dml():
return ('DmlExecutionProvider' in available_providers) and (gpu_memory_mb >= 1024)
def can_use_cuda():
return ('CUDAExecutionProvider' in available_providers) and (gpu_memory_mb >= 1024)
if provider_override == 'cpu':
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log("BriaAI provider override=cpu", level=logging.INFO)
elif provider_override == 'dml':
if can_use_dml():
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
else:
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log(f"BriaAI provider override=dml → {self.providers}", level=logging.INFO)
elif cached_provider == 'dml':
if can_use_dml():
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
if self.logger:
self.logger.log("BriaAI provider cache=dml 적용", level=logging.INFO)
else:
self.providers = ['CPUExecutionProvider']
else:
# auto: DML → CUDA → CPU 순으로 선택
if can_use_dml():
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
if self.logger:
self.logger.log(f"BriaAI DirectML 가속 사용 가능 (VRAM: {gpu_memory_mb}MB)", level=logging.INFO)
elif can_use_cuda():
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
if self.logger:
self.logger.log(f"BriaAI CUDA 가속 사용 가능 (VRAM: {gpu_memory_mb}MB)", level=logging.INFO)
else:
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log(f"VRAM/Provider 조건 불충족으로 CPU 모드 사용 (VRAM: {gpu_memory_mb}MB)", level=logging.WARNING)
else:
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log("BriaAI CPU 모드로 설정", level=logging.INFO)
self._onnxruntime_available = True
return True
except ImportError as e:
self._init_error = f"ONNX Runtime이 설치되지 않음: {e}"
self._onnxruntime_available = False
if self.logger:
self.logger.log(self._init_error, level=logging.ERROR)
return False
except Exception as e:
self._init_error = f"ONNX Runtime 초기화 실패: {e}"
self._onnxruntime_available = False
if self.logger:
self.logger.log(self._init_error, level=logging.ERROR)
return False
def _load_model(self) -> bool:
"""BriaAI ONNX 모델을 로드합니다."""
if self._model_loaded:
return True
if not self._onnxruntime_available:
if self.logger:
self.logger.log("ONNX Runtime을 사용할 수 없어 모델 로드 실패", level=logging.ERROR)
return False
if not self.local_model_path or not os.path.exists(self.local_model_path):
self._init_error = f"BriaAI 모델 파일을 찾을 수 없음: {self.local_model_path}"
if self.logger:
self.logger.log(self._init_error, level=logging.ERROR)
return False
# 단계별 폴백 시도
fallback_providers = [
(self.providers, "원본 providers"),
([("CPUExecutionProvider", {})], "CPU 폴백")
]
for attempt_providers, attempt_name in fallback_providers:
try:
import onnxruntime as ort
if self.logger:
self.logger.log(f"BriaAI ONNX 모델 로딩 시도 ({attempt_name}): {self.local_model_path}", level=logging.INFO)
# 세션 옵션 설정 (안전성 우선)
sess_options = ort.SessionOptions()
# CPU 모드에서는 더 보수적인 설정
if attempt_providers == [("CPUExecutionProvider", {})]:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.enable_mem_pattern = False
sess_options.enable_cpu_mem_arena = False
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if self.logger:
self.logger.log("🔒 CPU 안전 모드: 모든 최적화 비활성화", level=logging.INFO)
else:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.enable_mem_pattern = True
sess_options.enable_cpu_mem_arena = True
# DirectML 사용시 추가 설정
if 'DmlExecutionProvider' in str(attempt_providers):
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if self.logger:
self.logger.log("DirectML 최적화 설정 적용", level=logging.DEBUG)
# ONNX 세션 생성 (타임아웃 설정)
import signal
def timeout_handler(signum, frame):
raise TimeoutError("모델 로딩 타임아웃")
# Windows에서는 signal.alarm이 작동하지 않으므로 threading 사용
import threading
import time
session_created = [False]
session_result = [None]
session_error = [None]
def create_session():
try:
session_result[0] = ort.InferenceSession(
self.local_model_path,
sess_options=sess_options,
providers=attempt_providers
)
session_created[0] = True
except Exception as e:
session_error[0] = e
# 세션 생성을 별도 스레드에서 실행 (30초 타임아웃)
session_thread = threading.Thread(target=create_session)
session_thread.daemon = True
session_thread.start()
session_thread.join(timeout=30)
if not session_created[0]:
if session_error[0]:
raise session_error[0]
else:
raise TimeoutError("모델 로딩이 30초 내에 완료되지 않음")
self._session = session_result[0]
# 입출력 정보 가져오기
inputs = self._session.get_inputs()
outputs = self._session.get_outputs()
if not inputs or not outputs:
raise RuntimeError("ONNX 모델의 입출력 정의를 찾을 수 없습니다")
# 입력/출력 이름 설정
self._input_name = inputs[0].name
self._output_name = outputs[0].name
# 실제 사용된 프로바이더 확인
actual_providers = self._session.get_providers()
if self.logger:
self.logger.log(
f"✅ BriaAI ONNX 모델 로딩 완료 ({attempt_name}) | "
f"Providers: {actual_providers} | "
f"Input: {self._input_name} | Output: {self._output_name}",
level=logging.INFO
)
self._model_loaded = True
self._providers_used = actual_providers
# 성공 프로바이더 캐시 기록
try:
prov = 'dml' if any('Dml' in p for p in actual_providers) else 'cpu'
self._write_provider_cache(prov)
except Exception:
pass
return True
except Exception as e:
error_msg = f"BriaAI ONNX 모델 로딩 실패 ({attempt_name}): {e}"
if self.logger:
self.logger.log(error_msg, level=logging.WARNING if attempt_name != "CPU 폴백" else logging.ERROR, exc_info=True)
# 마지막 시도가 아니면 계속 진행
if attempt_name != "CPU 폴백":
continue
# 모든 시도 실패
self._init_error = "모든 provider에서 BriaAI ONNX 모델 로딩 실패"
if self.logger:
self.logger.log(self._init_error, level=logging.ERROR)
return False
# ─────────────────────────────────────────────────────────────
# 프로바이더 캐시 유틸
# ─────────────────────────────────────────────────────────────
def _cache_path(self):
try:
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
user_data = os.path.join(root_dir, 'user_data')
os.makedirs(user_data, exist_ok=True)
return os.path.join(user_data, 'rembg_provider.json')
except Exception:
return None
def _read_provider_cache(self):
try:
path = self._cache_path()
if not path or not os.path.exists(path):
return None
import json
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
prov = (data or {}).get('last_success_provider', '').lower()
return prov if prov in ('dml', 'cpu') else None
except Exception:
return None
def _write_provider_cache(self, provider: str):
try:
path = self._cache_path()
if not path:
return
import json
with open(path, 'w', encoding='utf-8') as f:
json.dump({"last_success_provider": provider}, f, ensure_ascii=False)
except Exception:
pass
def _preprocess(self, image_bgr: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
"""BGR uint8 이미지를 모델 입력(NCHW float32, 정규화)로 변환 (허깅페이스 호환)"""
orig_h, orig_w = image_bgr.shape[:2]
# 입력 검증
if len(image_bgr.shape) != 3 or image_bgr.shape[2] != 3:
raise ValueError(f"입력 이미지는 3채널 BGR이어야 합니다. 현재: {image_bgr.shape}")
# BGR -> RGB (허깅페이스는 RGB 입력 가정)
image_rgb = image_bgr[:, :, ::-1].copy()
# 리사이즈 (W,H) - 허깅페이스와 동일한 bilinear 보간
target_w, target_h = self._model_input_size
resized = cv2.resize(image_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# 허깅페이스 방식: float32 변환 후 정규화
# 1. [0,255] -> [0,1]
tensor = resized.astype(np.float32) / 255.0
# 2. normalize(mean=[0.5,0.5,0.5], std=[1.0,1.0,1.0]) -> (x - 0.5) / 1.0
tensor = tensor - 0.5
# 3. HWC -> CHW, 배치 축 추가 (NCHW)
nchw = np.transpose(tensor, (2, 0, 1))[np.newaxis, ...]
# 디버그 정보
if self.logger:
self.logger.log(f"전처리 완료: {image_bgr.shape} -> {nchw.shape}, 값 범위: [{nchw.min():.3f}, {nchw.max():.3f}]", level=logging.DEBUG)
return nchw, (orig_h, orig_w)
def _infer(self, input_tensor: np.ndarray) -> np.ndarray:
"""ONNX 추론 수행 후 [H,W] 마스크 확률맵 반환 (BriaAI 모델 특화)"""
outputs = self._session.run(None, {self._input_name: input_tensor})
# BriaAI 모델은 여러 출력을 가질 수 있음 (6개 side outputs)
# 첫 번째 출력(d1)이 가장 정확한 결과
if isinstance(outputs, (list, tuple)) and len(outputs) > 0:
pred = outputs[0] # 첫 번째 출력 사용 (d1)
if self.logger:
self.logger.log(f"ONNX 모델 출력 개수: {len(outputs)}, 첫 번째 출력 shape: {pred.shape}", level=logging.DEBUG)
else:
pred = outputs
# numpy array로 변환
pred = np.array(pred)
# 차원 정리: [B, C, H, W] -> [H, W]
if pred.ndim == 4:
# [1, 1, H, W] -> [H, W]
pred = pred[0, 0]
elif pred.ndim == 3:
if pred.shape[0] == 1:
# [1, H, W] -> [H, W]
pred = pred[0]
else:
# [C, H, W] -> [H, W] (첫 번째 채널 사용)
pred = pred[0]
elif pred.ndim == 2:
# 이미 [H, W] 형태
pass
else:
raise ValueError(f"예상하지 못한 출력 차원: {pred.shape}")
# 확률값 범위 확인 및 정규화 (0~1 사이가 아닐 경우)
if pred.max() > 1.0 or pred.min() < 0.0:
if self.logger:
self.logger.log(f"출력 값 범위 이상: [{pred.min():.3f}, {pred.max():.3f}] -> sigmoid 적용", level=logging.WARNING)
# sigmoid 함수 적용 (모델에서 sigmoid가 적용되지 않은 경우)
pred = 1.0 / (1.0 + np.exp(-pred))
if self.logger:
self.logger.log(f"추론 완료: {pred.shape}, 값 범위: [{pred.min():.3f}, {pred.max():.3f}]", level=logging.DEBUG)
return pred
def _postprocess(self, mask_pred: np.ndarray, orig_size: Tuple[int, int], aggressiveness: float = 0.5) -> np.ndarray:
"""모델 출력 마스크를 원본 해상도로 보간하고 0..255 uint8로 변환 (허깅페이스 방식)"""
orig_h, orig_w = orig_size
# 모델 출력(H,W)을 원본 크기로 리사이즈 (W,H) - 허깅페이스와 동일하게 bilinear 보간
mask_resized = cv2.resize(mask_pred, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
# 허깅페이스 방식의 min-max 정규화 (임계값 없이)
ma = float(mask_resized.max())
mi = float(mask_resized.min())
denom = (ma - mi) if (ma - mi) != 0 else 1.0
mask_norm = (mask_resized - mi) / denom
# 허깅페이스 방식: 직접 0-255로 스케일링 (임계값 적용 안함)
mask_255 = (mask_norm * 255).astype(np.uint8)
# aggressiveness 파라미터 적용 (필요시에만)
if aggressiveness != 0.5:
# aggressiveness가 0.5가 아닐 때만 약간의 조정 적용
if aggressiveness > 0.5:
# 더 공격적: 밝기 증가
factor = 1.0 + (aggressiveness - 0.5) * 0.4
mask_255 = np.clip(mask_255 * factor, 0, 255).astype(np.uint8)
else:
# 더 부드럽게: 밝기 감소
factor = 0.6 + aggressiveness * 0.8
mask_255 = np.clip(mask_255 * factor, 0, 255).astype(np.uint8)
return mask_255
# ===================================================================================
# 기존 BackgroundRemovalModule 호환 인터페이스
# ===================================================================================
def _refine_mask(
self,
alpha_mask: np.ndarray,
*,
alpha_threshold: int | None = None,
keep_top_k_components: int | None = None,
min_component_area: int | None = None,
morph_kernel: int | None = 3,
morph_open_iters: int = 0,
morph_close_iters: int = 0,
dilate_iters: int = 0,
erode_iters: int = 0,
fill_holes: bool = False,
edge_feather: int = 0,
) -> np.ndarray:
"""알파 마스크를 후처리하여 과도 제거/잔여 잡음을 완화합니다.
매개변수 설명:
- alpha_threshold: 이진화 임계값(0~255). 설정 시 확실한 전경만 남깁니다.
- keep_top_k_components: 가장 큰 연결요소 K개만 유지(사람+상품=2 권장).
- min_component_area: 이 면적 미만의 작은 요소 제거.
- morph_*: 모폴로지 연산으로 노이즈 제거/홀 메움.
- dilate/erode: 경계 확장/축소 미세 조정.
- fill_holes: 내부 구멍 채우기(큰 구멍 포함).
- edge_feather: 가장자리 페더링(>0이면 가우시안 블러, 값은 커널 반경).
"""
try:
mask_uint8 = alpha_mask.astype(np.uint8)
# 1) 이진화 준비
if alpha_threshold is not None:
_, bin_mask = cv2.threshold(mask_uint8, int(alpha_threshold), 255, cv2.THRESH_BINARY)
else:
# 소프트 마스크라도 후단 처리를 위해 이진 마스크 생성
bin_mask = (mask_uint8 > 0).astype(np.uint8) * 255
# 2) 연결요소 기반 정리
if keep_top_k_components is not None or (min_component_area is not None and min_component_area > 1):
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(bin_mask, connectivity=8)
# 0은 배경
component_indices = list(range(1, num_labels))
if min_component_area is not None:
component_indices = [i for i in component_indices if int(stats[i, cv2.CC_STAT_AREA]) >= int(min_component_area)]
# 가장 큰 K개만 남기기
if keep_top_k_components is not None and keep_top_k_components > 0:
component_indices = sorted(
component_indices,
key=lambda i: int(stats[i, cv2.CC_STAT_AREA]),
reverse=True,
)[: int(keep_top_k_components)]
filtered = np.zeros_like(bin_mask)
for i in component_indices:
filtered[labels == i] = 255
bin_mask = filtered
# 3) 모폴로지 연산으로 노이즈 제거/홀 메움
k = 3 if morph_kernel is None or morph_kernel <= 0 else int(morph_kernel)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
if morph_open_iters > 0:
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_open_iters))
if morph_close_iters > 0:
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel, iterations=int(morph_close_iters))
if dilate_iters > 0:
bin_mask = cv2.dilate(bin_mask, kernel, iterations=int(dilate_iters))
if erode_iters > 0:
bin_mask = cv2.erode(bin_mask, kernel, iterations=int(erode_iters))
# 4) 내부 큰 구멍 채우기
if fill_holes:
flood = bin_mask.copy()
h, w = flood.shape[:2]
flood_mask = np.zeros((h + 2, w + 2), np.uint8)
cv2.floodFill(flood, flood_mask, (0, 0), 255)
flood_inv = cv2.bitwise_not(flood)
bin_mask = cv2.bitwise_or(bin_mask, flood_inv)
# 5) 가장자리 페더링(부드럽게)
if edge_feather and edge_feather > 0:
r = int(edge_feather)
r = max(1, min(31, r * 2 + 1)) # 홀수 커널 보장, 과도한 값 방지
soft = cv2.GaussianBlur(bin_mask, (r, r), 0)
return soft.astype(np.uint8)
return bin_mask.astype(np.uint8)
except Exception as e:
if self.logger:
self.logger.log(f"마스크 후처리 오류: {e}", level=logging.ERROR, exc_info=True)
return alpha_mask
def is_available(self):
"""배경제거 모듈 사용 가능 여부 반환"""
return self._onnxruntime_available and (self.local_model_path and os.path.exists(self.local_model_path))
def get_init_error(self):
"""초기화 에러 메시지 반환"""
return self._init_error
def get_supported_models(self):
"""지원하는 모델 목록 반환"""
return self.SUPPORTED_MODELS.copy()
def get_default_model(self):
"""기본 모델명 반환"""
return self.default_model
def set_default_model(self, model_name):
"""기본 모델 설정"""
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"지원하지 않는 모델명: {model_name}")
self.default_model = model_name
if self.logger:
self.logger.log(f"BriaAI 기본 모델이 '{model_name}'으로 변경됨", level=logging.INFO)
def get_model_description(self, model_name):
"""모델 설명 반환"""
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
def to_white_background(self, img: Image.Image) -> Image.Image:
"""RGBA 이미지를 흰 배경과 합성 (이미 RGB라면 그대로 반환)"""
if img.mode in ("RGBA", "BGRA"):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[-1])
return bg
else:
# 이미 RGB이거나 다른 모드라면 RGB로 변환
return img.convert("RGB")
def remove_background(self, image_path, model_name=None, force_cpu=None, **kwargs):
"""
이미지에서 배경을 제거하여 PIL Image 반환
기존 BackgroundRemovalModule.remove_background와 동일한 인터페이스
"""
if not self.is_available():
if self.logger:
self.logger.log(f"BriaAI 모듈 사용 불가: {self._init_error}", level=logging.ERROR)
return None
if not os.path.exists(image_path):
if self.logger:
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}", level=logging.ERROR)
return None
# force_cpu 매개변수 처리
original_providers = self.providers
if force_cpu is True:
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log("⚠️ CPU 모드 강제 실행 (BriaAI)", level=logging.WARNING)
elif force_cpu is False:
# GPU 가속 강제 사용 (DirectML 테스트용)
if self.gpu_manager and self.gpu_manager.can_use_cuda:
if 'DmlExecutionProvider' in original_providers:
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
elif 'CUDAExecutionProvider' in original_providers:
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log(f"🔥 GPU 모드 강제 실행 (BriaAI): {self.providers}", level=logging.WARNING)
# 모델 로드 (지연 로딩)
model_loaded = self._load_model()
# 원래 providers 복원
if force_cpu is not None:
self.providers = original_providers
if not model_loaded:
return None
try:
# 이미지 로드
img = cv2.imread(image_path)
if img is None:
if self.logger:
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR)
return None
# 모델명 결정 및 aggressiveness 설정
effective_model_name = model_name or self.default_model
if effective_model_name not in self.SUPPORTED_MODELS:
if self.logger:
self.logger.log(f"지원하지 않는 모델명: {effective_model_name}. bria-rmbg-1.4로 대체 사용", level=logging.WARNING)
effective_model_name = "bria-rmbg-1.4"
aggressiveness = self.MODEL_AGGRESSIVENESS.get(effective_model_name, 0.5)
custom_aggressiveness = kwargs.get("aggressiveness", aggressiveness)
if self.logger:
self.logger.log(f"BriaAI 배경제거 시작: {effective_model_name} (aggressiveness={custom_aggressiveness})", level=logging.DEBUG)
start_time = time.time()
# 전처리
input_tensor, (orig_h, orig_w) = self._preprocess(img)
# 추론
mask_pred = self._infer(input_tensor)
# 후처리
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w), aggressiveness=custom_aggressiveness)
# 선택적 마스크 보정
if any(k in kwargs for k in (
'alpha_threshold', 'keep_top_k_components', 'min_component_area',
'morph_kernel', 'morph_open_iters', 'morph_close_iters',
'dilate_iters', 'erode_iters', 'fill_holes', 'edge_feather'
)):
alpha_mask = self._refine_mask(
alpha_mask,
alpha_threshold=kwargs.get('alpha_threshold'),
keep_top_k_components=kwargs.get('keep_top_k_components'),
min_component_area=kwargs.get('min_component_area'),
morph_kernel=kwargs.get('morph_kernel', 3),
morph_open_iters=kwargs.get('morph_open_iters', 0),
morph_close_iters=kwargs.get('morph_close_iters', 0),
dilate_iters=kwargs.get('dilate_iters', 0),
erode_iters=kwargs.get('erode_iters', 0),
fill_holes=kwargs.get('fill_holes', False),
edge_feather=kwargs.get('edge_feather', 0),
)
# DirectML 알파 채널 이슈 회피: 바로 흰 배경 합성된 BGR로 처리
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 알파 마스크를 이용해 흰 배경 합성 (DirectML 알파 채널 이슈 회피)
alpha_normalized = alpha_mask.astype(np.float32) / 255.0
alpha_3d = np.stack([alpha_normalized] * 3, axis=-1)
# 흰 배경과 합성
white_bg = np.full_like(img_rgb, 255, dtype=np.uint8)
blended_rgb = (
img_rgb.astype(np.float32) * alpha_3d +
white_bg.astype(np.float32) * (1.0 - alpha_3d)
).astype(np.uint8)
# PIL Image로 변환 (RGB 모드로 - 알파 채널 없음)
result = Image.fromarray(blended_rgb, 'RGB')
end_time = time.time()
processing_time = end_time - start_time
if self.logger:
provider_status = "GPU" if any('CUDA' in p or 'Dml' in p for p in self._session.get_providers()) else "CPU"
self.logger.log(f"✅ BriaAI 배경제거 성공: {effective_model_name} ({provider_status}, {processing_time:.2f}초)", level=logging.INFO)
# 마스크 통계 로깅
mask_stats = {
'min': int(alpha_mask.min()),
'max': int(alpha_mask.max()),
'mean': float(alpha_mask.mean()),
'nonzero_count': int(np.count_nonzero(alpha_mask))
}
self.logger.log(f"BriaAI 마스크 통계: {mask_stats}", level=logging.DEBUG)
# GPU 메모리 사용량 로깅
if self.gpu_manager and hasattr(self.gpu_manager, 'log_gpu_memory_usage'):
self.gpu_manager.log_gpu_memory_usage()
return result
except Exception as e:
if self.logger:
self.logger.log(f"BriaAI 배경제거 처리 중 오류: {e}", level=logging.ERROR, exc_info=True)
return None
def remove_background_array(self, image_bgr: np.ndarray, model_name=None, force_cpu=None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
"""
배경제거 결과를 numpy array로 직접 반환 (추가 편의 메서드)
Returns: (result_bgr, alpha_mask)
"""
# force_cpu 매개변수 처리
original_providers = self.providers
if force_cpu is True:
self.providers = ['CPUExecutionProvider']
if self.logger:
self.logger.log("⚠️ CPU 모드 강제 실행 (BriaAI array)", level=logging.WARNING)
elif force_cpu is False:
# GPU 가속 강제 사용
if self.gpu_manager and self.gpu_manager.can_use_cuda:
if 'DmlExecutionProvider' in original_providers:
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
elif 'CUDAExecutionProvider' in original_providers:
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
if self.logger:
self.logger.log(f"🔥 GPU 모드 강제 실행 (BriaAI array): {self.providers}", level=logging.WARNING)
# 모델 로드
model_loaded = self.is_available() and self._load_model()
# 원래 providers 복원
if force_cpu is not None:
self.providers = original_providers
if not model_loaded:
return image_bgr, None
try:
effective_model_name = model_name or self.default_model
aggressiveness = self.MODEL_AGGRESSIVENESS.get(effective_model_name, 0.5)
custom_aggressiveness = kwargs.get("aggressiveness", aggressiveness)
# 전처리 -> 추론 -> 후처리
input_tensor, (orig_h, orig_w) = self._preprocess(image_bgr)
mask_pred = self._infer(input_tensor)
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w), aggressiveness=custom_aggressiveness)
# 선택적 마스크 보정
if any(k in kwargs for k in (
'alpha_threshold', 'keep_top_k_components', 'min_component_area',
'morph_kernel', 'morph_open_iters', 'morph_close_iters',
'dilate_iters', 'erode_iters', 'fill_holes', 'edge_feather'
)):
alpha_mask = self._refine_mask(
alpha_mask,
alpha_threshold=kwargs.get('alpha_threshold'),
keep_top_k_components=kwargs.get('keep_top_k_components'),
min_component_area=kwargs.get('min_component_area'),
morph_kernel=kwargs.get('morph_kernel', 3),
morph_open_iters=kwargs.get('morph_open_iters', 0),
morph_close_iters=kwargs.get('morph_close_iters', 0),
dilate_iters=kwargs.get('dilate_iters', 0),
erode_iters=kwargs.get('erode_iters', 0),
fill_holes=kwargs.get('fill_holes', False),
edge_feather=kwargs.get('edge_feather', 0),
)
# 흰색 배경 합성
mask_3d = np.stack([alpha_mask] * 3, axis=-1)
result_bgr = (
image_bgr.astype(np.float32) * (mask_3d.astype(np.float32) / 255.0)
+ 255.0 * (1.0 - (mask_3d.astype(np.float32) / 255.0))
).clip(0, 255).astype(np.uint8)
return result_bgr, alpha_mask
except Exception as e:
if self.logger:
self.logger.log(f"BriaAI 배경제거 array 처리 중 오류: {e}", level=logging.ERROR, exc_info=True)
return image_bgr, None
def _preload_sessions(self):
"""BriaAI 모델 미리 로딩"""
if self.logger:
self.logger.log("🔄 BriaAI 모델 미리 로딩 중...", level=logging.INFO)
if self._load_model():
if self.logger:
self.logger.log("✅ BriaAI 모델 미리 로딩 완료", level=logging.INFO)
else:
if self.logger:
self.logger.log("⚠️ BriaAI 모델 로딩 실패", level=logging.WARNING)