384 lines
18 KiB
Python
384 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
src/modules/migan_module.py
|
|
|
|
- MI-GAN ONNX 파이프라인(전/후처리 포함)
|
|
- 핵심 호환 포인트:
|
|
1) 입력: image_path(str), mask(np.ndarray, 0~255 그레이, 텍스트영역=255)
|
|
2) 마스크는 내부에서 자동으로 (이진화 -> 반전) 하여 MI-GAN 규칙(255=known, 0=hole)에 맞춤
|
|
3) CUDA/CPU 자동 선택, 실패 시 CPU 폴백
|
|
|
|
의존성: onnxruntime, opencv-python, numpy
|
|
설정: toggle_states에 다음 키 사용(선택):
|
|
- "migan_onnx_path": 파이프라인 ONNX 경로(필수)
|
|
- "migan_use_cuda": True/False (기본 False)
|
|
- "migan_intra_threads": int (기본 0 = onnxruntime 기본)
|
|
- "migan_inter_threads": int (기본 0)
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
# OpenCV 내부 최적화 off
|
|
cv2.setUseOptimized(False)
|
|
|
|
|
|
def _np_uint8_2d(arr, name="mask"):
|
|
if arr is None:
|
|
raise ValueError(f"{name} is None")
|
|
if not isinstance(arr, np.ndarray):
|
|
raise TypeError(f"{name} must be np.ndarray, got {type(arr)}")
|
|
if arr.ndim != 2:
|
|
raise ValueError(f"{name} must be 2D, got shape={arr.shape}")
|
|
if arr.dtype != np.uint8:
|
|
# 안전 변환
|
|
arr = arr.astype(np.uint8, copy=False)
|
|
return arr
|
|
|
|
|
|
def _ensure_logger(logger: Optional[object]) -> logging.Logger:
|
|
"""네 logger( .log(msg, level=) )가 없을 때를 대비한 기본 로거"""
|
|
if logger and hasattr(logger, "log"):
|
|
return logger
|
|
pylogger = logging.getLogger("MIGAN")
|
|
if not pylogger.handlers:
|
|
pylogger.setLevel(logging.DEBUG)
|
|
h = logging.StreamHandler(stream=sys.stdout)
|
|
h.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s"))
|
|
pylogger.addHandler(h)
|
|
# .log 호환 어댑터
|
|
class _Adapter:
|
|
def __init__(self, _lg): self._lg = _lg
|
|
def log(self, msg, level=logging.DEBUG, **kwargs):
|
|
self._lg.log(level, msg)
|
|
return _Adapter(pylogger)
|
|
|
|
|
|
class MIGANPipelineONNXCompat:
|
|
"""
|
|
MI-GAN ONNX 파이프라인 래퍼
|
|
- 입력: image_path(str), mask(gray uint8 HxW) ※ 텍스트영역=255
|
|
- 내부에서 mask를 (이진화→반전)하여 MI-GAN 규칙(255=known, 0=hole)으로 맞춤
|
|
- 출력: BGR uint8(H,W,3)
|
|
"""
|
|
|
|
_SESSION_CACHE = {} # onnx_path -> InferenceSession 캐시
|
|
|
|
def __init__(self,
|
|
onnx_path: str,
|
|
logger: Optional[object] = None,
|
|
use_cuda: bool = False,
|
|
intra_threads: int = 0,
|
|
inter_threads: int = 0,
|
|
toggle_states: Optional[Dict[str, Any]] = None,
|
|
base_dir: Optional[str] = None):
|
|
self.logger = _ensure_logger(logger)
|
|
self.onnx_path = onnx_path
|
|
self.use_cuda = bool(use_cuda)
|
|
self.intra_threads = int(intra_threads or 0)
|
|
self.inter_threads = int(inter_threads or 0)
|
|
self.toggle_states = toggle_states or {}
|
|
self.base_dir = base_dir
|
|
|
|
# 프로바이더 캐시 파일 경로 결정
|
|
try:
|
|
cache_root = None
|
|
if self.base_dir:
|
|
cache_root = os.path.join(self.base_dir, "user_data")
|
|
else:
|
|
cache_root = os.path.dirname(self.onnx_path) or os.getcwd()
|
|
os.makedirs(cache_root, exist_ok=True)
|
|
self._provider_cache_path = os.path.join(cache_root, "migan_provider.json")
|
|
except Exception:
|
|
self._provider_cache_path = None
|
|
|
|
if not os.path.exists(self.onnx_path):
|
|
self.logger.log(f"[MIGAN] ONNX 파일을 찾을 수 없습니다: {self.onnx_path}", level=logging.ERROR)
|
|
raise FileNotFoundError(self.onnx_path)
|
|
|
|
self.session = self._get_or_create_session()
|
|
ins = self.session.get_inputs()
|
|
outs = self.session.get_outputs()
|
|
self.in_image = ins[0].name
|
|
self.in_mask = ins[1].name
|
|
self.out_name = outs[0].name
|
|
|
|
# 입력/출력 형태 정보 로깅 (디버깅용)
|
|
for i, inp in enumerate(ins):
|
|
self.logger.log(f"[MIGAN] 입력 {i}: {inp.name}, 형태: {inp.shape}, 타입: {inp.type}", level=logging.DEBUG)
|
|
for i, out in enumerate(outs):
|
|
self.logger.log(f"[MIGAN] 출력 {i}: {out.name}, 형태: {out.shape}, 타입: {out.type}", level=logging.DEBUG)
|
|
|
|
self.logger.log(f"[MIGAN] 세션 준비 완료. providers={self.session.get_providers()}", level=logging.DEBUG)
|
|
|
|
def _get_or_create_session(self) -> ort.InferenceSession:
|
|
key = (self.onnx_path, self.use_cuda, self.intra_threads, self.inter_threads)
|
|
if key in self._SESSION_CACHE:
|
|
return self._SESSION_CACHE[key]
|
|
|
|
# override 우선 적용: auto|dml|cpu
|
|
try:
|
|
override = (self.toggle_states or {}).get("migan_provider_override", "auto")
|
|
if isinstance(override, str):
|
|
override = override.lower()
|
|
if override == "dml":
|
|
self.use_cuda = True
|
|
elif override == "cpu":
|
|
self.use_cuda = False
|
|
else:
|
|
# auto: 캐시 확인
|
|
cached = self._read_provider_cache()
|
|
if cached == "dml":
|
|
self.use_cuda = True
|
|
elif cached == "cpu":
|
|
self.use_cuda = False
|
|
except Exception:
|
|
pass
|
|
|
|
so = ort.SessionOptions()
|
|
if self.intra_threads > 0:
|
|
so.intra_op_num_threads = self.intra_threads
|
|
if self.inter_threads > 0:
|
|
so.inter_op_num_threads = self.inter_threads
|
|
|
|
# 메모리 절약형 설정 (안전하게 속성 존재 여부 확인)
|
|
if hasattr(so, 'enable_cpu_mem_arena'):
|
|
so.enable_cpu_mem_arena = True # CPU 메모리 최적화 (유지)
|
|
if hasattr(so, 'enable_mem_pattern'):
|
|
so.enable_mem_pattern = False # 🔧 메모리 절약을 위해 비활성화
|
|
if hasattr(so, 'graph_optimization_level'):
|
|
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC # 🔧 기본 최적화만 사용
|
|
|
|
# DirectML 관련 메모리 절약 최적화 (GPU 사용 시, 속성 존재 여부 확인)
|
|
if self.use_cuda or (hasattr(self, 'gpu_manager') and self.gpu_manager and self.gpu_manager.can_use_cuda):
|
|
if hasattr(so, 'enable_cuda_graph'):
|
|
so.enable_cuda_graph = False # 🔧 메모리 절약을 위해 비활성화
|
|
|
|
# GPU 관리자가 있으면 최적 provider 사용, 없으면 기존 방식
|
|
gpu_manager_available = hasattr(self, 'gpu_manager') and self.gpu_manager
|
|
self.logger.log(f"[MIGAN] GPU 관리자 사용 가능: {gpu_manager_available}", level=logging.DEBUG)
|
|
|
|
# DirectML 기반 GPU 가속 (Windows 호환성 최우선)
|
|
providers = []
|
|
if self.use_cuda: # use_cuda를 GPU 가속 플래그로 사용
|
|
try:
|
|
available_providers = ort.get_available_providers()
|
|
self.logger.log(f"[MIGAN] 사용 가능한 providers: {available_providers}", level=logging.DEBUG)
|
|
|
|
if 'DmlExecutionProvider' in available_providers:
|
|
providers = [("DmlExecutionProvider", {}), ("CPUExecutionProvider", {})]
|
|
self.logger.log(f"[MIGAN] DirectML 가속 활성화", level=logging.INFO)
|
|
else:
|
|
providers = [("CPUExecutionProvider", {})]
|
|
self.logger.log(f"[MIGAN] DirectML 미지원, CPU 모드로 전환", level=logging.WARNING)
|
|
except Exception as e:
|
|
self.logger.log(f"[MIGAN] DirectML 설정 실패: {e}", level=logging.WARNING)
|
|
providers = [("CPUExecutionProvider", {})]
|
|
else:
|
|
providers = [("CPUExecutionProvider", {})]
|
|
self.logger.log(f"[MIGAN] CPU 전용 모드", level=logging.DEBUG)
|
|
|
|
self.logger.log(f"[MIGAN] 최종 providers: {providers}", level=logging.DEBUG)
|
|
|
|
# DirectML 기반 단계별 폴백 전략
|
|
sess = None
|
|
if self.use_cuda:
|
|
fallback_attempts = [
|
|
(providers, "DirectML 가속"),
|
|
([("CPUExecutionProvider", {})], "CPU 폴백")
|
|
]
|
|
else:
|
|
fallback_attempts = [
|
|
([("CPUExecutionProvider", {})], "CPU 전용")
|
|
]
|
|
|
|
for attempt_providers, attempt_name in fallback_attempts:
|
|
try:
|
|
self.logger.log(f"[MIGAN] {attempt_name} 시도: {attempt_providers}", level=logging.DEBUG)
|
|
sess = ort.InferenceSession(self.onnx_path, sess_options=so, providers=attempt_providers)
|
|
actual_providers = sess.get_providers()
|
|
self.logger.log(f"[MIGAN] {attempt_name} 성공! 실제 providers: {actual_providers}", level=logging.INFO)
|
|
break
|
|
except Exception as e:
|
|
self.logger.log(f"[MIGAN] {attempt_name} 실패: {e}", level=logging.WARNING)
|
|
if sess:
|
|
sess = None
|
|
continue
|
|
|
|
if sess is None:
|
|
raise RuntimeError("[MIGAN] 모든 폴백 시도 실패")
|
|
|
|
self._SESSION_CACHE[key] = sess
|
|
# 성공 프로바이더 캐시
|
|
try:
|
|
used = sess.get_providers()
|
|
prov = 'dml' if any('Dml' in p for p in used) else 'cpu'
|
|
self._write_provider_cache(prov)
|
|
except Exception:
|
|
pass
|
|
return sess
|
|
|
|
# ───────────────────────────────────────────────────────────────
|
|
# 퍼블릭 API: 파이프라인에서 바로 호출
|
|
# ───────────────────────────────────────────────────────────────
|
|
def inpaint(self, image_path: str, mask_gray_255_text: np.ndarray) -> Optional[np.ndarray]:
|
|
"""
|
|
Args
|
|
image_path: 원본 이미지 경로 (BGR 로딩)
|
|
mask_gray_255_text: 0~255 그레이, '텍스트영역=255' 형태(네 MaskModule 출력)
|
|
(GaussianBlur 포함되어 있을 수 있음)
|
|
Return
|
|
inpainted BGR 이미지 (np.ndarray) or None
|
|
"""
|
|
try:
|
|
# 1) 입력 이미지 로드
|
|
bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
|
if bgr is None:
|
|
self.logger.log(f"[MIGAN] 이미지 로드 실패: {image_path}", level=logging.ERROR)
|
|
return None
|
|
|
|
H, W = bgr.shape[:2]
|
|
|
|
# 2) 마스크 정규화: (이진화 → 반전) 해서 255=known, 0=hole 맞추기
|
|
mask = _np_uint8_2d(mask_gray_255_text, name="mask")
|
|
if mask.shape != (H, W):
|
|
self.logger.log(f"[MIGAN] 마스크 크기 불일치: mask={mask.shape}, img={(H,W)}", level=logging.ERROR)
|
|
return None
|
|
|
|
# 이진화: 128 스레시hold 기준
|
|
_, mask_bin = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY)
|
|
# 너의 마스크는 “텍스트=255”, MI-GAN은 “hole=0”이므로 반전
|
|
# (텍스트영역 255 -> 0), (배경 0 -> 255)
|
|
mask_known255 = 255 - mask_bin
|
|
|
|
# 3) RGB 변환 (파이프라인 입력은 RGB uint8)
|
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
# 4) ONNX 추론 - 배치 차원 추가 및 차원 순서 변경
|
|
start = time.time()
|
|
# ONNX 모델 입력 형태:
|
|
# - image: (1, 3, H, W) - 배치, 채널, 높이, 너비 순서
|
|
# - mask: (1, 1, H, W) - 배치, 채널(1), 높이, 너비 순서
|
|
|
|
# 이미지: (H, W, 3) -> (1, 3, H, W)
|
|
rgb_batch = np.expand_dims(rgb, 0).transpose(0, 3, 1, 2)
|
|
|
|
# 마스크: (H, W) -> (1, 1, H, W)
|
|
mask_batch = np.expand_dims(mask_known255, (0, 1))
|
|
|
|
self.logger.log(f"[MIGAN] 입력 형태 - 이미지: {rgb_batch.shape}, 마스크: {mask_batch.shape}", level=logging.DEBUG)
|
|
|
|
out = self.session.run(
|
|
[self.out_name],
|
|
{self.in_image: rgb_batch, self.in_mask: mask_batch}
|
|
)[0] # expect RGB uint8(1,3,H,W)
|
|
|
|
# 출력 차원 처리: (1,3,H,W) -> (H,W,3)
|
|
if out.ndim == 4 and out.shape[0] == 1:
|
|
out = out[0].transpose(1, 2, 0) # (1,3,H,W) -> (3,H,W) -> (H,W,3)
|
|
elif out.ndim == 3 and out.shape[0] == 3: # (3,H,W) -> (H,W,3)
|
|
out = out.transpose(1, 2, 0)
|
|
|
|
self.logger.log(f"[MIGAN] 출력 형태: {out.shape}, dtype: {out.dtype}", level=logging.DEBUG)
|
|
|
|
if not isinstance(out, np.ndarray) or out.ndim != 3 or out.dtype != np.uint8:
|
|
self.logger.log(f"[MIGAN] ONNX 출력 형식 오류: type={type(out)}, shape={getattr(out,'shape',None)}, dtype={getattr(out,'dtype',None)}",
|
|
level=logging.ERROR)
|
|
return None
|
|
|
|
elapsed = (time.time() - start) * 1000.0
|
|
self.logger.log(f"[MIGAN] 추론 완료: {elapsed:.2f} ms", level=logging.DEBUG)
|
|
|
|
# 5) BGR로 되돌려 반환
|
|
bgr_out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
|
|
return bgr_out
|
|
|
|
except Exception as e:
|
|
error_msg = str(e).lower()
|
|
if "invalid rank" in error_msg or "invalid argument" in error_msg:
|
|
self.logger.log(f"[MIGAN] ONNX 입력 차원 오류: {e}", level=logging.ERROR)
|
|
self.logger.log(f"[MIGAN] 입력 이미지 형태: {rgb_batch.shape if 'rgb_batch' in locals() else 'N/A'}", level=logging.ERROR)
|
|
self.logger.log(f"[MIGAN] 입력 마스크 형태: {mask_batch.shape if 'mask_batch' in locals() else 'N/A'}", level=logging.ERROR)
|
|
else:
|
|
self.logger.log(f"[MIGAN] inpaint 예외: {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|
|
|
|
|
|
# ───────────────────────────────────────────────────────────────
|
|
# ImageProcessor3에서 바로 부를 수 있게 하는 편의 함수
|
|
# ───────────────────────────────────────────────────────────────
|
|
def build_migan_from_toggle(toggle_states: dict, logger: Optional[object] = None, gpu_manager: Optional[object] = None) -> MIGANPipelineONNXCompat:
|
|
"""
|
|
toggle_states로부터 설정을 읽어 MIGANPipelineONNXCompat 인스턴스를 생성.
|
|
필수 키:
|
|
- migan_onnx_path
|
|
선택 키:
|
|
- migan_use_cuda (bool)
|
|
- migan_intra_threads (int)
|
|
- migan_inter_threads (int)
|
|
"""
|
|
onnx_path = toggle_states.get("migan_onnx_path", "")
|
|
if not onnx_path:
|
|
raise ValueError("toggle_states['migan_onnx_path'] 가 필요합니다.")
|
|
use_accel = toggle_states.get("migan_use_accel", None)
|
|
if use_accel is None:
|
|
use_accel = toggle_states.get("migan_use_cuda", False) # 호환 키
|
|
use_cuda = bool(use_accel)
|
|
intra = int(toggle_states.get("migan_intra_threads", 0) or 0)
|
|
inter = int(toggle_states.get("migan_inter_threads", 0) or 0)
|
|
pipeline = MIGANPipelineONNXCompat(
|
|
onnx_path=onnx_path,
|
|
logger=logger,
|
|
use_cuda=use_cuda,
|
|
intra_threads=intra,
|
|
inter_threads=inter,
|
|
toggle_states=toggle_states,
|
|
)
|
|
# GPU 관리자를 파이프라인 객체에 연결
|
|
if gpu_manager:
|
|
pipeline.gpu_manager = gpu_manager
|
|
if logger:
|
|
logger.log(f"[MIGAN] GPU 관리자 연결 완료: {type(gpu_manager).__name__}", level=logging.DEBUG)
|
|
else:
|
|
if logger:
|
|
logger.log(f"[MIGAN] GPU 관리자 없음: gpu_manager={gpu_manager}", level=logging.DEBUG)
|
|
|
|
# 디버깅: gpu_manager 속성 확인
|
|
if logger:
|
|
logger.log(f"[MIGAN] 파이프라인 gpu_manager 속성: {hasattr(pipeline, 'gpu_manager')}, 값: {getattr(pipeline, 'gpu_manager', None)}", level=logging.DEBUG)
|
|
|
|
return pipeline
|
|
|
|
# ───────────────────────────────────────────────────────────────
|
|
# 내부 유틸: 프로바이더 캐시
|
|
# ───────────────────────────────────────────────────────────────
|
|
def _read_provider_cache(self):
|
|
try:
|
|
if not getattr(self, "_provider_cache_path", None) or not os.path.exists(self._provider_cache_path):
|
|
return None
|
|
import json
|
|
with open(self._provider_cache_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
prov = (data or {}).get('last_success_provider', '').lower()
|
|
return prov if prov in ('dml', 'cpu') else None
|
|
except Exception:
|
|
return None
|
|
|
|
def _write_provider_cache(self, provider: str):
|
|
try:
|
|
if not getattr(self, "_provider_cache_path", None):
|
|
return
|
|
import json
|
|
os.makedirs(os.path.dirname(self._provider_cache_path), exist_ok=True)
|
|
with open(self._provider_cache_path, 'w', encoding='utf-8') as f:
|
|
json.dump({"last_success_provider": provider}, f, ensure_ascii=False)
|
|
except Exception:
|
|
pass
|