159 lines
6.2 KiB
Python
159 lines
6.2 KiB
Python
"""
|
|
BriaAI RMBG 1.4 ONNXRuntime 기반 배경제거 프로세서
|
|
rembg를 사용하지 않고 ONNX 모델을 직접 로드하여 추론합니다.
|
|
"""
|
|
import logging
|
|
import time
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
|
|
try:
|
|
import onnxruntime as ort
|
|
except Exception as e: # pragma: no cover
|
|
ort = None # 런타임에서 에러 메시지로 안내
|
|
|
|
from ..core.config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BriaRMBGOnnxProcessor:
|
|
"""BriaAI RMBG 1.4 ONNX 모델을 사용하는 배경 제거 프로세서"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._session: Optional["ort.InferenceSession"] = None
|
|
self._input_name: Optional[str] = None
|
|
self._output_name: Optional[str] = None
|
|
self._model_input_size: Tuple[int, int] = (1024, 1024) # (W, H)
|
|
logger.info("BriaRMBGOnnxProcessor 초기화 완료")
|
|
|
|
async def load_model(self) -> bool:
|
|
"""ONNX 세션을 로드합니다."""
|
|
if ort is None:
|
|
logger.error("onnxruntime 모듈을 불러오지 못했습니다. onnxruntime 패키지를 설치하세요.")
|
|
return False
|
|
|
|
model_path = settings.REMBG_MODEL_PATH
|
|
logger.info(f"Bria RMBG ONNX 세션 생성 중... path={model_path}")
|
|
|
|
try:
|
|
sess_options = ort.SessionOptions()
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
# Jetson에서는 TensorRT EP 호환 이슈가 있을 수 있어 CUDA, CPU 우선 사용
|
|
if settings.USE_CUDA:
|
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
else:
|
|
providers = ["CPUExecutionProvider"]
|
|
|
|
self._session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
|
|
|
inputs = self._session.get_inputs()
|
|
outputs = self._session.get_outputs()
|
|
if not inputs or not outputs:
|
|
raise RuntimeError("ONNX 모델의 입출력 정의를 찾을 수 없습니다.")
|
|
|
|
self._input_name = inputs[0].name
|
|
self._output_name = outputs[0].name
|
|
|
|
logger.info(
|
|
f"Bria RMBG ONNX 세션 생성 완료, Providers: {self._session.get_providers()} | "
|
|
f"Input: {self._input_name}, Output: {self._output_name}"
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Bria RMBG ONNX 세션 생성 실패: {e}")
|
|
return False
|
|
|
|
def _preprocess(self, image_bgr: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
|
"""BGR uint8 이미지를 모델 입력(NCHW float32, 정규화)로 변환"""
|
|
orig_h, orig_w = image_bgr.shape[:2]
|
|
# BGR -> RGB
|
|
image_rgb = image_bgr[:, :, ::-1]
|
|
|
|
# 리사이즈 (W,H)
|
|
target_w, target_h = self._model_input_size
|
|
resized = cv2.resize(image_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# float32 [0,1] -> normalize(mean=0.5, std=1.0) == x - 0.5
|
|
tensor = resized.astype(np.float32) / 255.0
|
|
tensor = tensor - 0.5
|
|
|
|
# HWC -> CHW, 배치 축 추가
|
|
nchw = np.transpose(tensor, (2, 0, 1))[np.newaxis, ...]
|
|
return nchw, (orig_h, orig_w)
|
|
|
|
def _infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
|
"""ONNX 추론 수행 후 [H,W] 마스크 확률맵 반환(0~1 범위 전처리 전)"""
|
|
outputs = self._session.run([self._output_name], {self._input_name: input_tensor})
|
|
pred = outputs[0]
|
|
# 예상 출력: [1, 1, H, W] 또는 [1, H, W]
|
|
pred = np.array(pred)
|
|
if pred.ndim == 4:
|
|
pred = pred[0, 0]
|
|
elif pred.ndim == 3:
|
|
pred = pred[0]
|
|
# 이제 pred는 [H, W]
|
|
return pred
|
|
|
|
def _postprocess(self, mask_pred: np.ndarray, orig_size: Tuple[int, int]) -> np.ndarray:
|
|
"""모델 출력 마스크를 원본 해상도로 보간하고 0..255 uint8로 변환"""
|
|
orig_h, orig_w = orig_size
|
|
# 모델 출력(H,W)을 원본 크기로 리사이즈 (W,H)
|
|
mask_resized = cv2.resize(mask_pred, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# min-max 정규화 안전 처리
|
|
ma = float(mask_resized.max())
|
|
mi = float(mask_resized.min())
|
|
denom = (ma - mi) if (ma - mi) != 0 else 1.0
|
|
mask_norm = (mask_resized - mi) / denom
|
|
|
|
mask_u8 = (mask_norm * 255.0).clip(0, 255).astype(np.uint8)
|
|
return mask_u8
|
|
|
|
async def remove_background(self, image: np.ndarray, model_name: str = None) -> tuple:
|
|
"""이미지에서 배경을 제거하여 (배경제거된 BGR 이미지, 알파 마스크) 반환"""
|
|
try:
|
|
start_time = time.time()
|
|
logger.info(f"배경제거 시작(Bria ONNX): image.shape={image.shape}, model_name={model_name}")
|
|
|
|
if self._session is None:
|
|
raise RuntimeError("Bria RMBG ONNX 세션이 로드되지 않았습니다.")
|
|
|
|
input_tensor, (orig_h, orig_w) = self._preprocess(image)
|
|
mask_pred = self._infer(input_tensor)
|
|
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w))
|
|
|
|
# 흰색 배경 합성 (BGR)
|
|
if image.ndim == 3 and image.shape[2] == 3:
|
|
mask_3 = np.stack([alpha_mask] * 3, axis=-1)
|
|
result_bgr = (
|
|
image.astype(np.float32) * (mask_3.astype(np.float32) / 255.0)
|
|
+ 255.0 * (1.0 - (mask_3.astype(np.float32) / 255.0))
|
|
).clip(0, 255).astype(np.uint8)
|
|
else:
|
|
# 비정상 입력 대비
|
|
result_bgr = image
|
|
|
|
duration = time.time() - start_time
|
|
try:
|
|
logger.info(
|
|
f"Bria ONNX mask stats: min={int(alpha_mask.min())}, max={int(alpha_mask.max())}, "
|
|
f"mean={float(alpha_mask.mean()):.3f}"
|
|
)
|
|
except Exception:
|
|
pass
|
|
logger.info(f"'bria-rmbg' processed in {duration:.3f}s")
|
|
|
|
return result_bgr, alpha_mask
|
|
|
|
except Exception as e:
|
|
logger.error(f"배경 제거 처리 실패(Bria ONNX): {e}", exc_info=True)
|
|
return image, None
|
|
|
|
|