366 lines
12 KiB
Python
366 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
인페인팅 서버 메인 애플리케이션
|
|
iopaint와 호환되는 API를 제공합니다.
|
|
"""
|
|
import time
|
|
import logging
|
|
import json
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from collections import defaultdict, deque
|
|
from fastapi import FastAPI, Request, Response
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
|
|
from app.core.config import settings
|
|
from app.core.worker_manager import worker_manager
|
|
from app.core.session_pool import SessionPool, session_pool
|
|
from app.api.endpoints import router
|
|
from app.monitoring.dashboard import monitor_app
|
|
from app.core.batch_manager import batch_manager
|
|
# from app.utils.background_task import manage_state_background # TODO: 경로 확인 필요
|
|
from app.utils.discord_notifier import send_discord_notification
|
|
|
|
# 로깅 설정
|
|
import logging.handlers
|
|
import os
|
|
import logging.config
|
|
|
|
# 로그 디렉토리 생성
|
|
log_dir = "logs"
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
# 로그 회전 설정: 최대 10MB, 7일 보관
|
|
rotating_handler = logging.handlers.TimedRotatingFileHandler(
|
|
filename=os.path.join(log_dir, "main.log"),
|
|
when="D", # 일별 회전
|
|
interval=1, # 1일마다
|
|
backupCount=7, # 7일 보관
|
|
encoding="utf-8"
|
|
)
|
|
rotating_handler.setFormatter(logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
))
|
|
|
|
# 파일 크기 제한 핸들러 추가 (10MB)
|
|
size_handler = logging.handlers.RotatingFileHandler(
|
|
filename=os.path.join(log_dir, "main_size.log"),
|
|
maxBytes=10*1024*1024, # 10MB
|
|
backupCount=5, # 최대 5개 파일
|
|
encoding="utf-8"
|
|
)
|
|
size_handler.setFormatter(logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
))
|
|
|
|
# 콘솔 핸들러
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
))
|
|
|
|
# 루트 로거 설정
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(logging.INFO)
|
|
root_logger.addHandler(rotating_handler)
|
|
root_logger.addHandler(size_handler)
|
|
root_logger.addHandler(console_handler)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 서버 시작 시간 기록
|
|
start_time = time.time()
|
|
|
|
# API 통계 수집 클래스
|
|
class APIStatsCollector:
|
|
def __init__(self):
|
|
self.total_requests = 0
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.endpoint_usage = defaultdict(int)
|
|
self.endpoint_response_times = defaultdict(list)
|
|
self.endpoint_concurrent = defaultdict(int)
|
|
self.max_concurrent = 0
|
|
self.current_concurrent = 0
|
|
self.response_times = deque(maxlen=1000)
|
|
self.recent_errors = deque(maxlen=100)
|
|
self.start_time = time.time()
|
|
|
|
def start_request(self, endpoint: str):
|
|
"""요청 시작 시 호출"""
|
|
self.current_concurrent += 1
|
|
self.endpoint_concurrent[endpoint] += 1
|
|
self.max_concurrent = max(self.max_concurrent, self.current_concurrent)
|
|
|
|
def end_request(self, endpoint: str, success: bool, response_time: float, error: str = None):
|
|
"""요청 완료 시 호출"""
|
|
self.current_concurrent -= 1
|
|
self.endpoint_concurrent[endpoint] -= 1
|
|
self.total_requests += 1
|
|
self.endpoint_usage[endpoint] += 1
|
|
|
|
if success:
|
|
self.successful_requests += 1
|
|
else:
|
|
self.failed_requests += 1
|
|
if error:
|
|
self.recent_errors.append({
|
|
"timestamp": time.time(),
|
|
"endpoint": endpoint,
|
|
"error": error
|
|
})
|
|
|
|
self.response_times.append(response_time)
|
|
self.endpoint_response_times[endpoint].append(response_time)
|
|
|
|
# 최근 100개만 유지
|
|
if len(self.endpoint_response_times[endpoint]) > 100:
|
|
self.endpoint_response_times[endpoint] = self.endpoint_response_times[endpoint][-100:]
|
|
|
|
def get_stats(self):
|
|
"""현재 통계 반환"""
|
|
uptime = time.time() - self.start_time
|
|
|
|
# 전체 응답시간 통계
|
|
if self.response_times:
|
|
avg_response_time = sum(self.response_times) / len(self.response_times)
|
|
min_response_time = min(self.response_times)
|
|
max_response_time = max(self.response_times)
|
|
else:
|
|
avg_response_time = min_response_time = max_response_time = 0
|
|
|
|
# 엔드포인트별 상세 통계
|
|
endpoint_stats = {}
|
|
for endpoint, times in self.endpoint_response_times.items():
|
|
if times:
|
|
endpoint_stats[endpoint] = {
|
|
"count": self.endpoint_usage[endpoint],
|
|
"avg_time": sum(times) / len(times),
|
|
"min_time": min(times),
|
|
"max_time": max(times),
|
|
"current_concurrent": self.endpoint_concurrent[endpoint]
|
|
}
|
|
|
|
return {
|
|
"total_requests": self.total_requests,
|
|
"successful_requests": self.successful_requests,
|
|
"failed_requests": self.failed_requests,
|
|
"success_rate": (self.successful_requests / max(self.total_requests, 1)) * 100,
|
|
"endpoint_usage": dict(self.endpoint_usage),
|
|
"endpoint_stats": endpoint_stats,
|
|
"average_response_time": avg_response_time,
|
|
"min_response_time": min_response_time,
|
|
"max_response_time": max_response_time,
|
|
"current_concurrent": self.current_concurrent,
|
|
"max_concurrent": self.max_concurrent,
|
|
"requests_per_second": self.total_requests / max(uptime, 1),
|
|
"uptime": uptime,
|
|
"recent_errors": list(self.recent_errors)[-10:] # 최근 10개 에러
|
|
}
|
|
|
|
# 전역 통계 수집기
|
|
api_stats = APIStatsCollector()
|
|
|
|
async def save_status_periodically():
|
|
"""주기적으로 워커와 세션 상태를 파일에 저장합니다."""
|
|
logger.info("🔄 상태 저장 백그라운드 작업 시작됨")
|
|
iteration = 0
|
|
while True:
|
|
try:
|
|
iteration += 1
|
|
logger.debug(f"상태 저장 시도 #{iteration}")
|
|
|
|
# 워커 상태 수집
|
|
worker_status = worker_manager.get_status()
|
|
logger.debug(f"워커 상태 수집 완료: {worker_status}")
|
|
|
|
# 세션 상태 수집
|
|
session_status = session_pool.get_status()
|
|
logger.debug(f"세션 상태 수집 완료: {session_status}")
|
|
|
|
# API 통계 수집
|
|
api_statistics = api_stats.get_stats()
|
|
logger.debug(f"API 통계 수집 완료: 총 요청 {api_statistics['total_requests']}개")
|
|
|
|
status = {
|
|
"worker_status": worker_status,
|
|
"session_status": session_status,
|
|
"api_stats": api_statistics,
|
|
"timestamp": time.time()
|
|
}
|
|
|
|
# 파일에 저장
|
|
with open("status.json", "w") as f:
|
|
json.dump(status, f, indent=2)
|
|
|
|
logger.debug(f"상태 저장 완료 #{iteration}: {time.strftime('%H:%M:%S')}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"상태 저장 실패 #{iteration}: {e}")
|
|
import traceback
|
|
logger.error(f"상세 오류: {traceback.format_exc()}")
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""애플리케이션 생명주기 관리"""
|
|
# 서버 시작 시
|
|
logger.info("🚀 인페인팅 서버 시작 중...")
|
|
app.state.start_time = time.time() # settings 대신 app.state에 저장
|
|
|
|
# app.state에 공유 객체 저장
|
|
app.state.worker_manager = worker_manager
|
|
app.state.session_pool = session_pool
|
|
logger.info("✅ 공유 객체를 app.state에 저장 완료")
|
|
|
|
# 상태 저장 백그라운드 작업 시작
|
|
logger.info("🔄 상태 저장 백그라운드 작업 생성 중...")
|
|
status_task = asyncio.create_task(save_status_periodically())
|
|
logger.info("✅ 상태 저장 백그라운드 작업 생성 완료")
|
|
|
|
try:
|
|
# ONNX Runtime과 RemBG가 자동으로 CUDA 감지
|
|
logger.info("🚀 세션 풀 초기화 (CUDA 자동 감지)")
|
|
|
|
await session_pool.initialize()
|
|
logger.info("✅ 세션 풀 초기화 완료")
|
|
|
|
# 워커 매니저 시작
|
|
await worker_manager.start()
|
|
logger.info("✅ 워커 매니저 시작 완료")
|
|
|
|
if settings.USE_MICRO_BATCHING:
|
|
await batch_manager.start()
|
|
logger.info("✅ 배치 관리자 시작 완료")
|
|
|
|
# app.state.background_task = asyncio.create_task(manage_state_background(app.state))
|
|
# logger.info("✅ 상태 저장 백그라운드 작업 생성 완료")
|
|
|
|
logger.info("🎉 인페인팅 서버 시작 완료!")
|
|
send_discord_notification("✅ 서버가 성공적으로 시작되었습니다.", level="success")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 서버 시작 실패: {e}")
|
|
raise
|
|
|
|
yield
|
|
|
|
# 서버 종료 시
|
|
logger.info("🛑 인페인팅 서버 종료 중...")
|
|
|
|
# 상태 저장 백그라운드 작업 취소
|
|
status_task.cancel()
|
|
|
|
try:
|
|
# 워커 매니저 중지
|
|
await worker_manager.stop()
|
|
logger.info("✅ 워커 매니저 중지 완료")
|
|
|
|
if settings.USE_MICRO_BATCHING:
|
|
await batch_manager.stop()
|
|
logger.info("✅ 배치 관리자 중지 완료")
|
|
|
|
# if app.state.background_task:
|
|
# app.state.background_task.cancel()
|
|
# try:
|
|
# await app.state.background_task
|
|
# except asyncio.CancelledError:
|
|
# logger.info("상태 저장 백그라운드 작업이 정상적으로 취소되었습니다.")
|
|
|
|
logger.info("👋 인페인팅 서버 종료 완료")
|
|
send_discord_notification("👋 서버가 종료되었습니다.", level="info")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 서버 종료 중 오류: {e}")
|
|
|
|
|
|
# 메인 애플리케이션 생성
|
|
app = FastAPI(
|
|
title=settings.APP_NAME,
|
|
version=settings.APP_VERSION,
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# API 통계 수집 미들웨어
|
|
@app.middleware("http")
|
|
async def collect_api_stats(request: Request, call_next):
|
|
"""API 호출 통계를 수집하는 미들웨어"""
|
|
start_time = time.time()
|
|
endpoint = f"{request.method} {request.url.path}"
|
|
|
|
# 요청 시작
|
|
api_stats.start_request(endpoint)
|
|
|
|
try:
|
|
# 실제 요청 처리
|
|
response = await call_next(request)
|
|
|
|
# 응답 시간 계산
|
|
response_time = time.time() - start_time
|
|
|
|
# 성공/실패 판단 (2xx, 3xx는 성공)
|
|
success = 200 <= response.status_code < 400
|
|
|
|
# 통계 업데이트
|
|
api_stats.end_request(endpoint, success, response_time)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
# 에러 발생 시
|
|
response_time = time.time() - start_time
|
|
api_stats.end_request(endpoint, False, response_time, str(e))
|
|
raise
|
|
|
|
# CORS 미들웨어 추가
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# API 라우터 포함
|
|
app.include_router(router)
|
|
|
|
# 모니터링은 start_server.sh를 통해 독립적으로 실행됩니다.
|
|
# app.mount("/monitoring", monitor_app, name="monitoring")
|
|
|
|
|
|
# 모니터링은 독립적인 서버(포트 8001)에서 처리됩니다.
|
|
# status.json 파일을 통해 데이터를 공유합니다.
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="인페인팅 서버")
|
|
parser.add_argument("--dev", action="store_true", help="개발 모드로 실행")
|
|
parser.add_argument("--host", default=settings.HOST, help="호스트 주소")
|
|
parser.add_argument("--port", type=int, default=settings.PORT, help="포트 번호")
|
|
parser.add_argument("--workers", type=int, default=settings.WORKERS, help="워커 수")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.dev:
|
|
logger.info("🔧 개발 모드로 실행합니다")
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=args.host,
|
|
port=args.port,
|
|
reload=True,
|
|
log_level="info"
|
|
)
|
|
else:
|
|
logger.info("🚀 프로덕션 모드로 실행합니다")
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=args.host,
|
|
port=args.port,
|
|
workers=args.workers,
|
|
log_level="info"
|
|
)
|