IT_Server/modules/migan/inpaint_with_migan.py

56 lines
2.0 KiB
Python

# migan_inpaint.py (이름 예시)
import torch
import numpy as np
import cv2
from PIL import Image
from iopaint.model.mi_gan import MIGAN
from iopaint.schema import InpaintRequest
# MIGAN 인스턴스(모델 로드)는 미리 만들어서 재사용 권장!
# def get_migan(device="cuda", model_path="modules/migan/migan_traced.pt"):
def get_migan(device, model_path):
migan = MIGAN(device)
migan.model = torch.jit.load(model_path, map_location=device).eval()
migan.device = torch.device(device)
return migan
def inpaint_with_migan(image, mask, device="cuda", model_path="modules/migan/migan_traced.pt", migan_obj=None):
"""
MIGAN을 사용해 인페인팅을 수행합니다.
image: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image
mask: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image (흑백)
device: "cuda" 또는 "cpu"
migan_obj: 이미 로드한 MIGAN 인스턴스 (권장)
return: np.ndarray (BGR)
"""
# 이미지 로딩 (RGB)
if isinstance(image, str):
image = cv2.imread(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
elif isinstance(image, np.ndarray):
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 이미 RGB일 수도 있으니 확인 후 위 생략 가능
# 마스크 로딩 (흑백)
if isinstance(mask, str):
mask = cv2.imread(mask, cv2.IMREAD_GRAYSCALE)
elif isinstance(mask, Image.Image):
mask = np.array(mask.convert("L"))
elif isinstance(mask, np.ndarray):
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# MIGAN 인스턴스 준비
if migan_obj is None:
migan_obj = get_migan(device=device, model_path=model_path)
# 인페인팅 (InpaintRequest 기본값)
result_bgr = migan_obj(image, mask, InpaintRequest())
return result_bgr