inpaintServer/main.py

453 lines
16 KiB
Python

#!/usr/bin/env python3
"""
인페인팅 서버 메인 애플리케이션
iopaint와 호환되는 API를 제공합니다.
"""
import time
import logging
import json
import asyncio
from contextlib import asynccontextmanager
from collections import defaultdict, deque
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from app.core.config import settings
from app.core.worker_manager import worker_manager
from app.core.session_pool import SessionPool, session_pool
from app.api.endpoints import router
from app.monitoring.dashboard import monitor_app
from app.core.batch_manager import batch_manager
# from app.utils.background_task import manage_state_background # TODO: 경로 확인 필요
from app.utils.discord_notifier import send_discord_notification
# 로깅 설정
import logging.handlers
import os
import logging.config
# 로그 디렉토리 생성
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
# 로그 회전 설정: 최대 10MB, 7일 보관
rotating_handler = logging.handlers.TimedRotatingFileHandler(
filename=os.path.join(log_dir, "main.log"),
when="D", # 일별 회전
interval=1, # 1일마다
backupCount=7, # 7일 보관
encoding="utf-8"
)
rotating_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# 파일 크기 제한 핸들러 추가 (10MB)
size_handler = logging.handlers.RotatingFileHandler(
filename=os.path.join(log_dir, "main_size.log"),
maxBytes=10*1024*1024, # 10MB
backupCount=5, # 최대 5개 파일
encoding="utf-8"
)
size_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# 콘솔 핸들러
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# 루트 로거 설정
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(rotating_handler)
root_logger.addHandler(size_handler)
root_logger.addHandler(console_handler)
logger = logging.getLogger(__name__)
# 서버 시작 시간 기록
start_time = time.time()
# API 통계 수집 클래스
class APIStatsCollector:
def __init__(self):
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.endpoint_usage = defaultdict(int)
self.endpoint_response_times = defaultdict(list)
self.endpoint_concurrent = defaultdict(int)
self.max_concurrent = 0
self.current_concurrent = 0
self.response_times = deque(maxlen=1000)
self.recent_errors = deque(maxlen=100)
self.start_time = time.time()
def start_request(self, endpoint: str):
"""요청 시작 시 호출"""
self.current_concurrent += 1
self.endpoint_concurrent[endpoint] += 1
self.max_concurrent = max(self.max_concurrent, self.current_concurrent)
def end_request(self, endpoint: str, success: bool, response_time: float, error: str = None):
"""요청 완료 시 호출"""
self.current_concurrent -= 1
self.endpoint_concurrent[endpoint] -= 1
self.total_requests += 1
self.endpoint_usage[endpoint] += 1
if success:
self.successful_requests += 1
else:
self.failed_requests += 1
if error:
self.recent_errors.append({
"timestamp": time.time(),
"endpoint": endpoint,
"error": error
})
self.response_times.append(response_time)
self.endpoint_response_times[endpoint].append(response_time)
# 최근 100개만 유지
if len(self.endpoint_response_times[endpoint]) > 100:
self.endpoint_response_times[endpoint] = self.endpoint_response_times[endpoint][-100:]
def get_stats(self):
"""현재 통계 반환"""
uptime = time.time() - self.start_time
# 전체 응답시간 통계
if self.response_times:
avg_response_time = sum(self.response_times) / len(self.response_times)
min_response_time = min(self.response_times)
max_response_time = max(self.response_times)
else:
avg_response_time = min_response_time = max_response_time = 0
# 엔드포인트별 상세 통계
endpoint_stats = {}
for endpoint, times in self.endpoint_response_times.items():
if times:
endpoint_stats[endpoint] = {
"count": self.endpoint_usage[endpoint],
"avg_time": sum(times) / len(times),
"min_time": min(times),
"max_time": max(times),
"current_concurrent": self.endpoint_concurrent[endpoint]
}
# 기본 엔드포인트 프리시드 (요청이 없어도 0으로 노출)
for endpoint in DEFAULT_ENDPOINTS:
if endpoint not in endpoint_stats:
endpoint_stats[endpoint] = {
"count": 0,
"avg_time": 0.0,
"min_time": 0.0,
"max_time": 0.0,
"current_concurrent": self.endpoint_concurrent[endpoint]
}
return {
"total_requests": self.total_requests,
"successful_requests": self.successful_requests,
"failed_requests": self.failed_requests,
"success_rate": (self.successful_requests / max(self.total_requests, 1)) * 100,
"endpoint_usage": dict(self.endpoint_usage),
"endpoint_stats": endpoint_stats,
"average_response_time": avg_response_time,
"min_response_time": min_response_time,
"max_response_time": max_response_time,
"current_concurrent": self.current_concurrent,
"max_concurrent": self.max_concurrent,
"requests_per_second": self.total_requests / max(uptime, 1),
"uptime": uptime,
"recent_errors": list(self.recent_errors)[-10:] # 최근 10개 에러
}
# 전역 통계 수집기
api_stats = APIStatsCollector()
# 대시보드/헬스 전용 경로는 API 통계에서 제외
# - 주기적 폴링으로 인해 실제 비즈니스 엔드포인트 통계를 왜곡시키지 않기 위함
EXCLUDED_ENDPOINTS = {
"/api/v1/health",
"/docs",
"/openapi.json",
"/redoc",
}
EXCLUDED_PREFIXES = [
"/api/v1/stats", # /api/v1/stats 및 /api/v1/stats/* 전체 제외
]
# 대시보드에 기본적으로 표시할 핵심 엔드포인트(요청이 없더라도 0으로 노출)
DEFAULT_ENDPOINTS = [
"POST /api/v1/inpaint",
"POST /api/v1/remove_bg",
"POST /api/v1/run_plugin_gen_image",
]
API_ERROR_LOG_PATH = os.path.join(log_dir, "api_errors.jsonl")
API_ERROR_MAX_BYTES = 10 * 1024 * 1024 # 10MB
API_ERROR_BACKUP_COUNT = 5
def _rotate_api_error_log_if_needed():
try:
if os.path.exists(API_ERROR_LOG_PATH) and os.path.getsize(API_ERROR_LOG_PATH) >= API_ERROR_MAX_BYTES:
ts = time.strftime("%Y%m%d-%H%M%S")
rotated_path = os.path.join(log_dir, f"api_errors_{ts}.jsonl")
os.replace(API_ERROR_LOG_PATH, rotated_path)
# 오래된 로테이션 파일 정리 (최신 N개만 유지)
rotated = [
os.path.join(log_dir, f) for f in os.listdir(log_dir)
if f.startswith("api_errors_") and f.endswith(".jsonl")
]
rotated.sort(key=lambda p: os.path.getmtime(p), reverse=True)
for old in rotated[API_ERROR_BACKUP_COUNT:]:
try:
os.remove(old)
except Exception:
pass
except Exception as e: # pragma: no cover
logger.warning(f"API 에러 로그 로테이션 실패: {e}")
def _append_api_error_log(record: dict):
"""에러 전용 JSONL 로그에 한 줄 추가"""
try:
_rotate_api_error_log_if_needed()
with open(API_ERROR_LOG_PATH, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
except Exception as e: # pragma: no cover
logger.warning(f"API 에러 로그 기록 실패: {e}")
async def save_status_periodically():
"""주기적으로 워커와 세션 상태를 파일에 저장합니다."""
logger.info("🔄 상태 저장 백그라운드 작업 시작됨")
iteration = 0
while True:
try:
iteration += 1
logger.debug(f"상태 저장 시도 #{iteration}")
# 워커 상태 수집
worker_status = worker_manager.get_status()
logger.debug(f"워커 상태 수집 완료: {worker_status}")
# 세션 상태 수집
session_status = session_pool.get_status()
logger.debug(f"세션 상태 수집 완료: {session_status}")
# API 통계 수집
api_statistics = api_stats.get_stats()
logger.debug(f"API 통계 수집 완료: 총 요청 {api_statistics['total_requests']}")
status = {
"worker_status": worker_status,
"session_status": session_status,
"api_stats": api_statistics,
"timestamp": time.time()
}
# 파일에 저장
with open("status.json", "w") as f:
json.dump(status, f, indent=2)
logger.debug(f"상태 저장 완료 #{iteration}: {time.strftime('%H:%M:%S')}")
except Exception as e:
logger.error(f"상태 저장 실패 #{iteration}: {e}")
import traceback
logger.error(f"상세 오류: {traceback.format_exc()}")
await asyncio.sleep(5)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""애플리케이션 생명주기 관리"""
# 서버 시작 시
logger.info("🚀 인페인팅 서버 시작 중...")
app.state.start_time = time.time() # settings 대신 app.state에 저장
# app.state에 공유 객체 저장
app.state.worker_manager = worker_manager
app.state.session_pool = session_pool
logger.info("✅ 공유 객체를 app.state에 저장 완료")
# 상태 저장 백그라운드 작업 시작
logger.info("🔄 상태 저장 백그라운드 작업 생성 중...")
status_task = asyncio.create_task(save_status_periodically())
logger.info("✅ 상태 저장 백그라운드 작업 생성 완료")
try:
# ONNX Runtime과 RemBG가 자동으로 CUDA 감지
logger.info("🚀 세션 풀 초기화 (CUDA 자동 감지)")
await session_pool.initialize()
logger.info("✅ 세션 풀 초기화 완료")
# 워커 매니저 시작
await worker_manager.start()
logger.info("✅ 워커 매니저 시작 완료")
if settings.USE_MICRO_BATCHING:
await batch_manager.start()
logger.info("✅ 배치 관리자 시작 완료")
# app.state.background_task = asyncio.create_task(manage_state_background(app.state))
# logger.info("✅ 상태 저장 백그라운드 작업 생성 완료")
logger.info("🎉 인페인팅 서버 시작 완료!")
send_discord_notification("✅ 서버가 성공적으로 시작되었습니다.", level="success")
except Exception as e:
logger.error(f"❌ 서버 시작 실패: {e}")
raise
yield
# 서버 종료 시
logger.info("🛑 인페인팅 서버 종료 중...")
# 상태 저장 백그라운드 작업 취소
status_task.cancel()
try:
# 워커 매니저 중지
await worker_manager.stop()
logger.info("✅ 워커 매니저 중지 완료")
if settings.USE_MICRO_BATCHING:
await batch_manager.stop()
logger.info("✅ 배치 관리자 중지 완료")
# if app.state.background_task:
# app.state.background_task.cancel()
# try:
# await app.state.background_task
# except asyncio.CancelledError:
# logger.info("상태 저장 백그라운드 작업이 정상적으로 취소되었습니다.")
logger.info("👋 인페인팅 서버 종료 완료")
send_discord_notification("👋 서버가 종료되었습니다.", level="info")
except Exception as e:
logger.error(f"❌ 서버 종료 중 오류: {e}")
# 메인 애플리케이션 생성
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
lifespan=lifespan
)
# API 통계 수집 미들웨어
@app.middleware("http")
async def collect_api_stats(request: Request, call_next):
"""API 호출 통계를 수집하는 미들웨어"""
start_time = time.time()
path = request.url.path
# 통계 제외 대상이면 단순 패스스루 (카운팅/지표 반영 안 함)
if path in EXCLUDED_ENDPOINTS or any(path.startswith(p) for p in EXCLUDED_PREFIXES):
return await call_next(request)
endpoint = f"{request.method} {path}"
# 요청 시작
api_stats.start_request(endpoint)
try:
# 실제 요청 처리
response = await call_next(request)
# 응답 시간 계산
response_time = time.time() - start_time
# 성공/실패 판단 (2xx, 3xx는 성공)
success = 200 <= response.status_code < 400
# 통계 업데이트
api_stats.end_request(endpoint, success, response_time)
# 4xx/5xx는 에러 로그 파일에 기록
if not success:
_append_api_error_log({
"timestamp": time.time(),
"method": request.method,
"path": path,
"status": response.status_code,
"response_time_ms": int(response_time * 1000)
})
return response
except Exception as e:
# 에러 발생 시
response_time = time.time() - start_time
api_stats.end_request(endpoint, False, response_time, str(e))
_append_api_error_log({
"timestamp": time.time(),
"method": request.method,
"path": path,
"status": 500,
"error": str(e),
"response_time_ms": int(response_time * 1000)
})
raise
# CORS 미들웨어 추가
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API 라우터 포함
app.include_router(router)
# 모니터링은 start_server.sh를 통해 독립적으로 실행됩니다.
# app.mount("/monitoring", monitor_app, name="monitoring")
# 모니터링은 독립적인 서버(포트 8001)에서 처리됩니다.
# status.json 파일을 통해 데이터를 공유합니다.
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="인페인팅 서버")
parser.add_argument("--dev", action="store_true", help="개발 모드로 실행")
parser.add_argument("--host", default=settings.HOST, help="호스트 주소")
parser.add_argument("--port", type=int, default=settings.PORT, help="포트 번호")
parser.add_argument("--workers", type=int, default=settings.WORKERS, help="워커 수")
args = parser.parse_args()
if args.dev:
logger.info("🔧 개발 모드로 실행합니다")
uvicorn.run(
"main:app",
host=args.host,
port=args.port,
reload=True,
log_level="info"
)
else:
logger.info("🚀 프로덕션 모드로 실행합니다")
uvicorn.run(
"main:app",
host=args.host,
port=args.port,
workers=args.workers,
log_level="info"
)