175 lines
6.5 KiB
Python
175 lines
6.5 KiB
Python
import os
|
|
import time
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import requests
|
|
|
|
|
|
def submit_job_sync(api, img_path, prefix="detail", group="g1", seq=1, toggles=None):
|
|
payload = {
|
|
"file_path": img_path,
|
|
"index": seq - 1,
|
|
"file_prefix": prefix,
|
|
"group_id": group,
|
|
"seq": seq,
|
|
}
|
|
if toggles:
|
|
payload["toggle_overrides"] = toggles
|
|
r = requests.post(f"{api}/v1/process-image", json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
return r.json()["job_id"]
|
|
|
|
|
|
def submit_rembg_sync(api, img_path, prefix="thumb", toggles=None):
|
|
payload = {
|
|
"file_path": img_path,
|
|
"file_prefix": prefix,
|
|
}
|
|
if toggles:
|
|
payload["toggle_overrides"] = toggles
|
|
r = requests.post(f"{api}/v1/remove-background", json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
return r.json()["job_id"]
|
|
|
|
|
|
async def wait_job_async(session, api, job_id, timeout=900):
|
|
import aiohttp
|
|
end = time.time() + timeout
|
|
async with session.get(f"{api}/v1/jobs/{job_id}") as resp:
|
|
_ = await resp.text() # priming
|
|
while time.time() < end:
|
|
try:
|
|
async with session.get(f"{api}/v1/jobs/{job_id}", timeout=15) as r:
|
|
if r.status == 200:
|
|
data = await r.json()
|
|
if data.get("status") in ("done", "error", "cancelled"):
|
|
return data
|
|
except Exception:
|
|
await asyncio.sleep(0.2)
|
|
await asyncio.sleep(0.2)
|
|
raise TimeoutError("job wait timeout")
|
|
|
|
|
|
async def run_concurrent(api, images, mode="translate", prefix="detail", rembg_prefix="thumb", outdir=None, toggles=None, concurrency=4):
|
|
import aiohttp
|
|
sem = asyncio.Semaphore(concurrency)
|
|
|
|
async def submit_and_wait(idx, img):
|
|
nonlocal api, mode, prefix, rembg_prefix, outdir, toggles
|
|
async with sem:
|
|
# 제출(동기 HTTP는 스레드풀로)
|
|
loop = asyncio.get_running_loop()
|
|
if mode == "translate":
|
|
jid = await loop.run_in_executor(None, submit_job_sync, api, img, prefix, "grp-t", idx + 1, toggles)
|
|
elif mode == "rembg":
|
|
jid = await loop.run_in_executor(None, submit_rembg_sync, api, img, rembg_prefix, toggles)
|
|
else: # both: 번역 먼저, rembg도 추가
|
|
jid_t = await loop.run_in_executor(None, submit_job_sync, api, img, prefix, "grp-t", idx + 1, toggles)
|
|
jid_r = await loop.run_in_executor(None, submit_rembg_sync, api, img, rembg_prefix, toggles)
|
|
return (img, (jid_t, jid_r))
|
|
|
|
# 대기(aiohttp)
|
|
async with aiohttp.ClientSession() as session:
|
|
res = await wait_job_async(session, api, jid)
|
|
if outdir:
|
|
rr = res.get("result") or {}
|
|
path = (rr.get("path") or rr.get("result", {}).get("path"))
|
|
if path and os.path.isfile(path):
|
|
import shutil
|
|
base = os.path.basename(path)
|
|
tag = "t_" if mode == "translate" else "r_"
|
|
try:
|
|
shutil.copy2(path, os.path.join(outdir, f"{tag}{base}"))
|
|
except Exception:
|
|
pass
|
|
return (img, res)
|
|
|
|
tasks = [submit_and_wait(i, img) for i, img in enumerate(images)]
|
|
return await asyncio.gather(*tasks)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--api", default="http://127.0.0.1:8009")
|
|
parser.add_argument("--samples", nargs="*", default=None)
|
|
parser.add_argument("--mode", default="translate", choices=["translate", "rembg", "both"])
|
|
parser.add_argument("--prefix", default="detail")
|
|
parser.add_argument("--rembg-prefix", default="thumb")
|
|
parser.add_argument("--outdir", default=None)
|
|
parser.add_argument("--concurrency", type=int, default=4)
|
|
parser.add_argument("--font", default=None, help="폰트 타입(예: 폰트3). toggle_overrides.font_type")
|
|
args = parser.parse_args()
|
|
|
|
api = args.api.rstrip("/")
|
|
|
|
# 헬스체크
|
|
try:
|
|
hr = requests.get(f"{api}/health", timeout=10)
|
|
hr.raise_for_status()
|
|
h = hr.json()
|
|
if not h.get("ready"):
|
|
print("Worker not ready. Start API server: python main.py", flush=True)
|
|
return
|
|
except Exception as e:
|
|
print(f"Cannot reach API at {api}. Start server: python main.py | error={e}", flush=True)
|
|
return
|
|
|
|
# 샘플 이미지
|
|
root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
default_samples = [
|
|
os.path.join(root, "test", "1.jpg"),
|
|
os.path.join(root, "test", "2.jpg"),
|
|
os.path.join(root, "test", "5.jpg"),
|
|
]
|
|
samples = args.samples or default_samples
|
|
samples = [s for s in samples if os.path.isfile(s)]
|
|
if not samples:
|
|
print("No sample images found. Use --samples with absolute paths.")
|
|
return
|
|
|
|
if args.outdir:
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
|
|
# 토글 오버라이드
|
|
toggles = {}
|
|
if args.font:
|
|
toggles["font_type"] = args.font
|
|
|
|
# 실행
|
|
total_start = time.time()
|
|
results = asyncio.run(run_concurrent(api, samples, mode=args.mode, prefix=args.prefix, rembg_prefix=args.rembg_prefix, outdir=args.outdir, toggles=toggles, concurrency=args.concurrency))
|
|
total_s = time.time() - total_start
|
|
|
|
# 출력 및 간단 통계
|
|
done = 0
|
|
timings = {"download": [], "ocr": [], "translate": [], "mask": [], "inpaint": [], "render": [], "save": [], "total_ms": []}
|
|
for item in results:
|
|
img, res = item
|
|
if isinstance(res, tuple):
|
|
print(f"submitted both for: {img} -> translate_jid={res[0]} rembg_jid={res[1]}")
|
|
continue
|
|
rr = (res or {}).get("result") or {}
|
|
if rr.get("status") == "translated" or rr.get("status") == "removed":
|
|
done += 1
|
|
t = rr.get("timings") or {}
|
|
for k in timings.keys():
|
|
v = t.get(k)
|
|
if v is not None:
|
|
try:
|
|
timings[k].append(float(v))
|
|
except Exception:
|
|
pass
|
|
print(f"done: {img} -> status={rr.get('status')} path={rr.get('path')} timings_keys={list((rr.get('timings') or {}).keys())}")
|
|
|
|
print(f"\nConcurrent batch finished: {done}/{len(results)} succeeded in {total_s:.2f}s")
|
|
if timings["total_ms"]:
|
|
avg = sum(timings["total_ms"]) / len(timings["total_ms"])
|
|
print(f"avg total_ms: {avg:.1f} (n={len(timings['total_ms'])})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|