TransWorker/modules/inpainting_module.py

28 lines
1.2 KiB
Python

import numpy as np
import requests
import cv2
import base64
class IOPaintInpainting:
"""IOPaint 서버 연동 인페인팅 모델 (REST API /api/v1/inpaint 사용, 바이너리 PNG 반환)"""
def __init__(self, server_url="http://localhost:8080"):
self.api_url = f"http://localhost:8080/api/v1/inpaint"
def inpaint(self, image: np.ndarray, mask: np.ndarray, api_url:str = 'http://localhost:8080/api/v1/inpaint', ) -> np.ndarray:
# 이미지를 base64로 인코딩
_, img_encoded = cv2.imencode('.png', image)
_, mask_encoded = cv2.imencode('.png', mask)
img_b64 = base64.b64encode(img_encoded).decode('utf-8')
mask_b64 = base64.b64encode(mask_encoded).decode('utf-8')
payload = {
"image": img_b64,
"mask": mask_b64
}
response = requests.post(api_url, json=payload)
if response.status_code != 200:
print("IOPaint 서버 에러:", response.text)
return None
# 응답이 바이너리 PNG 이미지이므로 바로 디코딩
nparr = np.frombuffer(response.content, np.uint8)
result = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return result