191 lines
7.1 KiB
Python
191 lines
7.1 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import argparse
|
|
import requests
|
|
|
|
|
|
def submit_job(api, img_path, prefix="detail", group="g1", seq=1, toggles=None):
|
|
payload = {
|
|
"file_path": img_path,
|
|
"index": seq - 1,
|
|
"file_prefix": prefix,
|
|
"group_id": group,
|
|
"seq": seq,
|
|
}
|
|
if toggles:
|
|
payload["toggle_overrides"] = toggles
|
|
r = requests.post(f"{api}/v1/process-image", json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
return r.json()["job_id"]
|
|
|
|
|
|
def submit_rembg(api, img_path, prefix="thumb", toggles=None):
|
|
payload = {
|
|
"file_path": img_path,
|
|
"file_prefix": prefix,
|
|
}
|
|
if toggles:
|
|
payload["toggle_overrides"] = toggles
|
|
r = requests.post(f"{api}/v1/remove-background", json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
return r.json()["job_id"]
|
|
|
|
|
|
def wait_job(api, job_id, timeout=300):
|
|
end = time.time() + timeout
|
|
while time.time() < end:
|
|
r = requests.get(f"{api}/v1/jobs/{job_id}", timeout=15)
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
if data.get("status") in ("done", "error", "cancelled"):
|
|
return data
|
|
time.sleep(0.2)
|
|
raise TimeoutError("job wait timeout")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--api", default="http://127.0.0.1:8009", help="API base URL")
|
|
parser.add_argument("--samples", nargs="*", default=None, help="sample image paths")
|
|
parser.add_argument("--provider-ocr", default=None, help="force OCR provider: auto|dml|cpu")
|
|
parser.add_argument("--provider-migan", default=None, help="force MIGAN provider: auto|dml|cpu (via env/.env)")
|
|
parser.add_argument("--accel-migan", default=None, help="use accel for MIGAN: 1|0")
|
|
parser.add_argument("--mode", default="both", choices=["translate", "rembg", "both"], help="test mode")
|
|
parser.add_argument("--prefix", default="detail", help="file_prefix for translate (detail|option|thumb)")
|
|
parser.add_argument("--rembg-prefix", default="thumb", help="file_prefix for rembg (thumb|detail)")
|
|
parser.add_argument("--outdir", default=None, help="copy result files to this directory")
|
|
args = parser.parse_args()
|
|
|
|
api = args.api.rstrip("/")
|
|
|
|
# 준비: 헬스체크
|
|
try:
|
|
hr = requests.get(f"{api}/health", timeout=10)
|
|
hr.raise_for_status()
|
|
h = hr.json()
|
|
print("health:", h, flush=True)
|
|
if not h.get("ready"):
|
|
print("Worker not ready. Start API server: python main.py", flush=True)
|
|
return
|
|
except Exception as e:
|
|
print(f"Cannot reach API at {api}. Start server: python main.py | error={e}", flush=True)
|
|
return
|
|
|
|
# 런타임 프로바이더 선택(옵션)
|
|
if args.provider_ocr:
|
|
rr = requests.post(f"{api}/v1/ocr/reinit", json={"provider": args.provider_ocr}, timeout=30)
|
|
print("reinit ocr:", rr.status_code, rr.text)
|
|
if args.accel_migan is not None or args.provider_migan:
|
|
# MIGAN accel 플래그 우선
|
|
use_accel = None
|
|
if args.accel_migan is not None:
|
|
use_accel = (str(args.accel_migan) not in ("0", "false", "False"))
|
|
mr = requests.post(f"{api}/v1/migan/reset", json={"use_cuda": use_accel}, timeout=30)
|
|
print("reset migan:", mr.status_code, mr.text)
|
|
|
|
# 샘플 이미지 로드
|
|
root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
default_samples = [
|
|
os.path.join(root, "test", "1.jpg"),
|
|
os.path.join(root, "test", "2.jpg"),
|
|
os.path.join(root, "test", "5.jpg"),
|
|
]
|
|
samples = args.samples or default_samples
|
|
missing = [s for s in samples if not os.path.isfile(s)]
|
|
samples = [s for s in samples if os.path.isfile(s)]
|
|
if missing:
|
|
print("Missing sample files:", missing, flush=True)
|
|
if not samples:
|
|
print("No sample images found. Use --samples with absolute paths.", flush=True)
|
|
return
|
|
|
|
# 제출 및 성능 수집
|
|
total_start = time.time()
|
|
translate_ids, rembg_ids = [], []
|
|
|
|
if args.mode in ("translate", "both"):
|
|
print("\nSubmitting translate jobs...")
|
|
for i, img in enumerate(samples, start=1):
|
|
jid = submit_job(api, img, prefix=args.prefix, group="bench-t", seq=i)
|
|
print(f" translate submitted: {img} -> {jid}")
|
|
translate_ids.append((img, jid))
|
|
|
|
if args.mode in ("rembg", "both"):
|
|
print("\nSubmitting rembg jobs...")
|
|
for i, img in enumerate(samples, start=1):
|
|
jid = submit_rembg(api, img, prefix=args.rembg_prefix)
|
|
print(f" rembg submitted: {img} -> {jid}")
|
|
rembg_ids.append((img, jid))
|
|
|
|
outdir = args.outdir
|
|
if outdir:
|
|
os.makedirs(outdir, exist_ok=True)
|
|
|
|
# 결과 수집: translate
|
|
translate_results = []
|
|
if translate_ids:
|
|
print("\nWaiting translate results...")
|
|
for img, jid in translate_ids:
|
|
res = wait_job(api, jid, timeout=900)
|
|
translate_results.append(res)
|
|
rr = res.get("result") or {}
|
|
path = rr.get("path")
|
|
print(f" translate done: {img} -> status={rr.get('status')} path={path}")
|
|
if outdir and path and os.path.isfile(path):
|
|
import shutil
|
|
base = os.path.basename(path)
|
|
shutil.copy2(path, os.path.join(outdir, f"t_{base}"))
|
|
|
|
# 결과 수집: rembg
|
|
rembg_results = []
|
|
if rembg_ids:
|
|
print("\nWaiting rembg results...")
|
|
for img, jid in rembg_ids:
|
|
res = wait_job(api, jid, timeout=900)
|
|
rembg_results.append(res)
|
|
rr = res.get("result") or {}
|
|
path = rr.get("path")
|
|
print(f" rembg done: {img} -> status={rr.get('status')} path={path}")
|
|
if outdir and path and os.path.isfile(path):
|
|
import shutil
|
|
base = os.path.basename(path)
|
|
shutil.copy2(path, os.path.join(outdir, f"r_{base}"))
|
|
|
|
total_end = time.time()
|
|
total_s = total_end - total_start
|
|
total_jobs = len(translate_ids) + len(rembg_ids)
|
|
if total_jobs:
|
|
print(f"\nBatch finished: {total_jobs} jobs in {total_s:.2f}s (avg {total_s/total_jobs:.2f}s/job)")
|
|
|
|
# 타이밍 통계(translate)
|
|
if translate_results:
|
|
def collect(key):
|
|
vals = []
|
|
for r in translate_results:
|
|
rr = r.get("result") or {}
|
|
timings = (rr.get("timings") or {})
|
|
if key in timings:
|
|
try:
|
|
vals.append(float(timings[key]))
|
|
except Exception:
|
|
pass
|
|
return vals
|
|
|
|
keys = ["total_ms", "download", "ocr", "translate", "mask", "inpaint", "render", "save"]
|
|
print("\nTiming stats (translate, ms):")
|
|
for k in keys:
|
|
vals = collect(k)
|
|
if not vals:
|
|
continue
|
|
avg = sum(vals) / len(vals)
|
|
p50 = sorted(vals)[len(vals)//2]
|
|
p90 = sorted(vals)[int(len(vals)*0.9)-1] if len(vals) >= 10 else max(vals)
|
|
print(f" {k:10s} avg={avg:8.1f} p50={p50:8.1f} p90={p90:8.1f} n={len(vals)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|