165 lines
7.3 KiB
Python
165 lines
7.3 KiB
Python
import requests
|
|
import cv2
|
|
import base64
|
|
import numpy as np
|
|
import os
|
|
import logging
|
|
|
|
class Request_AI_Server:
|
|
"""IOPaint 서버 연동 인페인팅 모델 (REST API /api/v1/inpaint 사용, 바이너리 PNG 반환)"""
|
|
def __init__(self, logger,
|
|
inpaint_server_url: str = "http://192.168.0.150:35756",
|
|
rembg_server_url: str | None = None):
|
|
"""두 개의 서버 URL을 분리해 받는다.
|
|
|
|
Args:
|
|
logger: Logger 객체
|
|
inpaint_server_url: 인페인트 서버 기본 URL (포트 포함, api/v1 제외)
|
|
rembg_server_url: RemoveBG 서버 URL. None 이면 inpaint_server_url 과 동일하게 사용
|
|
"""
|
|
self.logger = logger
|
|
|
|
self.inpaint_base_url = inpaint_server_url.rstrip('/')
|
|
self.rembg_base_url = (rembg_server_url or inpaint_server_url).rstrip('/')
|
|
|
|
self.inpaint_api_url = f"{self.inpaint_base_url}/api/v1/inpaint"
|
|
# RemoveBG 플러그인은 고정 엔드포인트
|
|
self.rembg_api_url = f"{self.rembg_base_url}/api/v1/run_plugin_gen_image"
|
|
|
|
def request_inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
|
|
|
try:
|
|
# 서버 상태 먼저 확인
|
|
if not self.is_server_alive(self.inpaint_base_url):
|
|
self.logger.log("인페인팅 서버가 비정상입니다. 백업 인페인팅으로 넘어갑니다.", level=logging.WARNING)
|
|
return None
|
|
|
|
image_data = None
|
|
|
|
# image가 경로(str)라면 파일을 읽어서 np.ndarray로 변환
|
|
if isinstance(image, str) and os.path.isfile(image):
|
|
image_data = cv2.imread(image)
|
|
if image_data is None:
|
|
self.logger.log(f"이미지 파일을 읽을 수 없습니다: {image}", level=logging.ERROR)
|
|
return None
|
|
|
|
# 이미지를 base64로 인코딩
|
|
_, img_encoded = cv2.imencode('.png', image_data)
|
|
_, mask_encoded = cv2.imencode('.png', mask)
|
|
img_b64 = base64.b64encode(img_encoded).decode('utf-8')
|
|
mask_b64 = base64.b64encode(mask_encoded).decode('utf-8')
|
|
payload = {
|
|
"image": img_b64,
|
|
"mask": mask_b64
|
|
}
|
|
response = requests.post(self.inpaint_api_url, json=payload)
|
|
if response.status_code != 200:
|
|
print("인페인팅 서버 에러:", response.text)
|
|
return None
|
|
# 응답이 바이너리 PNG 이미지이므로 바로 디코딩
|
|
nparr = np.frombuffer(response.content, np.uint8)
|
|
result = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
return result
|
|
except Exception as e:
|
|
self.logger.log(f"인페인팅 서버 에러: {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|
|
|
|
def request_rembg(self, image: np.ndarray) -> np.ndarray:
|
|
"""RemoveBG 플러그인 호출 후 결과 이미지를 흰 배경 중앙 배치로 후처리."""
|
|
try:
|
|
# 서버 상태 먼저 확인
|
|
if not self.is_server_alive(self.rembg_base_url):
|
|
self.logger.log("rembg 서버가 비정상입니다.", level=logging.WARNING)
|
|
return None
|
|
|
|
# 입력 이미지 로드/확정
|
|
if isinstance(image, str) and os.path.isfile(image):
|
|
image_data = cv2.imread(image)
|
|
elif isinstance(image, np.ndarray):
|
|
image_data = image
|
|
else:
|
|
self.logger.log(f"이미지 파일을 읽을 수 없습니다: {image}", level=logging.ERROR)
|
|
return None
|
|
|
|
# base64 인코딩 (data URL)
|
|
_, img_encoded = cv2.imencode('.png', image_data)
|
|
img_b64 = base64.b64encode(img_encoded).decode('utf-8')
|
|
payload = {
|
|
"name": "RemoveBG",
|
|
"image": f"data:image/png;base64,{img_b64}",
|
|
"scale": 1
|
|
}
|
|
|
|
response = requests.post(self.rembg_api_url, json=payload)
|
|
if response.status_code != 200:
|
|
self.logger.log(f"rembg 서버 에러: {response.text}", level=logging.ERROR)
|
|
return None
|
|
|
|
# PNG 바이너리 → numpy
|
|
nparr = np.frombuffer(response.content, np.uint8)
|
|
result_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
|
|
|
if result_img is None or result_img.ndim != 3:
|
|
return result_img # 실패 시 원본 반환
|
|
|
|
# ---- 후처리: 마스크 정제 및 중앙 배치 ----
|
|
# 1) 초기 마스크
|
|
if result_img.shape[2] == 4:
|
|
mask_init = (result_img[:, :, 3] > 200).astype(np.uint8)
|
|
rgba_img = result_img
|
|
else:
|
|
gray = cv2.cvtColor(result_img[:, :, :3], cv2.COLOR_BGR2GRAY)
|
|
mask_init = (gray < 230).astype(np.uint8)
|
|
alpha_channel = (mask_init * 255).astype(np.uint8)
|
|
rgba_img = np.dstack([result_img, alpha_channel])
|
|
|
|
# 2) 모폴로지 정제
|
|
kernel = np.ones((3, 3), np.uint8)
|
|
mask = cv2.erode(mask_init, kernel, iterations=1)
|
|
mask = cv2.dilate(mask, kernel, iterations=2)
|
|
|
|
# 3) 최대 연결 요소만 유지
|
|
num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
|
if num > 1:
|
|
largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
|
|
mask = (labels == largest).astype(np.uint8)
|
|
|
|
ys, xs = np.where(mask > 0)
|
|
if len(xs) == 0 or len(ys) == 0:
|
|
# 객체 감지 실패 → 흰 배경 합성만
|
|
white_bg = np.full_like(rgba_img[:, :, :3], 255)
|
|
return white_bg
|
|
|
|
# 4) 바운딩 박스 + 마진
|
|
top, left = ys.min(), xs.min()
|
|
bottom, right = ys.max(), xs.max()
|
|
crop_rgba = rgba_img[top:bottom + 1, left:right + 1]
|
|
|
|
ch, cw = crop_rgba.shape[:2]
|
|
margin = int(max(ch, cw) * 0.1)
|
|
crop_rgba = cv2.copyMakeBorder(
|
|
crop_rgba, margin, margin, margin, margin,
|
|
borderType=cv2.BORDER_CONSTANT,
|
|
value=[255, 255, 255, 0]
|
|
)
|
|
|
|
# 5) 흰 배경으로 합성
|
|
bgr_crop = crop_rgba[:, :, :3].astype(np.float32)
|
|
alpha_crop = crop_rgba[:, :, 3:4].astype(np.float32) / 255.0
|
|
white_bg = np.full_like(bgr_crop, 255.0)
|
|
final_img = (bgr_crop * alpha_crop + white_bg * (1 - alpha_crop)).astype(np.uint8)
|
|
return final_img
|
|
except Exception as e:
|
|
self.logger.log(f"rembg 서버 에러: {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|
|
|
|
def is_server_alive(self, base_url: str, timeout: int = 1) -> bool:
|
|
"""주어진 서버(base_url)에 /api/v1/model 헬스체크를 호출한다."""
|
|
import requests as _rq
|
|
try:
|
|
model_url = base_url.rstrip('/') + '/api/v1/model'
|
|
response = _rq.get(model_url, timeout=timeout)
|
|
return response.status_code == 200
|
|
except Exception as e:
|
|
self.logger.log(f"서버 상태 확인 실패 ({base_url}): {e}", level=logging.WARNING)
|
|
return False |