IMG_Worker/modules/gemma_client.py

254 lines
8.9 KiB
Python

# -*- coding: utf-8 -*-
"""
Gemma Translation API Python Client
- FastAPI 서버(/batch_translate, /translate_ocr_step1, /translate_ocr_step2 ...) 래퍼
- 안전한 재시도, 타임아웃, 배치 슬라이싱 지원
- 네 OCR 결과(dict 리스트) -> 번역 문자열 리스트 정렬 유지
사용 예:
from gemma_client import GemmaTranslator
gt = GemmaTranslator(base_url="http://<SERVER_IP>", timeout=120)
# A) OCR 결과를 그대로 번역 (id, text 유지)
ko_list = gt.translate_ocr_texts(
product_name="휴대용 선풍기",
category="가전/계절가전",
ocr_results=[{"text":"强力送风"}, {"text":"USB-C 快速充电"}]
)
# B) 순수 텍스트 리스트를 번역
ko_list = gt.batch_translate_texts(
product_name="휴대용 선풍기",
category="가전/계절가전",
text_list=["大风力无刷电机","Type-C 充电"]
)
"""
from __future__ import annotations
import os
import time
import json
import random
import logging
from typing import List, Dict, Any, Optional, Tuple
import requests
_JSON = Dict[str, Any]
class GemmaTranslatorError(RuntimeError):
pass
class GemmaTranslator:
"""
vLLM 번역 서버 클라이언트.
Params
------
base_url : str
예) "http://localhost" (HAProxy 경유 시 포트 생략 / 80)
개별 인스턴스 직접 붙으려면 "http://<IP>:8000"
timeout : int
요청 타임아웃(초)
max_retries : int
요청 재시도 횟수
backoff : float
재시도 backoff base (지수)
session : requests.Session | None
세션 주입 가능
"""
def __init__(
self,
base_url: Optional[str] = None,
timeout: int = 120,
max_retries: int = 2,
backoff: float = 0.6,
session: Optional[requests.Session] = None,
logger: Optional[logging.Logger] = None,
) -> None:
self.base_url = (base_url or os.getenv("GEMMA_API_BASE") or "http://localhost").rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
self.backoff = backoff
self.sess = session or requests.Session()
self.log = logger or logging.getLogger(__name__)
# -----------------------------
# 내부 HTTP 헬퍼 (retry 포함)
# -----------------------------
def _post(self, path: str, payload: _JSON) -> _JSON:
url = f"{self.base_url}{path}"
last_err: Optional[Exception] = None
for attempt in range(self.max_retries + 1):
try:
r = self.sess.post(url, json=payload, timeout=self.timeout)
r.raise_for_status()
return r.json()
except Exception as e:
last_err = e
# 429/5xx/연결오류 등 재시도
if attempt < self.max_retries:
sleep_s = (self.backoff ** attempt) + random.uniform(0, 0.2)
self.log.warning(f"[GemmaTranslator] POST {url} 실패({e}), 재시도 {attempt+1}/{self.max_retries} 대기 {sleep_s:.2f}s")
time.sleep(sleep_s)
else:
break
raise GemmaTranslatorError(f"POST {url} 실패: {last_err}")
# -----------------------------
# 공개 API
# -----------------------------
def health(self) -> _JSON:
url = f"{self.base_url}/health"
r = self.sess.get(url, timeout=10)
r.raise_for_status()
return r.json()
def metrics(self) -> _JSON:
url = f"{self.base_url}/metrics"
r = self.sess.get(url, timeout=10)
r.raise_for_status()
return r.json()
# ---- A) 순수 텍스트 리스트 번역 (/batch_translate) ----
def batch_translate_texts(
self,
product_name: str,
category: str,
text_list: List[str],
delimiter: str = " / ",
batch_size: int = 8,
) -> List[str]:
"""
text_list → 같은 길이의 ko 리스트 반환.
서버에서 추가 배치/토큰 슬라이싱을 하므로, 클라에서는 적당한 batch_size만 넘기면 됨.
"""
if not text_list:
return []
out: List[str] = []
for i in range(0, len(text_list), batch_size):
chunk = text_list[i : i + batch_size]
payload = {
"product_name": product_name,
"category": category,
"text_list": chunk,
"delimiter": delimiter,
"batch_size": min(batch_size, len(chunk)),
}
resp = self._post("/batch-translate", payload)
# 서버는 translated_texts 길이를 입력과 동일하게 맞춰줌
out.extend(resp.get("translated_texts", chunk))
return out
# ---- B) OCR 결과 번역: [{text:str, ...}] -> [ko, ...] ----
def translate_ocr_texts(
self,
product_name: str,
category: str,
ocr_results: List[Dict[str, Any]],
batch_size: int = 16,
) -> List[str]:
"""
입력: OCR 결과 리스트(각 항목에 최소 'text' 키 필요)
처리: 서버 /ocr-translate 로 {id, source} 배열을 보내고, 반환 {id, result} 정렬
출력: 원래 순서 유지한 ko 문자열 리스트
"""
if not ocr_results:
return []
# 유효성 검증 (새 스키마 준수)
if not product_name or len(product_name.strip()) < 1:
raise GemmaTranslatorError("product_name은 1자 이상이어야 합니다.")
if not category or len(category.strip()) < 1:
raise GemmaTranslatorError("category는 1자 이상이어야 합니다.")
# id 부여 및 source 필터링 (빈 텍스트 스킵, 최소 1자)
items = []
source_to_orig_idx = {} # source id to original index mapping
for i, d in enumerate(ocr_results):
source = (d.get("text") or "").strip()
if len(source) >= 1: # 최소 1자 이상
item_id = len(items) + 1
items.append({"id": item_id, "source": source})
source_to_orig_idx[item_id] = i
if not items:
return [""] * len(ocr_results) # 원래 길이만큼 빈 문자열 반환
# 서버에 배치로 전송
out_ko = [""] * len(ocr_results) # 원래 ocr_results 길이 유지
for i in range(0, len(items), batch_size):
chunk = items[i : i + batch_size]
payload = {
"product_name": product_name,
"category": category,
"items": chunk,
}
resp = self._post("/ocr-translate", payload)
result_items = resp.get("items", [])
# {id, result} 배열 → id 기준으로 매핑
for obj in result_items:
item_id = int(obj.get("id", 0))
orig_idx = source_to_orig_idx.get(item_id)
if orig_idx is not None:
out_ko[orig_idx] = str(obj.get("result", "")).strip()
return out_ko
# ---- C) 옵션 번역: [{id, source:[...]}] → [{id, translations:[...]}] ----
def translate_option_groups(
self,
product_name: str,
category: str,
option_groups: List[Dict[str, Any]],
batch_size: int = 8,
) -> List[Dict[str, Any]]:
"""
option_groups 예:
[{"id": 1, "source": ["红色","蓝色"]}, {"id": 2, "source": ["小号","大号"]}]
반환(서버 결과 그대로):
[{"id": 1, "translations": ["핑크","블루"]}, {"id": 2, "translations": ["소형","대형"]}]
"""
if not option_groups:
return []
out: List[Dict[str, Any]] = []
for i in range(0, len(option_groups), batch_size):
chunk = option_groups[i : i + batch_size]
payload = {
"product_name": product_name,
"category": category,
"items": chunk,
}
resp = self._post("/option-translate", payload)
out.extend(resp.get("result", []))
return out
# ---- D) (선택) 카피 다듬기 ---- (이 메서드는 새 스펙에서 OCR가 단일 단계이므로 제거 또는 주석 처리)
# def polish_translations(
# self,
# product_name: str,
# category: str,
# id_text_pairs: List[Dict[str, Any]],
# batch_size: int = 16,
# ) -> List[Dict[str, Any]]:
# """
# 입력: [{"id":1,"translation":"..."}]
# 반환: [{"id":1,"result":"..."}]
# """
# if not id_text_pairs:
# return []
# out: List[Dict[str, Any]] = []
# for i in range(0, len(id_text_pairs), batch_size):
# chunk = id_text_pairs[i : i + batch_size]
# payload = {"product_name": product_name, "category": category, "items": chunk}
# resp = self._post("/translate_ocr_step2", payload)
# out.extend(resp.get("result", []))
# return out