inpaintServer/main.py

366 lines
12 KiB
Python

#!/usr/bin/env python3
"""
인페인팅 서버 메인 애플리케이션
iopaint와 호환되는 API를 제공합니다.
"""
import time
import logging
import json
import asyncio
from contextlib import asynccontextmanager
from collections import defaultdict, deque
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from app.core.config import settings
from app.core.worker_manager import worker_manager
from app.core.session_pool import SessionPool, session_pool
from app.api.endpoints import router
from app.monitoring.dashboard import monitor_app
from app.core.batch_manager import batch_manager
# from app.utils.background_task import manage_state_background # TODO: 경로 확인 필요
from app.utils.discord_notifier import send_discord_notification
# 로깅 설정
import logging.handlers
import os
import logging.config
# 로그 디렉토리 생성
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
# 로그 회전 설정: 최대 10MB, 7일 보관
rotating_handler = logging.handlers.TimedRotatingFileHandler(
filename=os.path.join(log_dir, "main.log"),
when="D", # 일별 회전
interval=1, # 1일마다
backupCount=7, # 7일 보관
encoding="utf-8"
)
rotating_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# 파일 크기 제한 핸들러 추가 (10MB)
size_handler = logging.handlers.RotatingFileHandler(
filename=os.path.join(log_dir, "main_size.log"),
maxBytes=10*1024*1024, # 10MB
backupCount=5, # 최대 5개 파일
encoding="utf-8"
)
size_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# 콘솔 핸들러
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
# 루트 로거 설정
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(rotating_handler)
root_logger.addHandler(size_handler)
root_logger.addHandler(console_handler)
logger = logging.getLogger(__name__)
# 서버 시작 시간 기록
start_time = time.time()
# API 통계 수집 클래스
class APIStatsCollector:
def __init__(self):
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.endpoint_usage = defaultdict(int)
self.endpoint_response_times = defaultdict(list)
self.endpoint_concurrent = defaultdict(int)
self.max_concurrent = 0
self.current_concurrent = 0
self.response_times = deque(maxlen=1000)
self.recent_errors = deque(maxlen=100)
self.start_time = time.time()
def start_request(self, endpoint: str):
"""요청 시작 시 호출"""
self.current_concurrent += 1
self.endpoint_concurrent[endpoint] += 1
self.max_concurrent = max(self.max_concurrent, self.current_concurrent)
def end_request(self, endpoint: str, success: bool, response_time: float, error: str = None):
"""요청 완료 시 호출"""
self.current_concurrent -= 1
self.endpoint_concurrent[endpoint] -= 1
self.total_requests += 1
self.endpoint_usage[endpoint] += 1
if success:
self.successful_requests += 1
else:
self.failed_requests += 1
if error:
self.recent_errors.append({
"timestamp": time.time(),
"endpoint": endpoint,
"error": error
})
self.response_times.append(response_time)
self.endpoint_response_times[endpoint].append(response_time)
# 최근 100개만 유지
if len(self.endpoint_response_times[endpoint]) > 100:
self.endpoint_response_times[endpoint] = self.endpoint_response_times[endpoint][-100:]
def get_stats(self):
"""현재 통계 반환"""
uptime = time.time() - self.start_time
# 전체 응답시간 통계
if self.response_times:
avg_response_time = sum(self.response_times) / len(self.response_times)
min_response_time = min(self.response_times)
max_response_time = max(self.response_times)
else:
avg_response_time = min_response_time = max_response_time = 0
# 엔드포인트별 상세 통계
endpoint_stats = {}
for endpoint, times in self.endpoint_response_times.items():
if times:
endpoint_stats[endpoint] = {
"count": self.endpoint_usage[endpoint],
"avg_time": sum(times) / len(times),
"min_time": min(times),
"max_time": max(times),
"current_concurrent": self.endpoint_concurrent[endpoint]
}
return {
"total_requests": self.total_requests,
"successful_requests": self.successful_requests,
"failed_requests": self.failed_requests,
"success_rate": (self.successful_requests / max(self.total_requests, 1)) * 100,
"endpoint_usage": dict(self.endpoint_usage),
"endpoint_stats": endpoint_stats,
"average_response_time": avg_response_time,
"min_response_time": min_response_time,
"max_response_time": max_response_time,
"current_concurrent": self.current_concurrent,
"max_concurrent": self.max_concurrent,
"requests_per_second": self.total_requests / max(uptime, 1),
"uptime": uptime,
"recent_errors": list(self.recent_errors)[-10:] # 최근 10개 에러
}
# 전역 통계 수집기
api_stats = APIStatsCollector()
async def save_status_periodically():
"""주기적으로 워커와 세션 상태를 파일에 저장합니다."""
logger.info("🔄 상태 저장 백그라운드 작업 시작됨")
iteration = 0
while True:
try:
iteration += 1
logger.debug(f"상태 저장 시도 #{iteration}")
# 워커 상태 수집
worker_status = worker_manager.get_status()
logger.debug(f"워커 상태 수집 완료: {worker_status}")
# 세션 상태 수집
session_status = session_pool.get_status()
logger.debug(f"세션 상태 수집 완료: {session_status}")
# API 통계 수집
api_statistics = api_stats.get_stats()
logger.debug(f"API 통계 수집 완료: 총 요청 {api_statistics['total_requests']}")
status = {
"worker_status": worker_status,
"session_status": session_status,
"api_stats": api_statistics,
"timestamp": time.time()
}
# 파일에 저장
with open("status.json", "w") as f:
json.dump(status, f, indent=2)
logger.debug(f"상태 저장 완료 #{iteration}: {time.strftime('%H:%M:%S')}")
except Exception as e:
logger.error(f"상태 저장 실패 #{iteration}: {e}")
import traceback
logger.error(f"상세 오류: {traceback.format_exc()}")
await asyncio.sleep(5)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""애플리케이션 생명주기 관리"""
# 서버 시작 시
logger.info("🚀 인페인팅 서버 시작 중...")
app.state.start_time = time.time() # settings 대신 app.state에 저장
# app.state에 공유 객체 저장
app.state.worker_manager = worker_manager
app.state.session_pool = session_pool
logger.info("✅ 공유 객체를 app.state에 저장 완료")
# 상태 저장 백그라운드 작업 시작
logger.info("🔄 상태 저장 백그라운드 작업 생성 중...")
status_task = asyncio.create_task(save_status_periodically())
logger.info("✅ 상태 저장 백그라운드 작업 생성 완료")
try:
# ONNX Runtime과 RemBG가 자동으로 CUDA 감지
logger.info("🚀 세션 풀 초기화 (CUDA 자동 감지)")
await session_pool.initialize()
logger.info("✅ 세션 풀 초기화 완료")
# 워커 매니저 시작
await worker_manager.start()
logger.info("✅ 워커 매니저 시작 완료")
if settings.USE_MICRO_BATCHING:
await batch_manager.start()
logger.info("✅ 배치 관리자 시작 완료")
# app.state.background_task = asyncio.create_task(manage_state_background(app.state))
# logger.info("✅ 상태 저장 백그라운드 작업 생성 완료")
logger.info("🎉 인페인팅 서버 시작 완료!")
send_discord_notification("✅ 서버가 성공적으로 시작되었습니다.", level="success")
except Exception as e:
logger.error(f"❌ 서버 시작 실패: {e}")
raise
yield
# 서버 종료 시
logger.info("🛑 인페인팅 서버 종료 중...")
# 상태 저장 백그라운드 작업 취소
status_task.cancel()
try:
# 워커 매니저 중지
await worker_manager.stop()
logger.info("✅ 워커 매니저 중지 완료")
if settings.USE_MICRO_BATCHING:
await batch_manager.stop()
logger.info("✅ 배치 관리자 중지 완료")
# if app.state.background_task:
# app.state.background_task.cancel()
# try:
# await app.state.background_task
# except asyncio.CancelledError:
# logger.info("상태 저장 백그라운드 작업이 정상적으로 취소되었습니다.")
logger.info("👋 인페인팅 서버 종료 완료")
send_discord_notification("👋 서버가 종료되었습니다.", level="info")
except Exception as e:
logger.error(f"❌ 서버 종료 중 오류: {e}")
# 메인 애플리케이션 생성
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
lifespan=lifespan
)
# API 통계 수집 미들웨어
@app.middleware("http")
async def collect_api_stats(request: Request, call_next):
"""API 호출 통계를 수집하는 미들웨어"""
start_time = time.time()
endpoint = f"{request.method} {request.url.path}"
# 요청 시작
api_stats.start_request(endpoint)
try:
# 실제 요청 처리
response = await call_next(request)
# 응답 시간 계산
response_time = time.time() - start_time
# 성공/실패 판단 (2xx, 3xx는 성공)
success = 200 <= response.status_code < 400
# 통계 업데이트
api_stats.end_request(endpoint, success, response_time)
return response
except Exception as e:
# 에러 발생 시
response_time = time.time() - start_time
api_stats.end_request(endpoint, False, response_time, str(e))
raise
# CORS 미들웨어 추가
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API 라우터 포함
app.include_router(router)
# 모니터링은 start_server.sh를 통해 독립적으로 실행됩니다.
# app.mount("/monitoring", monitor_app, name="monitoring")
# 모니터링은 독립적인 서버(포트 8001)에서 처리됩니다.
# status.json 파일을 통해 데이터를 공유합니다.
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="인페인팅 서버")
parser.add_argument("--dev", action="store_true", help="개발 모드로 실행")
parser.add_argument("--host", default=settings.HOST, help="호스트 주소")
parser.add_argument("--port", type=int, default=settings.PORT, help="포트 번호")
parser.add_argument("--workers", type=int, default=settings.WORKERS, help="워커 수")
args = parser.parse_args()
if args.dev:
logger.info("🔧 개발 모드로 실행합니다")
uvicorn.run(
"main:app",
host=args.host,
port=args.port,
reload=True,
log_level="info"
)
else:
logger.info("🚀 프로덕션 모드로 실행합니다")
uvicorn.run(
"main:app",
host=args.host,
port=args.port,
workers=args.workers,
log_level="info"
)