inpaintServer/app/core/session_pool.py

291 lines
11 KiB
Python

"""
세션 풀 관리 시스템
각 모델(simple-lama, migan, rembg)의 세션을 효율적으로 관리합니다.
"""
import asyncio
import logging
import time
from typing import Dict, List, Optional, Any
from enum import Enum
from dataclasses import dataclass
from contextlib import asynccontextmanager
logger = logging.getLogger(__name__)
class ModelType(Enum):
SIMPLE_LAMA = "simple_lama"
MIGAN = "migan"
REMBG = "rembg"
@dataclass
class ModelSession:
session_id: str
model_type: ModelType
model: Any
created_at: float
last_used: float
in_use: bool = False
def mark_used(self):
self.last_used = time.time()
def is_expired(self, timeout: int = 3600) -> bool:
return time.time() - self.last_used > timeout
class SessionPool:
def __init__(self,
simple_lama_count: int = 2,
migan_count: int = 2,
rembg_count: int = 1):
self.pools: Dict[ModelType, List[ModelSession]] = {
ModelType.SIMPLE_LAMA: [],
ModelType.MIGAN: [],
ModelType.REMBG: []
}
self.pool_sizes = {
ModelType.SIMPLE_LAMA: simple_lama_count,
ModelType.MIGAN: migan_count,
ModelType.REMBG: rembg_count
}
self.locks: Dict[ModelType, asyncio.Lock] = {
model_type: asyncio.Lock() for model_type in ModelType
}
self._initialized = False
async def initialize(self):
"""모든 모델 세션을 초기화합니다."""
if self._initialized:
return
logger.info("Initializing session pools...")
for model_type, count in self.pool_sizes.items():
await self._initialize_model_pool(model_type, count)
self._initialized = True
logger.info("Session pools initialized successfully")
async def _initialize_model_pool(self, model_type: ModelType, count: int):
"""특정 모델의 세션 풀을 초기화합니다."""
logger.info(f"Initializing {count} sessions for {model_type.value}")
for i in range(count):
try:
session = await self._create_session(model_type, f"{model_type.value}_{i}")
self.pools[model_type].append(session)
logger.info(f"Created session {session.session_id}")
except Exception as e:
logger.error(f"Failed to create session for {model_type.value}: {e}")
async def _create_session(self, model_type: ModelType, session_id: str) -> ModelSession:
"""새로운 모델 세션을 생성합니다."""
model = await self._load_model(model_type)
return ModelSession(
session_id=session_id,
model_type=model_type,
model=model,
created_at=time.time(),
last_used=time.time()
)
async def _load_model(self, model_type: ModelType) -> Any:
"""모델을 로드합니다."""
# 실제 구현에서는 각 모델을 로드하는 로직이 들어갑니다
if model_type == ModelType.SIMPLE_LAMA:
return await self._load_simple_lama_model()
elif model_type == ModelType.MIGAN:
return await self._load_migan_model()
elif model_type == ModelType.REMBG:
return await self._load_rembg_model()
else:
raise ValueError(f"Unknown model type: {model_type}")
async def _load_simple_lama_model(self):
"""Simple LAMA 모델을 로드합니다."""
from ..models.simple_lama import SimpleLamaInpainter
from ..core.config import settings
try:
model = SimpleLamaInpainter(
model_path=settings.SIMPLE_LAMA_MODEL_PATH,
device="cuda" if settings.USE_CUDA else "cpu",
fp16=settings.USE_FP16
)
await model.load_model()
logger.info("Simple LAMA 모델 세션 로드 완료")
return model
except Exception as e:
logger.error(f"Simple LAMA 모델 로드 실패: {e}")
raise
async def _load_migan_model(self):
"""MIGAN 모델을 로드합니다."""
from ..models.migan import MiganInpainter
from ..core.config import settings
try:
# MIGAN 모델 생성 - ONNX Runtime이 자동으로 CUDA 감지
model = MiganInpainter(
model_path=getattr(settings, 'MIGAN_ONNX_PATH', settings.MIGAN_MODEL_PATH),
device="cuda" if settings.USE_CUDA else "cpu",
fp16=settings.USE_FP16,
use_cuda=settings.USE_CUDA
)
await model.load_model()
logger.info("MIGAN 모델 세션 로드 완료")
return model
except Exception as e:
logger.error(f"MIGAN 모델 로드 실패: {e}")
raise
async def _load_rembg_model(self):
"""REMBG 모델을 로드합니다."""
from ..models.rembg_model import RembgProcessor
from ..core.config import settings
try:
# RemBG 모델 생성 - 자동으로 CUDA 감지
model = RembgProcessor(
model_name=getattr(settings, 'REMBG_MODEL_NAME', 'birefnet-general-lite'),
device="cuda" if settings.USE_CUDA else "cpu",
fp16=settings.USE_FP16,
local_rembg_model_path=getattr(settings, 'LOCAL_REMBG_MODEL_PATH', None)
)
# 프리로드 강제: 실패 시 서버 기동 실패로 처리 (원인 파악을 위함)
await model.load_model()
logger.info("REMBG 모델 세션 로드 완료")
return model
except Exception as e:
logger.error(f"REMBG 모델 로드 실패: {e}")
raise
@asynccontextmanager
async def get_session(self, model_type: ModelType):
"""세션을 가져와서 사용 후 반환합니다."""
session = await self._acquire_session(model_type)
try:
yield session
finally:
await self._release_session(session)
async def _acquire_session(self, model_type: ModelType) -> ModelSession:
"""사용 가능한 세션을 획득합니다."""
async with self.locks[model_type]:
# 사용 가능한 세션 찾기
for session in self.pools[model_type]:
if not session.in_use:
session.in_use = True
session.mark_used()
logger.debug(f"Acquired session {session.session_id}")
return session
# 사용 가능한 세션이 없으면 대기
logger.warning(f"No available sessions for {model_type.value}, waiting...")
# 세션이 사용 가능해질 때까지 대기
while True:
await asyncio.sleep(0.1)
async with self.locks[model_type]:
for session in self.pools[model_type]:
if not session.in_use:
session.in_use = True
session.mark_used()
logger.debug(f"Acquired session {session.session_id} after waiting")
return session
async def _release_session(self, session: ModelSession):
"""세션을 반환합니다."""
async with self.locks[session.model_type]:
session.in_use = False
logger.debug(f"Released session {session.session_id}")
async def get_pool_status(self) -> Dict[str, Any]:
"""풀 상태를 반환합니다."""
status = {}
for model_type in ModelType:
pool = self.pools[model_type]
total = len(pool)
in_use = sum(1 for session in pool if session.in_use)
available = total - in_use
status[model_type.value] = {
"total": total,
"in_use": in_use,
"available": available,
"sessions": [
{
"id": session.session_id,
"in_use": session.in_use,
"last_used": session.last_used,
"created_at": session.created_at
}
for session in pool
]
}
return status
async def cleanup_expired_sessions(self, timeout: int = 3600):
"""만료된 세션을 정리합니다."""
for model_type, pool in self.pools.items():
async with self.locks[model_type]:
expired_sessions = [s for s in pool if s.is_expired(timeout) and not s.in_use]
for session in expired_sessions:
pool.remove(session)
logger.info(f"Removed expired session {session.session_id}")
async def scale_pool(self, model_type: ModelType, new_size: int):
"""풀 크기를 조정합니다."""
async with self.locks[model_type]:
current_size = len(self.pools[model_type])
if new_size > current_size:
# 세션 추가
for i in range(current_size, new_size):
session_id = f"{model_type.value}_{i}"
session = await self._create_session(model_type, session_id)
self.pools[model_type].append(session)
logger.info(f"Added session {session_id}")
elif new_size < current_size:
# 세션 제거 (사용 중이지 않은 것만)
sessions_to_remove = []
for session in self.pools[model_type]:
if not session.in_use and len(sessions_to_remove) < (current_size - new_size):
sessions_to_remove.append(session)
for session in sessions_to_remove:
self.pools[model_type].remove(session)
logger.info(f"Removed session {session.session_id}")
self.pool_sizes[model_type] = new_size
def get_status(self) -> dict:
"""세션 풀의 현재 상태를 반환합니다."""
status_by_model = {}
for model_type in ModelType:
pool = self.pools[model_type]
total = len(pool)
in_use = sum(1 for session in pool if session.in_use)
available = total - in_use
status_by_model[model_type.value] = {
"total": total,
"in_use": in_use,
"available": available
}
return status_by_model
# 전역 세션 풀 인스턴스 (설정값으로 초기화)
from ..core.config import settings
session_pool = SessionPool(
simple_lama_count=settings.SIMPLE_LAMA_SESSIONS,
migan_count=settings.MIGAN_SESSIONS,
rembg_count=settings.REMBG_SESSIONS
)