56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
# migan_inpaint.py (이름 예시)
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
from iopaint.model.mi_gan import MIGAN
|
|
from iopaint.schema import InpaintRequest
|
|
|
|
# MIGAN 인스턴스(모델 로드)는 미리 만들어서 재사용 권장!
|
|
# def get_migan(device="cuda", model_path="modules/migan/migan_traced.pt"):
|
|
def get_migan(device, model_path):
|
|
migan = MIGAN(device)
|
|
migan.model = torch.jit.load(model_path, map_location=device).eval()
|
|
migan.device = torch.device(device)
|
|
return migan
|
|
|
|
def inpaint_with_migan(image, mask, device="cuda", model_path="modules/migan/migan_traced.pt", migan_obj=None):
|
|
"""
|
|
MIGAN을 사용해 인페인팅을 수행합니다.
|
|
image: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image
|
|
mask: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image (흑백)
|
|
device: "cuda" 또는 "cpu"
|
|
migan_obj: 이미 로드한 MIGAN 인스턴스 (권장)
|
|
return: np.ndarray (BGR)
|
|
"""
|
|
# 이미지 로딩 (RGB)
|
|
if isinstance(image, str):
|
|
image = cv2.imread(image)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
elif isinstance(image, Image.Image):
|
|
image = np.array(image.convert("RGB"))
|
|
elif isinstance(image, np.ndarray):
|
|
if image.shape[2] == 4:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
|
elif image.shape[2] == 3:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
# 이미 RGB일 수도 있으니 확인 후 위 생략 가능
|
|
|
|
# 마스크 로딩 (흑백)
|
|
if isinstance(mask, str):
|
|
mask = cv2.imread(mask, cv2.IMREAD_GRAYSCALE)
|
|
elif isinstance(mask, Image.Image):
|
|
mask = np.array(mask.convert("L"))
|
|
elif isinstance(mask, np.ndarray):
|
|
if len(mask.shape) == 3:
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
|
|
|
# MIGAN 인스턴스 준비
|
|
if migan_obj is None:
|
|
migan_obj = get_migan(device=device, model_path=model_path)
|
|
|
|
# 인페인팅 (InpaintRequest 기본값)
|
|
result_bgr = migan_obj(image, mask, InpaintRequest())
|
|
|
|
return result_bgr
|