118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
import requests
|
||
import base64
|
||
import time
|
||
import os
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
|
||
# --- 설정 ---
|
||
SERVER_URL = "http://127.0.0.1:8008/api/v1/remove_bg"
|
||
IMAGE_PATH = "tests/rembg_test/456.webp"
|
||
OUTPUT_PATH = "tests/rembg_test/output_bria.png"
|
||
NUM_TESTS = 5
|
||
|
||
# --- 테스트 준비 ---
|
||
def image_to_base64(filepath):
|
||
""" 이미지를 읽어 base64로 인코딩합니다. """
|
||
try:
|
||
with Image.open(filepath) as img:
|
||
# RGBA를 가질 수 있는 경우 RGB로 변환하여 데이터 일관성 확보
|
||
if img.mode == 'RGBA':
|
||
img = img.convert('RGB')
|
||
|
||
buffered = BytesIO()
|
||
img.save(buffered, format="PNG")
|
||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||
except FileNotFoundError:
|
||
print(f"오류: 테스트 이미지 파일을 찾을 수 없습니다 - {filepath}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"오류: 이미지 처리 중 문제 발생 - {e}")
|
||
return None
|
||
|
||
def save_base64_image(base64_string, filepath):
|
||
""" base64 문자열을 이미지 파일로 저장합니다. """
|
||
try:
|
||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||
img_data = base64.b64decode(base64_string)
|
||
with open(filepath, 'wb') as f:
|
||
f.write(img_data)
|
||
print(f"✅ 결과 이미지가 '{filepath}'에 저장되었습니다.")
|
||
except Exception as e:
|
||
print(f"오류: 결과 이미지 저장 실패 - {e}")
|
||
|
||
# --- 테스트 실행 ---
|
||
if __name__ == "__main__":
|
||
print("--- Bria RMBG ONNX 모델 배경 제거 테스트 시작 ---")
|
||
|
||
# 1. 이미지 인코딩
|
||
print(f"테스트 이미지 로딩: {IMAGE_PATH}")
|
||
b64_image = image_to_base64(IMAGE_PATH)
|
||
|
||
if not b64_image:
|
||
exit()
|
||
|
||
# 2. API 요청 및 시간 측정
|
||
timings = []
|
||
last_response_image = None
|
||
|
||
for i in range(NUM_TESTS):
|
||
print(f"[{i+1}/{NUM_TESTS}] 요청 전송 중...", end=" ", flush=True)
|
||
|
||
payload = {
|
||
"image": b64_image,
|
||
}
|
||
|
||
try:
|
||
start_time = time.perf_counter()
|
||
response = requests.post(SERVER_URL, json=payload, params={"response_format": "base64", "image_format": "png"})
|
||
end_time = time.perf_counter()
|
||
|
||
duration = (end_time - start_time) * 1000 # ms 단위로 변환
|
||
timings.append(duration)
|
||
|
||
response.raise_for_status()
|
||
|
||
response_data = response.json()
|
||
last_response_image = response_data.get("image")
|
||
|
||
if i == 0:
|
||
print(f"성공! (Cold Start): {duration:.2f} ms")
|
||
else:
|
||
print(f"성공!: {duration:.2f} ms")
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"실패. API 요청 오류: {e}")
|
||
if hasattr(e, 'response') and e.response is not None:
|
||
print("서버 응답:", e.response.text)
|
||
break
|
||
except Exception as e:
|
||
print(f"실패. 예상치 못한 오류: {e}")
|
||
break
|
||
|
||
# 3. 결과 분석 및 저장
|
||
if timings:
|
||
print("\n--- 테스트 결과 분석 ---")
|
||
print(f"총 요청 횟수: {len(timings)}회")
|
||
if timings:
|
||
print(f"첫 요청 시간 (Cold Start): {timings[0]:.2f} ms")
|
||
|
||
if len(timings) > 1:
|
||
warm_timings = timings[1:]
|
||
avg_warm_time = sum(warm_timings) / len(warm_timings)
|
||
min_warm_time = min(warm_timings)
|
||
max_warm_time = max(warm_timings)
|
||
print(f"안정화된 추론 시간 (평균): {avg_warm_time:.2f} ms")
|
||
print(f"안정화된 추론 시간 (최소): {min_warm_time:.2f} ms")
|
||
print(f"안정화된 추론 시간 (최대): {max_warm_time:.2f} ms")
|
||
|
||
if last_response_image:
|
||
save_base64_image(last_response_image, OUTPUT_PATH)
|
||
else:
|
||
print("오류: 마지막 요청에서 이미지를 받지 못해 결과 파일을 저장할 수 없습니다.")
|
||
else:
|
||
print("\n테스트가 실행되지 않았습니다.")
|
||
|
||
print("\n--- 테스트 종료 ---")
|
||
print("ℹ️ 서버 VRAM 사용량 및 모델 로딩 시간은 'main_server.log'를 확인해주세요.")
|