inpaintServer/tests/test_rembg_onnx.py

118 lines
4.2 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import base64
import time
import os
from PIL import Image
from io import BytesIO
# --- 설정 ---
SERVER_URL = "http://127.0.0.1:8008/api/v1/remove_bg"
IMAGE_PATH = "tests/rembg_test/456.webp"
OUTPUT_PATH = "tests/rembg_test/output_bria.png"
NUM_TESTS = 5
# --- 테스트 준비 ---
def image_to_base64(filepath):
""" 이미지를 읽어 base64로 인코딩합니다. """
try:
with Image.open(filepath) as img:
# RGBA를 가질 수 있는 경우 RGB로 변환하여 데이터 일관성 확보
if img.mode == 'RGBA':
img = img.convert('RGB')
buffered = BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
except FileNotFoundError:
print(f"오류: 테스트 이미지 파일을 찾을 수 없습니다 - {filepath}")
return None
except Exception as e:
print(f"오류: 이미지 처리 중 문제 발생 - {e}")
return None
def save_base64_image(base64_string, filepath):
""" base64 문자열을 이미지 파일로 저장합니다. """
try:
os.makedirs(os.path.dirname(filepath), exist_ok=True)
img_data = base64.b64decode(base64_string)
with open(filepath, 'wb') as f:
f.write(img_data)
print(f"✅ 결과 이미지가 '{filepath}'에 저장되었습니다.")
except Exception as e:
print(f"오류: 결과 이미지 저장 실패 - {e}")
# --- 테스트 실행 ---
if __name__ == "__main__":
print("--- Bria RMBG ONNX 모델 배경 제거 테스트 시작 ---")
# 1. 이미지 인코딩
print(f"테스트 이미지 로딩: {IMAGE_PATH}")
b64_image = image_to_base64(IMAGE_PATH)
if not b64_image:
exit()
# 2. API 요청 및 시간 측정
timings = []
last_response_image = None
for i in range(NUM_TESTS):
print(f"[{i+1}/{NUM_TESTS}] 요청 전송 중...", end=" ", flush=True)
payload = {
"image": b64_image,
}
try:
start_time = time.perf_counter()
response = requests.post(SERVER_URL, json=payload, params={"response_format": "base64", "image_format": "png"})
end_time = time.perf_counter()
duration = (end_time - start_time) * 1000 # ms 단위로 변환
timings.append(duration)
response.raise_for_status()
response_data = response.json()
last_response_image = response_data.get("image")
if i == 0:
print(f"성공! (Cold Start): {duration:.2f} ms")
else:
print(f"성공!: {duration:.2f} ms")
except requests.exceptions.RequestException as e:
print(f"실패. API 요청 오류: {e}")
if hasattr(e, 'response') and e.response is not None:
print("서버 응답:", e.response.text)
break
except Exception as e:
print(f"실패. 예상치 못한 오류: {e}")
break
# 3. 결과 분석 및 저장
if timings:
print("\n--- 테스트 결과 분석 ---")
print(f"총 요청 횟수: {len(timings)}")
if timings:
print(f"첫 요청 시간 (Cold Start): {timings[0]:.2f} ms")
if len(timings) > 1:
warm_timings = timings[1:]
avg_warm_time = sum(warm_timings) / len(warm_timings)
min_warm_time = min(warm_timings)
max_warm_time = max(warm_timings)
print(f"안정화된 추론 시간 (평균): {avg_warm_time:.2f} ms")
print(f"안정화된 추론 시간 (최소): {min_warm_time:.2f} ms")
print(f"안정화된 추론 시간 (최대): {max_warm_time:.2f} ms")
if last_response_image:
save_base64_image(last_response_image, OUTPUT_PATH)
else:
print("오류: 마지막 요청에서 이미지를 받지 못해 결과 파일을 저장할 수 없습니다.")
else:
print("\n테스트가 실행되지 않았습니다.")
print("\n--- 테스트 종료 ---")
print(" 서버 VRAM 사용량 및 모델 로딩 시간은 'main_server.log'를 확인해주세요.")