136 lines
4.3 KiB
Python
136 lines
4.3 KiB
Python
from fastapi import FastAPI, Request, HTTPException, Response, UploadFile, File, Form
|
|
from pydantic import BaseModel, Field
|
|
from typing import Dict, Optional, List
|
|
from app.celery_worker import celery_app
|
|
from app.supabase_auth import check_user_permission
|
|
from celery.result import AsyncResult
|
|
import time
|
|
import os, shutil
|
|
import uuid
|
|
import logging
|
|
|
|
# 로거 설정
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI()
|
|
|
|
# ==== 입력 모델 ====
|
|
class TranslateRequest(BaseModel):
|
|
toggle_states: Dict
|
|
unwanted_texts: Dict
|
|
image_data: str
|
|
user_id: str
|
|
ocr_method: Optional[str] = "paddleocr"
|
|
inpaint_method: Optional[str] = "lama"
|
|
|
|
class InpaintRequest(BaseModel):
|
|
mask_image_data: str
|
|
image_data: str
|
|
user_id: str
|
|
inpaint_method: Optional[str] = "lama"
|
|
|
|
class OCRRequest(BaseModel):
|
|
image_data: str
|
|
user_id: str
|
|
ocr_method: Optional[str] = "paddleocr"
|
|
|
|
# ==== 리턴 모델 ====
|
|
class OCRBox(BaseModel):
|
|
text: str
|
|
box: List[int] # [x1, y1, x2, y2, ...]
|
|
|
|
class TranslateResponse(BaseModel):
|
|
ocr_texts: List[str]
|
|
ocr_boxes: List[OCRBox]
|
|
translated_texts: List[str]
|
|
inpainted_image: str # base64
|
|
|
|
class InpaintResponse(BaseModel):
|
|
inpainted_image: str # base64
|
|
|
|
class OCRResponse(BaseModel):
|
|
ocr_texts: List[str]
|
|
ocr_boxes: List[OCRBox]
|
|
|
|
# ==== 사용자 인증 ====
|
|
async def validate_user(user_id):
|
|
allowed = await check_user_permission(user_id)
|
|
if not allowed:
|
|
raise HTTPException(status_code=403, detail="권한이 없습니다.")
|
|
|
|
# ==== 셀러리 태스크 (워커 시스템으로 전송) ====
|
|
def start_celery_task(task_name, **kwargs):
|
|
"""워커 시스템으로 작업 전송"""
|
|
return celery_app.send_task(task_name, kwargs=kwargs)
|
|
|
|
# ==== 엔드포인트 ====
|
|
@app.post("/translate_me")
|
|
async def translate_me(req: TranslateRequest):
|
|
await validate_user(req.user_id)
|
|
filename = f"{uuid.uuid4().hex}_{int(time.time())}.png"
|
|
|
|
# 워커 시스템의 translate_task로 전송
|
|
task = start_celery_task("worker.translate_task", **req.dict(), filename=filename)
|
|
logger.info(f"번역 태스크 등록: {task.id}, 파일명: {filename}")
|
|
return {"task_id": task.id, "filename": filename}
|
|
|
|
@app.post("/inpaint_me")
|
|
async def inpaint_me(req: InpaintRequest):
|
|
await validate_user(req.user_id)
|
|
|
|
# 워커 시스템의 inpaint_task로 전송
|
|
task = start_celery_task("worker.inpaint_task", **req.dict())
|
|
logger.info(f"인페인팅 태스크 등록: {task.id}")
|
|
return {"task_id": task.id}
|
|
|
|
@app.post("/ocr_me")
|
|
async def ocr_me(req: OCRRequest):
|
|
await validate_user(req.user_id)
|
|
|
|
# 워커 시스템의 ocr_task로 전송
|
|
task = start_celery_task("worker.ocr_task", **req.dict())
|
|
logger.info(f"OCR 태스크 등록: {task.id}")
|
|
return {"task_id": task.id}
|
|
|
|
@app.get("/task_status/{task_id}")
|
|
async def get_task_status(task_id: str):
|
|
"""작업 상태 확인"""
|
|
try:
|
|
result = AsyncResult(task_id, app=celery_app)
|
|
return {
|
|
"task_id": task_id,
|
|
"status": result.status,
|
|
"result": result.result if result.ready() else None
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=404, detail=f"작업을 찾을 수 없습니다: {str(e)}")
|
|
|
|
@app.post("/upload_image")
|
|
async def upload_image(user_id: str = Form(...), file: UploadFile = File(...)):
|
|
original_dir = f"images/{user_id}/original/"
|
|
os.makedirs(original_dir, exist_ok=True)
|
|
filename = file.filename
|
|
file_path = os.path.join(original_dir, filename)
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
url = f"/images/{user_id}/original/{filename}"
|
|
return {"url": url}
|
|
|
|
@app.get("/download_image/")
|
|
def download_image(user_id: str, filename: str):
|
|
file_path = f"images/{user_id}/translated/{filename}"
|
|
if not os.path.exists(file_path):
|
|
return {"error": "파일이 없습니다"}
|
|
with open(file_path, "rb") as f:
|
|
data = f.read()
|
|
os.remove(file_path) # 다운로드 직후 삭제
|
|
return Response(content=data, media_type="image/png")
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "이미지 번역 메인 서버", "status": "running"}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""헬스 체크"""
|
|
return {"status": "healthy", "timestamp": time.time()} |