inpaintServer/app/models/bria_rmbg_onnx.py

159 lines
6.2 KiB
Python

"""
BriaAI RMBG 1.4 ONNXRuntime 기반 배경제거 프로세서
rembg를 사용하지 않고 ONNX 모델을 직접 로드하여 추론합니다.
"""
import logging
import time
from typing import Optional, Tuple
import numpy as np
import cv2
from PIL import Image
try:
import onnxruntime as ort
except Exception as e: # pragma: no cover
ort = None # 런타임에서 에러 메시지로 안내
from ..core.config import settings
logger = logging.getLogger(__name__)
class BriaRMBGOnnxProcessor:
"""BriaAI RMBG 1.4 ONNX 모델을 사용하는 배경 제거 프로세서"""
def __init__(self, *args, **kwargs):
self._session: Optional["ort.InferenceSession"] = None
self._input_name: Optional[str] = None
self._output_name: Optional[str] = None
self._model_input_size: Tuple[int, int] = (1024, 1024) # (W, H)
logger.info("BriaRMBGOnnxProcessor 초기화 완료")
async def load_model(self) -> bool:
"""ONNX 세션을 로드합니다."""
if ort is None:
logger.error("onnxruntime 모듈을 불러오지 못했습니다. onnxruntime 패키지를 설치하세요.")
return False
model_path = settings.REMBG_MODEL_PATH
logger.info(f"Bria RMBG ONNX 세션 생성 중... path={model_path}")
try:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Jetson에서는 TensorRT EP 호환 이슈가 있을 수 있어 CUDA, CPU 우선 사용
if settings.USE_CUDA:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self._session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
inputs = self._session.get_inputs()
outputs = self._session.get_outputs()
if not inputs or not outputs:
raise RuntimeError("ONNX 모델의 입출력 정의를 찾을 수 없습니다.")
self._input_name = inputs[0].name
self._output_name = outputs[0].name
logger.info(
f"Bria RMBG ONNX 세션 생성 완료, Providers: {self._session.get_providers()} | "
f"Input: {self._input_name}, Output: {self._output_name}"
)
return True
except Exception as e:
logger.error(f"Bria RMBG ONNX 세션 생성 실패: {e}")
return False
def _preprocess(self, image_bgr: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
"""BGR uint8 이미지를 모델 입력(NCHW float32, 정규화)로 변환"""
orig_h, orig_w = image_bgr.shape[:2]
# BGR -> RGB
image_rgb = image_bgr[:, :, ::-1]
# 리사이즈 (W,H)
target_w, target_h = self._model_input_size
resized = cv2.resize(image_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# float32 [0,1] -> normalize(mean=0.5, std=1.0) == x - 0.5
tensor = resized.astype(np.float32) / 255.0
tensor = tensor - 0.5
# HWC -> CHW, 배치 축 추가
nchw = np.transpose(tensor, (2, 0, 1))[np.newaxis, ...]
return nchw, (orig_h, orig_w)
def _infer(self, input_tensor: np.ndarray) -> np.ndarray:
"""ONNX 추론 수행 후 [H,W] 마스크 확률맵 반환(0~1 범위 전처리 전)"""
outputs = self._session.run([self._output_name], {self._input_name: input_tensor})
pred = outputs[0]
# 예상 출력: [1, 1, H, W] 또는 [1, H, W]
pred = np.array(pred)
if pred.ndim == 4:
pred = pred[0, 0]
elif pred.ndim == 3:
pred = pred[0]
# 이제 pred는 [H, W]
return pred
def _postprocess(self, mask_pred: np.ndarray, orig_size: Tuple[int, int]) -> np.ndarray:
"""모델 출력 마스크를 원본 해상도로 보간하고 0..255 uint8로 변환"""
orig_h, orig_w = orig_size
# 모델 출력(H,W)을 원본 크기로 리사이즈 (W,H)
mask_resized = cv2.resize(mask_pred, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
# min-max 정규화 안전 처리
ma = float(mask_resized.max())
mi = float(mask_resized.min())
denom = (ma - mi) if (ma - mi) != 0 else 1.0
mask_norm = (mask_resized - mi) / denom
mask_u8 = (mask_norm * 255.0).clip(0, 255).astype(np.uint8)
return mask_u8
async def remove_background(self, image: np.ndarray, model_name: str = None) -> tuple:
"""이미지에서 배경을 제거하여 (배경제거된 BGR 이미지, 알파 마스크) 반환"""
try:
start_time = time.time()
logger.info(f"배경제거 시작(Bria ONNX): image.shape={image.shape}, model_name={model_name}")
if self._session is None:
raise RuntimeError("Bria RMBG ONNX 세션이 로드되지 않았습니다.")
input_tensor, (orig_h, orig_w) = self._preprocess(image)
mask_pred = self._infer(input_tensor)
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w))
# 흰색 배경 합성 (BGR)
if image.ndim == 3 and image.shape[2] == 3:
mask_3 = np.stack([alpha_mask] * 3, axis=-1)
result_bgr = (
image.astype(np.float32) * (mask_3.astype(np.float32) / 255.0)
+ 255.0 * (1.0 - (mask_3.astype(np.float32) / 255.0))
).clip(0, 255).astype(np.uint8)
else:
# 비정상 입력 대비
result_bgr = image
duration = time.time() - start_time
try:
logger.info(
f"Bria ONNX mask stats: min={int(alpha_mask.min())}, max={int(alpha_mask.max())}, "
f"mean={float(alpha_mask.mean()):.3f}"
)
except Exception:
pass
logger.info(f"'bria-rmbg' processed in {duration:.3f}s")
return result_bgr, alpha_mask
except Exception as e:
logger.error(f"배경 제거 처리 실패(Bria ONNX): {e}", exc_info=True)
return image, None