392 lines
15 KiB
Python
392 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Gemma Translation API Python Client
|
|
- FastAPI 서버(/batch_translate, /translate_ocr_step1, /translate_ocr_step2 ...) 래퍼
|
|
- 안전한 재시도, 타임아웃, 배치 슬라이싱 지원
|
|
- 네 OCR 결과(dict 리스트) -> 번역 문자열 리스트 정렬 유지
|
|
|
|
사용 예:
|
|
from gemma_client import GemmaTranslator
|
|
gt = GemmaTranslator(base_url="http://<SERVER_IP>", timeout=120)
|
|
|
|
# A) OCR 결과를 그대로 번역 (id, text 유지)
|
|
ko_list = gt.translate_ocr_texts(
|
|
product_name="휴대용 선풍기",
|
|
category="가전/계절가전",
|
|
ocr_results=[{"text":"强力送风"}, {"text":"USB-C 快速充电"}]
|
|
)
|
|
|
|
# B) 순수 텍스트 리스트를 번역
|
|
ko_list = gt.batch_translate_texts(
|
|
product_name="휴대용 선풍기",
|
|
category="가전/계절가전",
|
|
text_list=["大风力无刷电机","Type-C 充电"]
|
|
)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
import random
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
import requests
|
|
|
|
|
|
_JSON = Dict[str, Any]
|
|
|
|
|
|
class GemmaTranslatorError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class GemmaTranslator:
|
|
"""
|
|
vLLM 번역 서버 클라이언트.
|
|
|
|
Params
|
|
------
|
|
base_url : str
|
|
예) "http://localhost" (HAProxy 경유 시 포트 생략 / 80)
|
|
개별 인스턴스 직접 붙으려면 "http://<IP>:8000"
|
|
timeout : int
|
|
요청 타임아웃(초)
|
|
max_retries : int
|
|
요청 재시도 횟수
|
|
backoff : float
|
|
재시도 backoff base (지수)
|
|
session : requests.Session | None
|
|
세션 주입 가능
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: Optional[str] = None,
|
|
timeout: int = 120,
|
|
max_retries: int = 2,
|
|
backoff: float = 0.6,
|
|
session: Optional[requests.Session] = None,
|
|
logger: Optional[logging.Logger] = None,
|
|
) -> None:
|
|
self.base_url = (base_url or os.getenv("GEMMA_API_BASE") or "http://localhost").rstrip("/")
|
|
self.timeout = timeout
|
|
self.max_retries = max_retries
|
|
self.backoff = backoff
|
|
self.sess = session or requests.Session()
|
|
self.log = logger or logging.getLogger(__name__)
|
|
|
|
# -----------------------------
|
|
# 내부 HTTP 헬퍼 (retry 포함)
|
|
# -----------------------------
|
|
def _post(self, path: str, payload: _JSON) -> _JSON:
|
|
url = f"{self.base_url}{path}"
|
|
last_err: Optional[Exception] = None
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
r = self.sess.post(url, json=payload, timeout=self.timeout)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
except Exception as e:
|
|
last_err = e
|
|
# 429/5xx/연결오류 등 재시도
|
|
if attempt < self.max_retries:
|
|
sleep_s = (self.backoff ** attempt) + random.uniform(0, 0.2)
|
|
self.log.warning(f"[GemmaTranslator] POST {url} 실패({e}), 재시도 {attempt+1}/{self.max_retries} 대기 {sleep_s:.2f}s")
|
|
time.sleep(sleep_s)
|
|
else:
|
|
break
|
|
raise GemmaTranslatorError(f"POST {url} 실패: {last_err}")
|
|
|
|
# -----------------------------
|
|
# 공개 API
|
|
# -----------------------------
|
|
def health(self) -> _JSON:
|
|
url = f"{self.base_url}/health"
|
|
r = self.sess.get(url, timeout=10)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
def metrics(self) -> _JSON:
|
|
url = f"{self.base_url}/metrics"
|
|
r = self.sess.get(url, timeout=10)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
# ---- A) 순수 텍스트 리스트 번역 (/batch_translate) ----
|
|
def batch_translate_texts(
|
|
self,
|
|
product_name: str,
|
|
category: str,
|
|
text_list: List[str],
|
|
delimiter: str = " / ",
|
|
batch_size: int = 8,
|
|
) -> List[str]:
|
|
"""
|
|
text_list → 같은 길이의 ko 리스트 반환.
|
|
서버에서 추가 배치/토큰 슬라이싱을 하므로, 클라에서는 적당한 batch_size만 넘기면 됨.
|
|
"""
|
|
if not text_list:
|
|
return []
|
|
|
|
out: List[str] = []
|
|
for i in range(0, len(text_list), batch_size):
|
|
chunk = text_list[i : i + batch_size]
|
|
payload = {
|
|
"product_name": product_name,
|
|
"category": category,
|
|
"text_list": chunk,
|
|
"delimiter": delimiter,
|
|
"batch_size": min(batch_size, len(chunk)),
|
|
}
|
|
resp = self._post("/batch-translate", payload)
|
|
# 서버는 translated_texts 길이를 입력과 동일하게 맞춰줌
|
|
out.extend(resp.get("translated_texts", chunk))
|
|
return out
|
|
|
|
# ---- B) OCR 결과 번역: [{text:str, ...}] -> [ko, ...] ----
|
|
def translate_ocr_texts(
|
|
self,
|
|
product_name: str,
|
|
category: str,
|
|
ocr_results: List[Dict[str, Any]],
|
|
batch_size: int = 16,
|
|
) -> List[str]:
|
|
"""
|
|
입력: OCR 결과 리스트(각 항목에 최소 'text' 키 필요)
|
|
처리: 서버 /ocr-translate 로 {id, source} 배열을 보내고, 반환 {id, result} 정렬
|
|
출력: 원래 순서 유지한 ko 문자열 리스트
|
|
"""
|
|
if not ocr_results:
|
|
return []
|
|
|
|
# 유효성 검증 (새 스키마 준수)
|
|
if not product_name or len(product_name.strip()) < 1:
|
|
raise GemmaTranslatorError("product_name은 1자 이상이어야 합니다.")
|
|
if not category or len(category.strip()) < 1:
|
|
raise GemmaTranslatorError("category는 1자 이상이어야 합니다.")
|
|
|
|
# id 부여 및 source 필터링 (빈 텍스트 스킵, 최소 1자)
|
|
items = []
|
|
source_to_orig_idx = {} # source id to original index mapping
|
|
for i, d in enumerate(ocr_results):
|
|
source = (d.get("text") or "").strip()
|
|
if len(source) >= 1: # 최소 1자 이상
|
|
item_id = len(items) + 1
|
|
items.append({"id": item_id, "source": source})
|
|
source_to_orig_idx[item_id] = i
|
|
|
|
if not items:
|
|
return [""] * len(ocr_results) # 원래 길이만큼 빈 문자열 반환
|
|
|
|
# 서버에 배치로 전송
|
|
out_ko = [""] * len(ocr_results) # 원래 ocr_results 길이 유지
|
|
for i in range(0, len(items), batch_size):
|
|
chunk = items[i : i + batch_size]
|
|
payload = {
|
|
"product_name": product_name,
|
|
"category": category,
|
|
"items": chunk,
|
|
}
|
|
resp = self._post("/ocr-translate", payload)
|
|
|
|
result_items = resp.get("items", [])
|
|
# {id, result} 배열 → id 기준으로 매핑
|
|
for obj in result_items:
|
|
item_id = int(obj.get("id", 0))
|
|
orig_idx = source_to_orig_idx.get(item_id)
|
|
if orig_idx is not None:
|
|
out_ko[orig_idx] = str(obj.get("result", "")).strip()
|
|
|
|
return out_ko
|
|
|
|
# ---- C) 옵션 번역: [{id, source:[...]}] → [{id, translations:[...]}] ----
|
|
def translate_option_groups(
|
|
self,
|
|
product_name: str,
|
|
category: str,
|
|
option_groups: List[Dict[str, Any]],
|
|
batch_size: int = 8,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
option_groups 예:
|
|
[{"id": 1, "source": ["红色","蓝色"]}, {"id": 2, "source": ["小号","大号"]}]
|
|
반환(서버 결과 그대로):
|
|
[{"id": 1, "translations": ["핑크","블루"]}, {"id": 2, "translations": ["소형","대형"]}]
|
|
"""
|
|
if not option_groups:
|
|
return []
|
|
|
|
out: List[Dict[str, Any]] = []
|
|
for i in range(0, len(option_groups), batch_size):
|
|
chunk = option_groups[i : i + batch_size]
|
|
payload = {
|
|
"product_name": product_name,
|
|
"category": category,
|
|
"items": chunk,
|
|
}
|
|
resp = self._post("/option-translate", payload)
|
|
out.extend(resp.get("result", []))
|
|
return out
|
|
|
|
# ---- E) LLM API Translate (/api/v1/llm/run) ----
|
|
def run_llm_translation(
|
|
self,
|
|
product_name: str,
|
|
category: str,
|
|
ocr_results: List[Dict[str, Any]],
|
|
job_type: str = "ocr_translator_step1",
|
|
prompt_name: str = "ocr_translator_step1",
|
|
retry_count: int = 3,
|
|
steps: int = 1
|
|
) -> List[str]:
|
|
"""
|
|
새로운 LLM 번역 API (/api/v1/llm/run) 사용
|
|
입력: OCR 결과 리스트 [{'text': '...'}, ...]
|
|
출력: 번역된 문자열 리스트 (원래 순서 유지)
|
|
|
|
Args:
|
|
product_name: 상품명
|
|
category: 카테고리
|
|
ocr_results: OCR 결과 리스트
|
|
job_type: 작업 타입
|
|
prompt_name: 프롬프트 이름
|
|
retry_count: 재시도 횟수 (1-5)
|
|
steps: 번역 단계 (1=직역만, 2=직역+마케팅톤 변환)
|
|
"""
|
|
if not ocr_results:
|
|
return []
|
|
|
|
# 1. items 구성
|
|
items = []
|
|
source_to_orig_idx = {}
|
|
for i, res in enumerate(ocr_results):
|
|
text = (res.get("text") or "").strip()
|
|
if text:
|
|
item_id = len(items) + 1
|
|
items.append({"id": item_id, "source": text})
|
|
source_to_orig_idx[item_id] = i
|
|
|
|
if not items:
|
|
return [""] * len(ocr_results)
|
|
|
|
# 2. API 요청 payload 구성
|
|
payload = {
|
|
"category": category,
|
|
"items": items,
|
|
"job_type": job_type,
|
|
"product_name": product_name,
|
|
"prompt_name": prompt_name
|
|
}
|
|
|
|
# Query parameters: retry, steps
|
|
# steps 값 검증 (1 또는 2만 허용)
|
|
steps = max(1, min(2, steps))
|
|
path = f"/api/v1/llm/run?retry={retry_count}&steps={steps}"
|
|
|
|
try:
|
|
# 3. 요청 전송
|
|
# _post 메서드는 base_url + path로 요청하므로, base_url 설정이 중요함.
|
|
# 사용자 제공 URL이 /api/v1/llm/run 이므로, base_url이 호스트 루트여야 함.
|
|
# 만약 base_url이 /api 등을 포함하고 있다면 조정 필요.
|
|
# 여기서는 _post가 path를 그대로 붙인다고 가정.
|
|
|
|
resp = self._post(path, payload)
|
|
|
|
# 4. 응답 처리
|
|
# 응답 스키마: { "success": bool, "results": "string", "error": "string", ... }
|
|
if not resp.get("success"):
|
|
error_msg = resp.get("error", "Unknown error")
|
|
self.log.error(f"[GemmaTranslator] LLM API Error: {error_msg}")
|
|
# 실패 시 빈 문자열 또는 원본 반환? 여기선 빈 문자열 리스트로 둠 (혹은 예외 발생)
|
|
raise GemmaTranslatorError(f"LLM API returned success=False: {error_msg}")
|
|
|
|
results_data = resp.get("results")
|
|
|
|
# 결과가 없거나 비어있으면 처리
|
|
if not results_data:
|
|
# results가 None이거나 빈 리스트/문자열인 경우
|
|
return [""] * len(ocr_results)
|
|
|
|
translated_items = []
|
|
|
|
# 1. 리스트인 경우 (표준 서버 응답: List[Dict])
|
|
if isinstance(results_data, list):
|
|
translated_items = results_data
|
|
|
|
# 2. 딕셔너리인 경우 (단일 객체로 온 경우 대비)
|
|
elif isinstance(results_data, dict):
|
|
translated_items = [results_data]
|
|
|
|
# 3. 문자열인 경우 (JSON 텍스트로 온 경우 - 유연한 대응)
|
|
elif isinstance(results_data, str):
|
|
try:
|
|
translated_items = json.loads(results_data)
|
|
except json.JSONDecodeError:
|
|
# JSON 파싱 실패 시 Markdown 코드 블록 제거 시도
|
|
clean_str = results_data.strip()
|
|
|
|
# ```json 또는 ``` 제거
|
|
if clean_str.startswith("```"):
|
|
# 첫 줄 제거 (```json ... 또는 ``` ...)
|
|
parts = clean_str.split('\n', 1)
|
|
if len(parts) > 1:
|
|
clean_str = parts[1]
|
|
else:
|
|
clean_str = clean_str.replace("```", "")
|
|
|
|
if clean_str.endswith("```"):
|
|
clean_str = clean_str[:-3]
|
|
|
|
try:
|
|
translated_items = json.loads(clean_str.strip())
|
|
except json.JSONDecodeError:
|
|
self.log.error(f"[GemmaTranslator] LLM results is not valid JSON: {results_data[:100]}...")
|
|
raise GemmaTranslatorError("LLM API results parsing failed")
|
|
|
|
else:
|
|
self.log.error(f"[GemmaTranslator] Unexpected results type: {type(results_data)}")
|
|
raise GemmaTranslatorError(f"Unexpected results type: {type(results_data)}")
|
|
|
|
# 5. 결과 매핑
|
|
|
|
out_ko = [""] * len(ocr_results)
|
|
# translated_items가 리스트인지 확인
|
|
if isinstance(translated_items, list):
|
|
for item in translated_items:
|
|
# item 구조 확인: {"id": 1, "translation": "..."} (서버 로그 기반)
|
|
item_id = int(item.get("id", 0))
|
|
|
|
# 결과 필드 찾기 (translation > result > translated > source)
|
|
res_text = ""
|
|
if "translation" in item:
|
|
res_text = item["translation"]
|
|
elif "result" in item:
|
|
res_text = item["result"]
|
|
elif "translated" in item:
|
|
res_text = item["translated"]
|
|
elif "source" in item:
|
|
# source만 있고 번역이 없는 경우? (거의 없겠지만)
|
|
res_text = item["source"]
|
|
|
|
# None 체크
|
|
if res_text is None:
|
|
res_text = ""
|
|
|
|
orig_idx = source_to_orig_idx.get(item_id)
|
|
if orig_idx is not None:
|
|
out_ko[orig_idx] = str(res_text).strip()
|
|
else:
|
|
self.log.warning(f"[GemmaTranslator] Unexpected translated_items type: {type(translated_items)}")
|
|
|
|
return out_ko
|
|
|
|
except Exception as e:
|
|
self.log.error(f"[GemmaTranslator] run_llm_translation failed: {e}")
|
|
# Fallback behavior or re-raise?
|
|
# User wants to replace google translate. If this fails, maybe we should return original texts or empty?
|
|
# Or re-raise to let caller handle (e.g. fallback to Google).
|
|
raise e
|
|
|