IMG_Worker/modules/test/perf_test.py

191 lines
7.1 KiB
Python

import os
import time
import json
import argparse
import requests
def submit_job(api, img_path, prefix="detail", group="g1", seq=1, toggles=None):
payload = {
"file_path": img_path,
"index": seq - 1,
"file_prefix": prefix,
"group_id": group,
"seq": seq,
}
if toggles:
payload["toggle_overrides"] = toggles
r = requests.post(f"{api}/v1/process-image", json=payload, timeout=30)
r.raise_for_status()
return r.json()["job_id"]
def submit_rembg(api, img_path, prefix="thumb", toggles=None):
payload = {
"file_path": img_path,
"file_prefix": prefix,
}
if toggles:
payload["toggle_overrides"] = toggles
r = requests.post(f"{api}/v1/remove-background", json=payload, timeout=30)
r.raise_for_status()
return r.json()["job_id"]
def wait_job(api, job_id, timeout=300):
end = time.time() + timeout
while time.time() < end:
r = requests.get(f"{api}/v1/jobs/{job_id}", timeout=15)
if r.status_code == 200:
data = r.json()
if data.get("status") in ("done", "error", "cancelled"):
return data
time.sleep(0.2)
raise TimeoutError("job wait timeout")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--api", default="http://127.0.0.1:8009", help="API base URL")
parser.add_argument("--samples", nargs="*", default=None, help="sample image paths")
parser.add_argument("--provider-ocr", default=None, help="force OCR provider: auto|dml|cpu")
parser.add_argument("--provider-migan", default=None, help="force MIGAN provider: auto|dml|cpu (via env/.env)")
parser.add_argument("--accel-migan", default=None, help="use accel for MIGAN: 1|0")
parser.add_argument("--mode", default="both", choices=["translate", "rembg", "both"], help="test mode")
parser.add_argument("--prefix", default="detail", help="file_prefix for translate (detail|option|thumb)")
parser.add_argument("--rembg-prefix", default="thumb", help="file_prefix for rembg (thumb|detail)")
parser.add_argument("--outdir", default=None, help="copy result files to this directory")
args = parser.parse_args()
api = args.api.rstrip("/")
# 준비: 헬스체크
try:
hr = requests.get(f"{api}/health", timeout=10)
hr.raise_for_status()
h = hr.json()
print("health:", h, flush=True)
if not h.get("ready"):
print("Worker not ready. Start API server: python main.py", flush=True)
return
except Exception as e:
print(f"Cannot reach API at {api}. Start server: python main.py | error={e}", flush=True)
return
# 런타임 프로바이더 선택(옵션)
if args.provider_ocr:
rr = requests.post(f"{api}/v1/ocr/reinit", json={"provider": args.provider_ocr}, timeout=30)
print("reinit ocr:", rr.status_code, rr.text)
if args.accel_migan is not None or args.provider_migan:
# MIGAN accel 플래그 우선
use_accel = None
if args.accel_migan is not None:
use_accel = (str(args.accel_migan) not in ("0", "false", "False"))
mr = requests.post(f"{api}/v1/migan/reset", json={"use_cuda": use_accel}, timeout=30)
print("reset migan:", mr.status_code, mr.text)
# 샘플 이미지 로드
root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
default_samples = [
os.path.join(root, "test", "1.jpg"),
os.path.join(root, "test", "2.jpg"),
os.path.join(root, "test", "5.jpg"),
]
samples = args.samples or default_samples
missing = [s for s in samples if not os.path.isfile(s)]
samples = [s for s in samples if os.path.isfile(s)]
if missing:
print("Missing sample files:", missing, flush=True)
if not samples:
print("No sample images found. Use --samples with absolute paths.", flush=True)
return
# 제출 및 성능 수집
total_start = time.time()
translate_ids, rembg_ids = [], []
if args.mode in ("translate", "both"):
print("\nSubmitting translate jobs...")
for i, img in enumerate(samples, start=1):
jid = submit_job(api, img, prefix=args.prefix, group="bench-t", seq=i)
print(f" translate submitted: {img} -> {jid}")
translate_ids.append((img, jid))
if args.mode in ("rembg", "both"):
print("\nSubmitting rembg jobs...")
for i, img in enumerate(samples, start=1):
jid = submit_rembg(api, img, prefix=args.rembg_prefix)
print(f" rembg submitted: {img} -> {jid}")
rembg_ids.append((img, jid))
outdir = args.outdir
if outdir:
os.makedirs(outdir, exist_ok=True)
# 결과 수집: translate
translate_results = []
if translate_ids:
print("\nWaiting translate results...")
for img, jid in translate_ids:
res = wait_job(api, jid, timeout=900)
translate_results.append(res)
rr = res.get("result") or {}
path = rr.get("path")
print(f" translate done: {img} -> status={rr.get('status')} path={path}")
if outdir and path and os.path.isfile(path):
import shutil
base = os.path.basename(path)
shutil.copy2(path, os.path.join(outdir, f"t_{base}"))
# 결과 수집: rembg
rembg_results = []
if rembg_ids:
print("\nWaiting rembg results...")
for img, jid in rembg_ids:
res = wait_job(api, jid, timeout=900)
rembg_results.append(res)
rr = res.get("result") or {}
path = rr.get("path")
print(f" rembg done: {img} -> status={rr.get('status')} path={path}")
if outdir and path and os.path.isfile(path):
import shutil
base = os.path.basename(path)
shutil.copy2(path, os.path.join(outdir, f"r_{base}"))
total_end = time.time()
total_s = total_end - total_start
total_jobs = len(translate_ids) + len(rembg_ids)
if total_jobs:
print(f"\nBatch finished: {total_jobs} jobs in {total_s:.2f}s (avg {total_s/total_jobs:.2f}s/job)")
# 타이밍 통계(translate)
if translate_results:
def collect(key):
vals = []
for r in translate_results:
rr = r.get("result") or {}
timings = (rr.get("timings") or {})
if key in timings:
try:
vals.append(float(timings[key]))
except Exception:
pass
return vals
keys = ["total_ms", "download", "ocr", "translate", "mask", "inpaint", "render", "save"]
print("\nTiming stats (translate, ms):")
for k in keys:
vals = collect(k)
if not vals:
continue
avg = sum(vals) / len(vals)
p50 = sorted(vals)[len(vals)//2]
p90 = sorted(vals)[int(len(vals)*0.9)-1] if len(vals) >= 10 else max(vals)
print(f" {k:10s} avg={avg:8.1f} p50={p50:8.1f} p90={p90:8.1f} n={len(vals)}")
if __name__ == "__main__":
main()