TransWorker/modules/lama_inpaint.py

30 lines
1.0 KiB
Python

from simple_lama_inpainting import SimpleLama
from PIL import Image
import numpy as np
import cv2
def inpaint_with_simple_lama(image, mask, device="cuda"):
"""
simple-lama-inpainting을 사용해 인페인팅을 수행합니다.
image: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image
mask: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image (흑백)
device: "cuda" 또는 "cpu"
return: np.ndarray (BGR)
"""
# 이미지 로딩
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# mask도 동일하게 처리
if isinstance(mask, str):
mask = Image.open(mask)
elif isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
# 인페인팅
simple_lama = SimpleLama(device=device)
result = simple_lama(image, mask)
# PIL.Image -> np.ndarray(BGR)
result_np = cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
return result_np