""" BriaAI RMBG 1.4 ONNXRuntime 기반 배경제거 프로세서 rembg를 사용하지 않고 ONNX 모델을 직접 로드하여 추론합니다. """ import logging import time from typing import Optional, Tuple import numpy as np import cv2 from PIL import Image try: import onnxruntime as ort except Exception as e: # pragma: no cover ort = None # 런타임에서 에러 메시지로 안내 from ..core.config import settings logger = logging.getLogger(__name__) class BriaRMBGOnnxProcessor: """BriaAI RMBG 1.4 ONNX 모델을 사용하는 배경 제거 프로세서""" def __init__(self, *args, **kwargs): self._session: Optional["ort.InferenceSession"] = None self._input_name: Optional[str] = None self._output_name: Optional[str] = None self._model_input_size: Tuple[int, int] = (1024, 1024) # (W, H) logger.info("BriaRMBGOnnxProcessor 초기화 완료") async def load_model(self) -> bool: """ONNX 세션을 로드합니다.""" if ort is None: logger.error("onnxruntime 모듈을 불러오지 못했습니다. onnxruntime 패키지를 설치하세요.") return False model_path = settings.REMBG_MODEL_PATH logger.info(f"Bria RMBG ONNX 세션 생성 중... path={model_path}") try: sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # Jetson에서는 TensorRT EP 호환 이슈가 있을 수 있어 CUDA, CPU 우선 사용 if settings.USE_CUDA: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] self._session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers) inputs = self._session.get_inputs() outputs = self._session.get_outputs() if not inputs or not outputs: raise RuntimeError("ONNX 모델의 입출력 정의를 찾을 수 없습니다.") self._input_name = inputs[0].name self._output_name = outputs[0].name logger.info( f"Bria RMBG ONNX 세션 생성 완료, Providers: {self._session.get_providers()} | " f"Input: {self._input_name}, Output: {self._output_name}" ) return True except Exception as e: logger.error(f"Bria RMBG ONNX 세션 생성 실패: {e}") return False def _preprocess(self, image_bgr: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]: """BGR uint8 이미지를 모델 입력(NCHW float32, 정규화)로 변환""" orig_h, orig_w = image_bgr.shape[:2] # BGR -> RGB image_rgb = image_bgr[:, :, ::-1] # 리사이즈 (W,H) target_w, target_h = self._model_input_size resized = cv2.resize(image_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR) # float32 [0,1] -> normalize(mean=0.5, std=1.0) == x - 0.5 tensor = resized.astype(np.float32) / 255.0 tensor = tensor - 0.5 # HWC -> CHW, 배치 축 추가 nchw = np.transpose(tensor, (2, 0, 1))[np.newaxis, ...] return nchw, (orig_h, orig_w) def _infer(self, input_tensor: np.ndarray) -> np.ndarray: """ONNX 추론 수행 후 [H,W] 마스크 확률맵 반환(0~1 범위 전처리 전)""" outputs = self._session.run([self._output_name], {self._input_name: input_tensor}) pred = outputs[0] # 예상 출력: [1, 1, H, W] 또는 [1, H, W] pred = np.array(pred) if pred.ndim == 4: pred = pred[0, 0] elif pred.ndim == 3: pred = pred[0] # 이제 pred는 [H, W] return pred def _postprocess(self, mask_pred: np.ndarray, orig_size: Tuple[int, int]) -> np.ndarray: """모델 출력 마스크를 원본 해상도로 보간하고 0..255 uint8로 변환""" orig_h, orig_w = orig_size # 모델 출력(H,W)을 원본 크기로 리사이즈 (W,H) mask_resized = cv2.resize(mask_pred, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) # min-max 정규화 안전 처리 ma = float(mask_resized.max()) mi = float(mask_resized.min()) denom = (ma - mi) if (ma - mi) != 0 else 1.0 mask_norm = (mask_resized - mi) / denom mask_u8 = (mask_norm * 255.0).clip(0, 255).astype(np.uint8) return mask_u8 async def remove_background(self, image: np.ndarray, model_name: str = None) -> tuple: """이미지에서 배경을 제거하여 (배경제거된 BGR 이미지, 알파 마스크) 반환""" try: start_time = time.time() logger.info(f"배경제거 시작(Bria ONNX): image.shape={image.shape}, model_name={model_name}") if self._session is None: raise RuntimeError("Bria RMBG ONNX 세션이 로드되지 않았습니다.") input_tensor, (orig_h, orig_w) = self._preprocess(image) mask_pred = self._infer(input_tensor) alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w)) # 흰색 배경 합성 (BGR) if image.ndim == 3 and image.shape[2] == 3: mask_3 = np.stack([alpha_mask] * 3, axis=-1) result_bgr = ( image.astype(np.float32) * (mask_3.astype(np.float32) / 255.0) + 255.0 * (1.0 - (mask_3.astype(np.float32) / 255.0)) ).clip(0, 255).astype(np.uint8) else: # 비정상 입력 대비 result_bgr = image duration = time.time() - start_time try: logger.info( f"Bria ONNX mask stats: min={int(alpha_mask.min())}, max={int(alpha_mask.max())}, " f"mean={float(alpha_mask.mean()):.3f}" ) except Exception: pass logger.info(f"'bria-rmbg' processed in {duration:.3f}s") return result_bgr, alpha_mask except Exception as e: logger.error(f"배경 제거 처리 실패(Bria ONNX): {e}", exc_info=True) return image, None