IMG_Worker/modules/test/concurrency_test.py

175 lines
6.5 KiB
Python

import os
import time
import argparse
import asyncio
import json
import requests
def submit_job_sync(api, img_path, prefix="detail", group="g1", seq=1, toggles=None):
payload = {
"file_path": img_path,
"index": seq - 1,
"file_prefix": prefix,
"group_id": group,
"seq": seq,
}
if toggles:
payload["toggle_overrides"] = toggles
r = requests.post(f"{api}/v1/process-image", json=payload, timeout=30)
r.raise_for_status()
return r.json()["job_id"]
def submit_rembg_sync(api, img_path, prefix="thumb", toggles=None):
payload = {
"file_path": img_path,
"file_prefix": prefix,
}
if toggles:
payload["toggle_overrides"] = toggles
r = requests.post(f"{api}/v1/remove-background", json=payload, timeout=30)
r.raise_for_status()
return r.json()["job_id"]
async def wait_job_async(session, api, job_id, timeout=900):
import aiohttp
end = time.time() + timeout
async with session.get(f"{api}/v1/jobs/{job_id}") as resp:
_ = await resp.text() # priming
while time.time() < end:
try:
async with session.get(f"{api}/v1/jobs/{job_id}", timeout=15) as r:
if r.status == 200:
data = await r.json()
if data.get("status") in ("done", "error", "cancelled"):
return data
except Exception:
await asyncio.sleep(0.2)
await asyncio.sleep(0.2)
raise TimeoutError("job wait timeout")
async def run_concurrent(api, images, mode="translate", prefix="detail", rembg_prefix="thumb", outdir=None, toggles=None, concurrency=4):
import aiohttp
sem = asyncio.Semaphore(concurrency)
async def submit_and_wait(idx, img):
nonlocal api, mode, prefix, rembg_prefix, outdir, toggles
async with sem:
# 제출(동기 HTTP는 스레드풀로)
loop = asyncio.get_running_loop()
if mode == "translate":
jid = await loop.run_in_executor(None, submit_job_sync, api, img, prefix, "grp-t", idx + 1, toggles)
elif mode == "rembg":
jid = await loop.run_in_executor(None, submit_rembg_sync, api, img, rembg_prefix, toggles)
else: # both: 번역 먼저, rembg도 추가
jid_t = await loop.run_in_executor(None, submit_job_sync, api, img, prefix, "grp-t", idx + 1, toggles)
jid_r = await loop.run_in_executor(None, submit_rembg_sync, api, img, rembg_prefix, toggles)
return (img, (jid_t, jid_r))
# 대기(aiohttp)
async with aiohttp.ClientSession() as session:
res = await wait_job_async(session, api, jid)
if outdir:
rr = res.get("result") or {}
path = (rr.get("path") or rr.get("result", {}).get("path"))
if path and os.path.isfile(path):
import shutil
base = os.path.basename(path)
tag = "t_" if mode == "translate" else "r_"
try:
shutil.copy2(path, os.path.join(outdir, f"{tag}{base}"))
except Exception:
pass
return (img, res)
tasks = [submit_and_wait(i, img) for i, img in enumerate(images)]
return await asyncio.gather(*tasks)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--api", default="http://127.0.0.1:8009")
parser.add_argument("--samples", nargs="*", default=None)
parser.add_argument("--mode", default="translate", choices=["translate", "rembg", "both"])
parser.add_argument("--prefix", default="detail")
parser.add_argument("--rembg-prefix", default="thumb")
parser.add_argument("--outdir", default=None)
parser.add_argument("--concurrency", type=int, default=4)
parser.add_argument("--font", default=None, help="폰트 타입(예: 폰트3). toggle_overrides.font_type")
args = parser.parse_args()
api = args.api.rstrip("/")
# 헬스체크
try:
hr = requests.get(f"{api}/health", timeout=10)
hr.raise_for_status()
h = hr.json()
if not h.get("ready"):
print("Worker not ready. Start API server: python main.py", flush=True)
return
except Exception as e:
print(f"Cannot reach API at {api}. Start server: python main.py | error={e}", flush=True)
return
# 샘플 이미지
root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
default_samples = [
os.path.join(root, "test", "1.jpg"),
os.path.join(root, "test", "2.jpg"),
os.path.join(root, "test", "5.jpg"),
]
samples = args.samples or default_samples
samples = [s for s in samples if os.path.isfile(s)]
if not samples:
print("No sample images found. Use --samples with absolute paths.")
return
if args.outdir:
os.makedirs(args.outdir, exist_ok=True)
# 토글 오버라이드
toggles = {}
if args.font:
toggles["font_type"] = args.font
# 실행
total_start = time.time()
results = asyncio.run(run_concurrent(api, samples, mode=args.mode, prefix=args.prefix, rembg_prefix=args.rembg_prefix, outdir=args.outdir, toggles=toggles, concurrency=args.concurrency))
total_s = time.time() - total_start
# 출력 및 간단 통계
done = 0
timings = {"download": [], "ocr": [], "translate": [], "mask": [], "inpaint": [], "render": [], "save": [], "total_ms": []}
for item in results:
img, res = item
if isinstance(res, tuple):
print(f"submitted both for: {img} -> translate_jid={res[0]} rembg_jid={res[1]}")
continue
rr = (res or {}).get("result") or {}
if rr.get("status") == "translated" or rr.get("status") == "removed":
done += 1
t = rr.get("timings") or {}
for k in timings.keys():
v = t.get(k)
if v is not None:
try:
timings[k].append(float(v))
except Exception:
pass
print(f"done: {img} -> status={rr.get('status')} path={rr.get('path')} timings_keys={list((rr.get('timings') or {}).keys())}")
print(f"\nConcurrent batch finished: {done}/{len(results)} succeeded in {total_s:.2f}s")
if timings["total_ms"]:
avg = sum(timings["total_ms"]) / len(timings["total_ms"])
print(f"avg total_ms: {avg:.1f} (n={len(timings['total_ms'])})")
if __name__ == "__main__":
main()