64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
tests/test_migan_module.py
|
|
|
|
- 로컬 이미지와 마스크로 MIGAN 파이프라인을 단독 테스트
|
|
- 너의 MaskModule 없이도 간단 폴리곤으로 마스크 생성 가능
|
|
"""
|
|
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import logging
|
|
from src.modules.migan_module import MIGANPipelineONNXCompat
|
|
|
|
def get_logger():
|
|
lg = logging.getLogger("MIGAN_TEST")
|
|
if not lg.handlers:
|
|
lg.setLevel(logging.DEBUG)
|
|
h = logging.StreamHandler()
|
|
h.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s"))
|
|
lg.addHandler(h)
|
|
# 네 프로젝트와 호환(.log 인터페이스)
|
|
class _Adapter:
|
|
def __init__(self, _lg): self._lg = _lg
|
|
def log(self, msg, level=logging.INFO, **kwargs): self._lg.log(level, msg)
|
|
return _Adapter(lg)
|
|
|
|
def make_demo_mask_like_yours(img_path: str):
|
|
"""네 MaskModule 스타일(텍스트영역=255) 흉내내기: 사각형 2개를 255로 채움 + 블러"""
|
|
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
|
h,w = img.shape[:2]
|
|
mask = np.zeros((h,w), np.uint8)
|
|
cv2.rectangle(mask, (int(0.1*w), int(0.2*h)), (int(0.4*w), int(0.3*h)), 255, -1)
|
|
cv2.rectangle(mask, (int(0.6*w), int(0.55*h)), (int(0.9*w), int(0.65*h)), 255, -1)
|
|
mask = cv2.GaussianBlur(mask, (15,15), 0)
|
|
return mask
|
|
|
|
def main():
|
|
logger = get_logger()
|
|
|
|
# 경로 설정
|
|
onnx_path = os.environ.get("MIGAN_ONNX", "migan_pipeline_v2.onnx")
|
|
img_path = os.environ.get("MIGAN_IMG", "examples/input.png")
|
|
out_path = os.environ.get("MIGAN_OUT", "examples/output_migan.png")
|
|
|
|
if not os.path.exists(onnx_path):
|
|
logger.log(f"ONNX 파일 없음: {onnx_path}", level=logging.ERROR); return
|
|
if not os.path.exists(img_path):
|
|
logger.log(f"입력 이미지 없음: {img_path}", level=logging.ERROR); return
|
|
|
|
migan = MIGANPipelineONNXCompat(onnx_path, logger=logger, use_cuda=True)
|
|
mask = make_demo_mask_like_yours(img_path)
|
|
|
|
out = migan.inpaint(img_path, mask)
|
|
if out is None:
|
|
logger.log("인페인팅 실패", level=logging.ERROR); return
|
|
|
|
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
|
cv2.imwrite(out_path, out)
|
|
logger.log(f"저장 완료: {out_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|