# -*- coding: utf-8 -*- """ src/modules/migan_module.py - MI-GAN ONNX 파이프라인(전/후처리 포함) - 핵심 호환 포인트: 1) 입력: image_path(str), mask(np.ndarray, 0~255 그레이, 텍스트영역=255) 2) 마스크는 내부에서 자동으로 (이진화 -> 반전) 하여 MI-GAN 규칙(255=known, 0=hole)에 맞춤 3) CUDA/CPU 자동 선택, 실패 시 CPU 폴백 의존성: onnxruntime, opencv-python, numpy 설정: toggle_states에 다음 키 사용(선택): - "migan_onnx_path": 파이프라인 ONNX 경로(필수) - "migan_use_cuda": True/False (기본 False) - "migan_intra_threads": int (기본 0 = onnxruntime 기본) - "migan_inter_threads": int (기본 0) """ import os import sys import time import logging from typing import Optional, Dict, Any import cv2 import numpy as np import onnxruntime as ort # OpenCV 내부 최적화 off cv2.setUseOptimized(False) def _np_uint8_2d(arr, name="mask"): if arr is None: raise ValueError(f"{name} is None") if not isinstance(arr, np.ndarray): raise TypeError(f"{name} must be np.ndarray, got {type(arr)}") if arr.ndim != 2: raise ValueError(f"{name} must be 2D, got shape={arr.shape}") if arr.dtype != np.uint8: # 안전 변환 arr = arr.astype(np.uint8, copy=False) return arr def _ensure_logger(logger: Optional[object]) -> logging.Logger: """네 logger( .log(msg, level=) )가 없을 때를 대비한 기본 로거""" if logger and hasattr(logger, "log"): return logger pylogger = logging.getLogger("MIGAN") if not pylogger.handlers: pylogger.setLevel(logging.DEBUG) h = logging.StreamHandler(stream=sys.stdout) h.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")) pylogger.addHandler(h) # .log 호환 어댑터 class _Adapter: def __init__(self, _lg): self._lg = _lg def log(self, msg, level=logging.DEBUG, **kwargs): self._lg.log(level, msg) return _Adapter(pylogger) class MIGANPipelineONNXCompat: """ MI-GAN ONNX 파이프라인 래퍼 - 입력: image_path(str), mask(gray uint8 HxW) ※ 텍스트영역=255 - 내부에서 mask를 (이진화→반전)하여 MI-GAN 규칙(255=known, 0=hole)으로 맞춤 - 출력: BGR uint8(H,W,3) """ _SESSION_CACHE = {} # onnx_path -> InferenceSession 캐시 def __init__(self, onnx_path: str, logger: Optional[object] = None, use_cuda: bool = False, intra_threads: int = 0, inter_threads: int = 0, toggle_states: Optional[Dict[str, Any]] = None, base_dir: Optional[str] = None): self.logger = _ensure_logger(logger) self.onnx_path = onnx_path self.use_cuda = bool(use_cuda) self.intra_threads = int(intra_threads or 0) self.inter_threads = int(inter_threads or 0) self.toggle_states = toggle_states or {} self.base_dir = base_dir # 프로바이더 캐시 파일 경로 결정 try: cache_root = None if self.base_dir: cache_root = os.path.join(self.base_dir, "user_data") else: cache_root = os.path.dirname(self.onnx_path) or os.getcwd() os.makedirs(cache_root, exist_ok=True) self._provider_cache_path = os.path.join(cache_root, "migan_provider.json") except Exception: self._provider_cache_path = None if not os.path.exists(self.onnx_path): self.logger.log(f"[MIGAN] ONNX 파일을 찾을 수 없습니다: {self.onnx_path}", level=logging.ERROR) raise FileNotFoundError(self.onnx_path) self.session = self._get_or_create_session() ins = self.session.get_inputs() outs = self.session.get_outputs() self.in_image = ins[0].name self.in_mask = ins[1].name self.out_name = outs[0].name # 입력/출력 형태 정보 로깅 (디버깅용) for i, inp in enumerate(ins): self.logger.log(f"[MIGAN] 입력 {i}: {inp.name}, 형태: {inp.shape}, 타입: {inp.type}", level=logging.DEBUG) for i, out in enumerate(outs): self.logger.log(f"[MIGAN] 출력 {i}: {out.name}, 형태: {out.shape}, 타입: {out.type}", level=logging.DEBUG) self.logger.log(f"[MIGAN] 세션 준비 완료. providers={self.session.get_providers()}", level=logging.DEBUG) def _get_or_create_session(self) -> ort.InferenceSession: key = (self.onnx_path, self.use_cuda, self.intra_threads, self.inter_threads) if key in self._SESSION_CACHE: return self._SESSION_CACHE[key] # override 우선 적용: auto|dml|cpu try: override = (self.toggle_states or {}).get("migan_provider_override", "auto") if isinstance(override, str): override = override.lower() if override == "dml": self.use_cuda = True elif override == "cpu": self.use_cuda = False else: # auto: 캐시 확인 cached = self._read_provider_cache() if cached == "dml": self.use_cuda = True elif cached == "cpu": self.use_cuda = False except Exception: pass so = ort.SessionOptions() if self.intra_threads > 0: so.intra_op_num_threads = self.intra_threads if self.inter_threads > 0: so.inter_op_num_threads = self.inter_threads # 메모리 절약형 설정 (안전하게 속성 존재 여부 확인) if hasattr(so, 'enable_cpu_mem_arena'): so.enable_cpu_mem_arena = True # CPU 메모리 최적화 (유지) if hasattr(so, 'enable_mem_pattern'): so.enable_mem_pattern = False # 🔧 메모리 절약을 위해 비활성화 if hasattr(so, 'graph_optimization_level'): so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC # 🔧 기본 최적화만 사용 # DirectML 관련 메모리 절약 최적화 (GPU 사용 시, 속성 존재 여부 확인) if self.use_cuda or (hasattr(self, 'gpu_manager') and self.gpu_manager and self.gpu_manager.can_use_cuda): if hasattr(so, 'enable_cuda_graph'): so.enable_cuda_graph = False # 🔧 메모리 절약을 위해 비활성화 # GPU 관리자가 있으면 최적 provider 사용, 없으면 기존 방식 gpu_manager_available = hasattr(self, 'gpu_manager') and self.gpu_manager self.logger.log(f"[MIGAN] GPU 관리자 사용 가능: {gpu_manager_available}", level=logging.DEBUG) # DirectML 기반 GPU 가속 (Windows 호환성 최우선) providers = [] if self.use_cuda: # use_cuda를 GPU 가속 플래그로 사용 try: available_providers = ort.get_available_providers() self.logger.log(f"[MIGAN] 사용 가능한 providers: {available_providers}", level=logging.DEBUG) if 'DmlExecutionProvider' in available_providers: providers = [("DmlExecutionProvider", {}), ("CPUExecutionProvider", {})] self.logger.log(f"[MIGAN] DirectML 가속 활성화", level=logging.INFO) else: providers = [("CPUExecutionProvider", {})] self.logger.log(f"[MIGAN] DirectML 미지원, CPU 모드로 전환", level=logging.WARNING) except Exception as e: self.logger.log(f"[MIGAN] DirectML 설정 실패: {e}", level=logging.WARNING) providers = [("CPUExecutionProvider", {})] else: providers = [("CPUExecutionProvider", {})] self.logger.log(f"[MIGAN] CPU 전용 모드", level=logging.DEBUG) self.logger.log(f"[MIGAN] 최종 providers: {providers}", level=logging.DEBUG) # DirectML 기반 단계별 폴백 전략 sess = None if self.use_cuda: fallback_attempts = [ (providers, "DirectML 가속"), ([("CPUExecutionProvider", {})], "CPU 폴백") ] else: fallback_attempts = [ ([("CPUExecutionProvider", {})], "CPU 전용") ] for attempt_providers, attempt_name in fallback_attempts: try: self.logger.log(f"[MIGAN] {attempt_name} 시도: {attempt_providers}", level=logging.DEBUG) sess = ort.InferenceSession(self.onnx_path, sess_options=so, providers=attempt_providers) actual_providers = sess.get_providers() self.logger.log(f"[MIGAN] {attempt_name} 성공! 실제 providers: {actual_providers}", level=logging.INFO) break except Exception as e: self.logger.log(f"[MIGAN] {attempt_name} 실패: {e}", level=logging.WARNING) if sess: sess = None continue if sess is None: raise RuntimeError("[MIGAN] 모든 폴백 시도 실패") self._SESSION_CACHE[key] = sess # 성공 프로바이더 캐시 try: used = sess.get_providers() prov = 'dml' if any('Dml' in p for p in used) else 'cpu' self._write_provider_cache(prov) except Exception: pass return sess # ─────────────────────────────────────────────────────────────── # 퍼블릭 API: 파이프라인에서 바로 호출 # ─────────────────────────────────────────────────────────────── def inpaint(self, image_path: str, mask_gray_255_text: np.ndarray) -> Optional[np.ndarray]: """ Args image_path: 원본 이미지 경로 (BGR 로딩) mask_gray_255_text: 0~255 그레이, '텍스트영역=255' 형태(네 MaskModule 출력) (GaussianBlur 포함되어 있을 수 있음) Return inpainted BGR 이미지 (np.ndarray) or None """ try: # 1) 입력 이미지 로드 bgr = cv2.imread(image_path, cv2.IMREAD_COLOR) if bgr is None: self.logger.log(f"[MIGAN] 이미지 로드 실패: {image_path}", level=logging.ERROR) return None H, W = bgr.shape[:2] # 2) 마스크 정규화: (이진화 → 반전) 해서 255=known, 0=hole 맞추기 mask = _np_uint8_2d(mask_gray_255_text, name="mask") if mask.shape != (H, W): self.logger.log(f"[MIGAN] 마스크 크기 불일치: mask={mask.shape}, img={(H,W)}", level=logging.ERROR) return None # 이진화: 128 스레시hold 기준 _, mask_bin = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY) # 너의 마스크는 “텍스트=255”, MI-GAN은 “hole=0”이므로 반전 # (텍스트영역 255 -> 0), (배경 0 -> 255) mask_known255 = 255 - mask_bin # 3) RGB 변환 (파이프라인 입력은 RGB uint8) rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) # 4) ONNX 추론 - 배치 차원 추가 및 차원 순서 변경 start = time.time() # ONNX 모델 입력 형태: # - image: (1, 3, H, W) - 배치, 채널, 높이, 너비 순서 # - mask: (1, 1, H, W) - 배치, 채널(1), 높이, 너비 순서 # 이미지: (H, W, 3) -> (1, 3, H, W) rgb_batch = np.expand_dims(rgb, 0).transpose(0, 3, 1, 2) # 마스크: (H, W) -> (1, 1, H, W) mask_batch = np.expand_dims(mask_known255, (0, 1)) self.logger.log(f"[MIGAN] 입력 형태 - 이미지: {rgb_batch.shape}, 마스크: {mask_batch.shape}", level=logging.DEBUG) out = self.session.run( [self.out_name], {self.in_image: rgb_batch, self.in_mask: mask_batch} )[0] # expect RGB uint8(1,3,H,W) # 출력 차원 처리: (1,3,H,W) -> (H,W,3) if out.ndim == 4 and out.shape[0] == 1: out = out[0].transpose(1, 2, 0) # (1,3,H,W) -> (3,H,W) -> (H,W,3) elif out.ndim == 3 and out.shape[0] == 3: # (3,H,W) -> (H,W,3) out = out.transpose(1, 2, 0) self.logger.log(f"[MIGAN] 출력 형태: {out.shape}, dtype: {out.dtype}", level=logging.DEBUG) if not isinstance(out, np.ndarray) or out.ndim != 3 or out.dtype != np.uint8: self.logger.log(f"[MIGAN] ONNX 출력 형식 오류: type={type(out)}, shape={getattr(out,'shape',None)}, dtype={getattr(out,'dtype',None)}", level=logging.ERROR) return None elapsed = (time.time() - start) * 1000.0 self.logger.log(f"[MIGAN] 추론 완료: {elapsed:.2f} ms", level=logging.DEBUG) # 5) BGR로 되돌려 반환 bgr_out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) return bgr_out except Exception as e: error_msg = str(e).lower() if "invalid rank" in error_msg or "invalid argument" in error_msg: self.logger.log(f"[MIGAN] ONNX 입력 차원 오류: {e}", level=logging.ERROR) self.logger.log(f"[MIGAN] 입력 이미지 형태: {rgb_batch.shape if 'rgb_batch' in locals() else 'N/A'}", level=logging.ERROR) self.logger.log(f"[MIGAN] 입력 마스크 형태: {mask_batch.shape if 'mask_batch' in locals() else 'N/A'}", level=logging.ERROR) else: self.logger.log(f"[MIGAN] inpaint 예외: {e}", level=logging.ERROR, exc_info=True) return None # ─────────────────────────────────────────────────────────────── # ImageProcessor3에서 바로 부를 수 있게 하는 편의 함수 # ─────────────────────────────────────────────────────────────── def build_migan_from_toggle(toggle_states: dict, logger: Optional[object] = None, gpu_manager: Optional[object] = None) -> MIGANPipelineONNXCompat: """ toggle_states로부터 설정을 읽어 MIGANPipelineONNXCompat 인스턴스를 생성. 필수 키: - migan_onnx_path 선택 키: - migan_use_cuda (bool) - migan_intra_threads (int) - migan_inter_threads (int) """ onnx_path = toggle_states.get("migan_onnx_path", "") if not onnx_path: raise ValueError("toggle_states['migan_onnx_path'] 가 필요합니다.") use_accel = toggle_states.get("migan_use_accel", None) if use_accel is None: use_accel = toggle_states.get("migan_use_cuda", False) # 호환 키 use_cuda = bool(use_accel) intra = int(toggle_states.get("migan_intra_threads", 0) or 0) inter = int(toggle_states.get("migan_inter_threads", 0) or 0) pipeline = MIGANPipelineONNXCompat( onnx_path=onnx_path, logger=logger, use_cuda=use_cuda, intra_threads=intra, inter_threads=inter, toggle_states=toggle_states, ) # GPU 관리자를 파이프라인 객체에 연결 if gpu_manager: pipeline.gpu_manager = gpu_manager if logger: logger.log(f"[MIGAN] GPU 관리자 연결 완료: {type(gpu_manager).__name__}", level=logging.DEBUG) else: if logger: logger.log(f"[MIGAN] GPU 관리자 없음: gpu_manager={gpu_manager}", level=logging.DEBUG) # 디버깅: gpu_manager 속성 확인 if logger: logger.log(f"[MIGAN] 파이프라인 gpu_manager 속성: {hasattr(pipeline, 'gpu_manager')}, 값: {getattr(pipeline, 'gpu_manager', None)}", level=logging.DEBUG) return pipeline # ─────────────────────────────────────────────────────────────── # 내부 유틸: 프로바이더 캐시 # ─────────────────────────────────────────────────────────────── def _read_provider_cache(self): try: if not getattr(self, "_provider_cache_path", None) or not os.path.exists(self._provider_cache_path): return None import json with open(self._provider_cache_path, 'r', encoding='utf-8') as f: data = json.load(f) prov = (data or {}).get('last_success_provider', '').lower() return prov if prov in ('dml', 'cpu') else None except Exception: return None def _write_provider_cache(self, provider: str): try: if not getattr(self, "_provider_cache_path", None): return import json os.makedirs(os.path.dirname(self._provider_cache_path), exist_ok=True) with open(self._provider_cache_path, 'w', encoding='utf-8') as f: json.dump({"last_success_provider": provider}, f, ensure_ascii=False) except Exception: pass