30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
from simple_lama_inpainting import SimpleLama
|
|
from PIL import Image
|
|
import numpy as np
|
|
import cv2
|
|
|
|
def inpaint_with_simple_lama(image, mask, device="cuda"):
|
|
"""
|
|
simple-lama-inpainting을 사용해 인페인팅을 수행합니다.
|
|
image: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image
|
|
mask: 파일 경로(str), np.ndarray, 또는 PIL.Image.Image (흑백)
|
|
device: "cuda" 또는 "cpu"
|
|
return: np.ndarray (BGR)
|
|
"""
|
|
# 이미지 로딩
|
|
if isinstance(image, str):
|
|
image = Image.open(image)
|
|
elif isinstance(image, np.ndarray):
|
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
# mask도 동일하게 처리
|
|
if isinstance(mask, str):
|
|
mask = Image.open(mask)
|
|
elif isinstance(mask, np.ndarray):
|
|
mask = Image.fromarray(mask)
|
|
# 인페인팅
|
|
simple_lama = SimpleLama(device=device)
|
|
result = simple_lama(image, mask)
|
|
# PIL.Image -> np.ndarray(BGR)
|
|
result_np = cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
|
|
|
|
return result_np |