285 lines
11 KiB
Python
285 lines
11 KiB
Python
"""
|
|
MIGAN ONNX 인페인팅 모델 구현 (실제 ONNX 파이프라인 사용)
|
|
"""
|
|
import os
|
|
import time
|
|
import logging
|
|
from typing import Optional, Union
|
|
import cv2
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
|
|
# OpenCV 내부 최적화 off
|
|
cv2.setUseOptimized(False)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _np_uint8_2d(arr, name="mask"):
|
|
if arr is None:
|
|
raise ValueError(f"{name} is None")
|
|
if not isinstance(arr, np.ndarray):
|
|
raise TypeError(f"{name} must be np.ndarray, got {type(arr)}")
|
|
if arr.ndim != 2:
|
|
raise ValueError(f"{name} must be 2D, got shape={arr.shape}")
|
|
if arr.dtype != np.uint8:
|
|
# 안전 변환
|
|
arr = arr.astype(np.uint8, copy=False)
|
|
return arr
|
|
|
|
|
|
class MiganInpainter:
|
|
"""
|
|
MIGAN ONNX 파이프라인 래퍼
|
|
- 입력: image_path(str), mask(gray uint8 HxW) ※ 텍스트영역=255
|
|
- 내부에서 mask를 (이진화→반전)하여 MI-GAN 규칙(255=known, 0=hole)으로 맞춤
|
|
- 출력: BGR uint8(H,W,3)
|
|
"""
|
|
def __init__(self,
|
|
model_path: str = None,
|
|
device: str = "cuda",
|
|
fp16: bool = True,
|
|
use_cuda: bool = False,
|
|
intra_threads: int = 0,
|
|
inter_threads: int = 0):
|
|
self.model_path = model_path
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.use_cuda = bool(use_cuda)
|
|
self.intra_threads = int(intra_threads or 0)
|
|
self.inter_threads = int(inter_threads or 0)
|
|
self.loaded = False
|
|
|
|
if not model_path or not os.path.exists(model_path):
|
|
logger.error(f"MIGAN ONNX 파일을 찾을 수 없습니다: {model_path}")
|
|
raise FileNotFoundError(f"MIGAN ONNX 파일이 없습니다: {model_path}")
|
|
|
|
self.session = None
|
|
self._session = None
|
|
self.in_image = None
|
|
self.in_mask = None
|
|
self.out_name = None
|
|
|
|
async def _get_or_create_session(self):
|
|
"""ONNX 런타임 세션을 생성하거나 기존 세션을 반환합니다."""
|
|
if self._session is None:
|
|
try:
|
|
logger.info("MIGAN ONNX 런타임 세션 생성 시도...")
|
|
import onnxruntime as ort
|
|
|
|
so = ort.SessionOptions()
|
|
if self.intra_threads > 0:
|
|
so.intra_op_num_threads = self.intra_threads
|
|
if self.inter_threads > 0:
|
|
so.inter_op_num_threads = self.inter_threads
|
|
|
|
providers = []
|
|
if self.use_cuda:
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
logger.info(f"MIGAN ONNX providers 설정: {providers}")
|
|
else:
|
|
providers = ['CPUExecutionProvider']
|
|
logger.info("MIGAN ONNX CPU-only mode 로 설정")
|
|
|
|
self._session = ort.InferenceSession(
|
|
self.model_path,
|
|
sess_options=so,
|
|
providers=providers
|
|
)
|
|
logger.info(f"MIGAN ONNX 세션 생성 완료. Providers: {self._session.get_providers()}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"MIGAN ONNX 세션 초기화 실패: {e}", exc_info=True)
|
|
if 'ort' in locals():
|
|
logger.error(f"사용 가능한 providers: {ort.get_available_providers()}")
|
|
raise RuntimeError(f"MIGAN 모델 초기화 실패: {e}")
|
|
|
|
return self._session
|
|
|
|
async def load_model(self):
|
|
"""모델을 비동기적으로 로드합니다."""
|
|
if self.loaded:
|
|
return
|
|
|
|
try:
|
|
logger.info("Loading MIGAN ONNX model...")
|
|
|
|
self.session = await self._get_or_create_session()
|
|
ins = self.session.get_inputs()
|
|
outs = self.session.get_outputs()
|
|
self.in_image = ins[0].name
|
|
self.in_mask = ins[1].name
|
|
self.out_name = outs[0].name
|
|
|
|
for i, inp in enumerate(ins):
|
|
logger.debug(f"MIGAN 입력 {i}: {inp.name}, 형태: {inp.shape}, 타입: {inp.type}")
|
|
for i, out in enumerate(outs):
|
|
logger.debug(f"MIGAN 출력 {i}: {out.name}, 형태: {out.shape}, 타입: {out.type}")
|
|
|
|
logger.debug(f"MIGAN 세션 준비 완료. providers={self.session.get_providers()}")
|
|
|
|
self.loaded = True
|
|
logger.info("MIGAN ONNX model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load MIGAN model: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def inpaint(self, image: Union[str, Image.Image, np.ndarray],
|
|
mask: Union[Image.Image, np.ndarray]) -> np.ndarray:
|
|
"""
|
|
인페인팅을 수행합니다.
|
|
|
|
Args:
|
|
image: 원본 이미지 (파일 경로, PIL Image, 또는 numpy array)
|
|
mask: 마스크 (PIL Image 또는 numpy array, 텍스트영역=255)
|
|
|
|
Returns:
|
|
인페인팅된 이미지 (BGR numpy array)
|
|
"""
|
|
if not self.loaded:
|
|
await self.load_model()
|
|
|
|
try:
|
|
# 1) 입력 이미지 로드
|
|
if isinstance(image, str):
|
|
bgr = cv2.imread(image, cv2.IMREAD_COLOR)
|
|
if bgr is None:
|
|
logger.error(f"MIGAN 이미지 로드 실패: {image}")
|
|
return None
|
|
elif isinstance(image, Image.Image):
|
|
bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
elif isinstance(image, np.ndarray):
|
|
if image.shape[2] == 3:
|
|
bgr = image.copy()
|
|
else:
|
|
logger.error(f"MIGAN 지원하지 않는 이미지 형태: {image.shape}")
|
|
return None
|
|
else:
|
|
logger.error(f"MIGAN 지원하지 않는 이미지 타입: {type(image)}")
|
|
return None
|
|
|
|
H, W = bgr.shape[:2]
|
|
|
|
# 2) 마스크 정규화: (이진화 → 반전) 해서 255=known, 0=hole 맞추기
|
|
if isinstance(mask, Image.Image):
|
|
mask_array = np.array(mask)
|
|
else:
|
|
mask_array = mask
|
|
|
|
mask_normalized = _np_uint8_2d(mask_array, name="mask")
|
|
if mask_normalized.shape != (H, W):
|
|
logger.error(f"MIGAN 마스크 크기 불일치: mask={mask_normalized.shape}, img={(H,W)}")
|
|
return None
|
|
|
|
# 이진화: 128 threshold 기준
|
|
_, mask_bin = cv2.threshold(mask_normalized, 128, 255, cv2.THRESH_BINARY)
|
|
# 마스크 반전: 텍스트영역 255 -> 0 (hole), 배경 0 -> 255 (known)
|
|
mask_known255 = 255 - mask_bin
|
|
|
|
# 3) RGB 변환 (파이프라인 입력은 RGB uint8)
|
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
# 4) ONNX 추론 - 배치 차원 추가 및 차원 순서 변경
|
|
start = time.time()
|
|
# ONNX 모델 입력 형태:
|
|
# - image: (1, 3, H, W) - 배치, 채널, 높이, 너비 순서
|
|
# - mask: (1, 1, H, W) - 배치, 채널(1), 높이, 너비 순서
|
|
|
|
# 이미지: (H, W, 3) -> (1, 3, H, W)
|
|
rgb_batch = np.expand_dims(rgb, 0).transpose(0, 3, 1, 2)
|
|
|
|
# 마스크: (H, W) -> (1, 1, H, W)
|
|
mask_batch = np.expand_dims(mask_known255, (0, 1))
|
|
|
|
logger.debug(f"MIGAN 입력 형태 - 이미지: {rgb_batch.shape}, 마스크: {mask_batch.shape}")
|
|
|
|
out = self.session.run(
|
|
[self.out_name],
|
|
{self.in_image: rgb_batch, self.in_mask: mask_batch}
|
|
)[0] # expect RGB uint8(1,3,H,W)
|
|
|
|
# 출력 차원 처리: (1,3,H,W) -> (H,W,3)
|
|
if out.ndim == 4 and out.shape[0] == 1:
|
|
out = out[0].transpose(1, 2, 0) # (1,3,H,W) -> (3,H,W) -> (H,W,3)
|
|
elif out.ndim == 3 and out.shape[0] == 3: # (3,H,W) -> (H,W,3)
|
|
out = out.transpose(1, 2, 0)
|
|
|
|
logger.debug(f"MIGAN 출력 형태: {out.shape}, dtype: {out.dtype}")
|
|
|
|
if not isinstance(out, np.ndarray) or out.ndim != 3 or out.dtype != np.uint8:
|
|
logger.error(f"MIGAN ONNX 출력 형식 오류: type={type(out)}, shape={getattr(out,'shape',None)}, dtype={getattr(out,'dtype',None)}")
|
|
return None
|
|
|
|
elapsed = (time.time() - start) * 1000.0
|
|
logger.debug(f"MIGAN 추론 완료: {elapsed:.2f} ms")
|
|
|
|
# 5) BGR로 되돌려 반환
|
|
bgr_out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
|
|
return bgr_out
|
|
|
|
except Exception as e:
|
|
error_msg = str(e).lower()
|
|
if "invalid rank" in error_msg or "invalid argument" in error_msg:
|
|
logger.error(f"MIGAN ONNX 입력 차원 오류: {e}")
|
|
logger.error(f"MIGAN 입력 이미지 형태: {rgb_batch.shape if 'rgb_batch' in locals() else 'N/A'}")
|
|
logger.error(f"MIGAN 입력 마스크 형태: {mask_batch.shape if 'mask_batch' in locals() else 'N/A'}")
|
|
else:
|
|
logger.error(f"MIGAN inpaint 예외: {e}", exc_info=True)
|
|
return None
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""모델 정보를 반환합니다."""
|
|
return {
|
|
"model_type": "migan",
|
|
"model_path": self.model_path,
|
|
"device": self.device,
|
|
"fp16": self.fp16,
|
|
"use_cuda": self.use_cuda,
|
|
"loaded": self.loaded,
|
|
"providers": self.session.get_providers() if self.session else None
|
|
}
|
|
|
|
|
|
# 편의 함수: 설정으로부터 MIGAN 인스턴스 생성
|
|
def build_migan_from_config(config: dict, logger: Optional[object] = None, gpu_manager: Optional[object] = None) -> MiganInpainter:
|
|
"""
|
|
설정으로부터 MiganInpainter 인스턴스를 생성.
|
|
필수 키:
|
|
- migan_onnx_path
|
|
선택 키:
|
|
- migan_use_cuda (bool)
|
|
- migan_intra_threads (int)
|
|
- migan_inter_threads (int)
|
|
"""
|
|
onnx_path = config.get("migan_onnx_path", "")
|
|
if not onnx_path:
|
|
raise ValueError("config['migan_onnx_path'] 가 필요합니다.")
|
|
use_cuda = bool(config.get("migan_use_cuda", False))
|
|
intra = int(config.get("migan_intra_threads", 0) or 0)
|
|
inter = int(config.get("migan_inter_threads", 0) or 0)
|
|
|
|
inpainter = MiganInpainter(
|
|
model_path=onnx_path,
|
|
device="cuda" if use_cuda else "cpu",
|
|
fp16=False, # ONNX에서는 fp16 사용하지 않음
|
|
use_cuda=use_cuda,
|
|
intra_threads=intra,
|
|
inter_threads=inter,
|
|
)
|
|
|
|
# GPU 관리자를 인페인터 객체에 연결
|
|
if gpu_manager:
|
|
inpainter.gpu_manager = gpu_manager
|
|
if logger:
|
|
logger.log(f"MIGAN GPU 관리자 연결 완료: {type(gpu_manager).__name__}", level=logging.DEBUG)
|
|
else:
|
|
if logger:
|
|
logger.log(f"MIGAN GPU 관리자 없음: gpu_manager={gpu_manager}", level=logging.DEBUG)
|
|
|
|
# 디버깅: gpu_manager 속성 확인
|
|
if logger:
|
|
logger.log(f"MIGAN 인페인터 gpu_manager 속성: {hasattr(inpainter, 'gpu_manager')}, 값: {getattr(inpainter, 'gpu_manager', None)}", level=logging.DEBUG)
|
|
|
|
return inpainter |