211 lines
9.3 KiB
Python
211 lines
9.3 KiB
Python
import numpy as np
|
|
import requests
|
|
import cv2
|
|
import base64
|
|
import subprocess
|
|
import random
|
|
import time
|
|
import threading
|
|
import os
|
|
import logging
|
|
|
|
class IOPaintManager:
|
|
"""IOPaint 서버 인스턴스 및 인페인팅 요청을 통합 관리하는 매니저"""
|
|
class ServerInstance:
|
|
def __init__(self, port, process):
|
|
self.port = port
|
|
self.process = process
|
|
self.busy = False
|
|
self.last_used = time.time()
|
|
def mark_busy(self):
|
|
self.busy = True
|
|
self.last_used = time.time()
|
|
def mark_idle(self):
|
|
self.busy = False
|
|
self.last_used = time.time()
|
|
def is_alive(self):
|
|
return self.process.poll() is None
|
|
|
|
def __init__(self, logger, num_instances=1, port_range=(8099, 8199), base_dir=None, wait_ready=30, model_dir=None):
|
|
self.logger = logger
|
|
self.instances = []
|
|
self.port_range = port_range
|
|
self.lock = threading.Lock()
|
|
self.base_dir = base_dir or os.getcwd()
|
|
self.model_dir = model_dir or os.path.join(self.base_dir, 'iop', 'models')
|
|
self.exe_path = os.path.join(self.base_dir, 'iop', 'iop.exe')
|
|
self._start_instances(num_instances, wait_ready)
|
|
|
|
def _get_random_port(self):
|
|
used_ports = {inst.port for inst in self.instances}
|
|
candidates = [p for p in range(self.port_range[0], self.port_range[1]+1) if p not in used_ports]
|
|
if not candidates:
|
|
self.logger.log("사용 가능한 포트가 없습니다.", level=logging.ERROR)
|
|
raise RuntimeError("사용 가능한 포트가 없습니다.")
|
|
return random.choice(candidates)
|
|
|
|
def wait_for_server_ready(self, port, timeout=30):
|
|
url = f"http://localhost:{port}/api/v1/server-config"
|
|
start = time.time()
|
|
last_error = None
|
|
self.logger.log(f"[{port}] 서버 준비 체크 시작 (최대 {timeout}초 대기)", level=logging.INFO)
|
|
tries = 0
|
|
while time.time() - start < timeout:
|
|
tries += 1
|
|
try:
|
|
r = requests.get(url, timeout=2)
|
|
self.logger.log(f"응답 : {r}", level=logging.INFO)
|
|
if r.status_code == 200:
|
|
elapsed = time.time() - start
|
|
self.logger.log(f"[{port}] 서버 준비 완료! (시도 {tries}회, {elapsed:.1f}초 소요)", level=logging.INFO)
|
|
return True
|
|
else:
|
|
self.logger.log(f"[{port}] 응답 코드: {r.status_code}", level=logging.INFO)
|
|
except Exception as e:
|
|
last_error = str(e)
|
|
self.logger.log(f"[{port}] 준비 체크 실패 (시도 {tries}회): {last_error}", level=logging.ERROR, exc_info=True)
|
|
time.sleep(0.5)
|
|
self.logger.log(f"[{port}] 서버 준비 실패 (총 {tries}회 시도, 마지막 에러: {last_error})", level=logging.ERROR, exc_info=True)
|
|
return False
|
|
|
|
def _start_instances(self, num, wait_ready):
|
|
self.logger.log(f"IOPaint 인스턴스 {num} 개 시작", level=logging.INFO)
|
|
try:
|
|
import torch
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
except Exception as e:
|
|
self.logger.log(f"torch import 또는 GPU 체크 실패: {e}", level=logging.WARNING)
|
|
device_type = "cpu"
|
|
for _ in range(num):
|
|
port = self._get_random_port()
|
|
cmd = [self.exe_path, 'start', '--model=lama', f'--device={device_type}', '--port', str(port), '--model-dir', self.model_dir]
|
|
self.logger.log(f"[{port}] 인스턴스 실행 명령: {' '.join(cmd)}", level=logging.INFO)
|
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
instance = self.ServerInstance(port, proc)
|
|
self.instances.append(instance)
|
|
start_wait = 8
|
|
time.sleep(start_wait)
|
|
self.logger.log(f"[{port}] 인스턴스 실행 명시대기: {start_wait}초", level=logging.INFO)
|
|
if self.wait_for_server_ready(port, timeout=wait_ready):
|
|
self.logger.log(f"IOPaint 인스턴스 {instance.port} 준비됨", level=logging.INFO)
|
|
else:
|
|
self.logger.log(f"IOPaint 인스턴스 {instance.port} 시작 실패", level=logging.ERROR)
|
|
# 에러 메시지 출력
|
|
try:
|
|
out, err = proc.communicate(timeout=3)
|
|
self.logger.log(f"[{port}] 표준출력:\n{out.decode(errors='ignore')}", level=logging.INFO)
|
|
self.logger.log(f"[{port}] 표준에러:\n{err.decode(errors='ignore')}", level=logging.INFO)
|
|
except Exception as e:
|
|
self.logger.log(f"[{port}] 에러 메시지 읽기 실패: {e}", level=logging.ERROR)
|
|
|
|
def get_instance_info(self):
|
|
"""모든 인스턴스의 정보를 반환"""
|
|
info = []
|
|
for inst in self.instances:
|
|
info.append({
|
|
"port": inst.port,
|
|
"busy": inst.busy,
|
|
"alive": inst.is_alive(),
|
|
"last_used": inst.last_used
|
|
})
|
|
return info
|
|
|
|
def get_idle_instance(self):
|
|
"""놀고 있는(사용 가능한) 인스턴스 반환 (없으면 None)"""
|
|
with self.lock:
|
|
for inst in self.instances:
|
|
if not inst.busy and inst.is_alive():
|
|
inst.mark_busy()
|
|
self.logger.log(f"IOPaint 인스턴스 {inst.port} 사용 중", level=logging.INFO)
|
|
return inst
|
|
return None
|
|
|
|
def mark_instance_idle(self, port):
|
|
"""작업이 끝난 인스턴스를 idle로 표시"""
|
|
for inst in self.instances:
|
|
if inst.port == port:
|
|
inst.mark_idle()
|
|
self.logger.log(f"IOPaint 인스턴스 {inst.port} 유휴", level=logging.INFO)
|
|
break
|
|
|
|
def shutdown_all(self):
|
|
"""모든 서버 인스턴스 종료"""
|
|
for inst in self.instances:
|
|
if inst.is_alive():
|
|
inst.process.terminate()
|
|
self.logger.log(f"IOPaint 인스턴스 {inst.port} 종료", level=logging.INFO)
|
|
self.instances = []
|
|
self.logger.log("모든 IOPaint 인스턴스 종료", level=logging.INFO)
|
|
|
|
def inpaint(self, image, mask, instance=None) -> np.ndarray:
|
|
"""image와 mask를 경로나 np.ndarray 모두 지원"""
|
|
# 이미지 처리
|
|
if isinstance(image, str):
|
|
image_np = cv2.imread(image)
|
|
if image_np is None:
|
|
self.logger.log(f"이미지 로딩 실패: {image}", level=logging.ERROR)
|
|
return None
|
|
else:
|
|
image_np = image
|
|
|
|
# 마스크 처리
|
|
if isinstance(mask, str):
|
|
mask_np = cv2.imread(mask, cv2.IMREAD_GRAYSCALE)
|
|
if mask_np is None:
|
|
self.logger.log(f"마스크 로딩 실패: {mask}", level=logging.ERROR)
|
|
return None
|
|
else:
|
|
mask_np = mask
|
|
|
|
|
|
if instance is None:
|
|
instance = self.get_idle_instance()
|
|
if instance is None:
|
|
self.logger.log("사용 가능한 IOPaint 인스턴스가 없습니다.", level=logging.ERROR)
|
|
return None
|
|
api_url = f"http://localhost:{instance.port}/api/v1/inpaint"
|
|
self.logger.log(f"IOPaint 인스턴스 {instance.port} 사용", level=logging.INFO)
|
|
try:
|
|
_, img_encoded = cv2.imencode('.png', image_np)
|
|
_, mask_encoded = cv2.imencode('.png', mask_np)
|
|
img_b64 = base64.b64encode(img_encoded).decode('utf-8')
|
|
mask_b64 = base64.b64encode(mask_encoded).decode('utf-8')
|
|
payload = {
|
|
"image": img_b64,
|
|
"mask": mask_b64
|
|
}
|
|
response = requests.post(api_url, json=payload)
|
|
if response.status_code != 200:
|
|
self.logger.log(f"IOPaint 서버 에러: {response.text}", level=logging.ERROR)
|
|
return None
|
|
nparr = np.frombuffer(response.content, np.uint8)
|
|
result = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
return result
|
|
finally:
|
|
self.mark_instance_idle(instance.port)
|
|
|
|
def add_instance(self, wait_ready=30):
|
|
try:
|
|
import torch
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
except Exception as e:
|
|
self.logger.log(f"torch import 또는 GPU 체크 실패: {e}", level=logging.WARNING)
|
|
device_type = "cpu"
|
|
port = self._get_random_port()
|
|
cmd = [self.exe_path, 'start', '--model=lama', f'--device={device_type}', '--port', str(port), '--model-dir', self.model_dir]
|
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
instance = self.ServerInstance(port, proc)
|
|
self.instances.append(instance)
|
|
if self.wait_for_server_ready(port, timeout=wait_ready):
|
|
self.logger.log(f"IOPaint 인스턴스 {instance.port} 시작", level=logging.INFO)
|
|
else:
|
|
self.logger.log(f"IOPaint 인스턴스 {instance.port} 시작 실패", level=logging.ERROR)
|
|
return instance
|
|
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# manager = IOPaintManager(num_instances=1)
|
|
# # result = manager.inpaint(image, mask) # 자동으로 idle 인스턴스에 요청
|
|
# print(manager.get_instance_info())
|
|
# manager.shutdown_all() |