246 lines
9.4 KiB
Python
246 lines
9.4 KiB
Python
"""
|
|
세션 풀 관리 시스템
|
|
각 모델(simple-lama, migan, rembg)의 세션을 효율적으로 관리합니다.
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Dict, List, Optional, Any
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from contextlib import asynccontextmanager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelType(Enum):
|
|
SIMPLE_LAMA = "simple_lama"
|
|
MIGAN = "migan"
|
|
REMBG = "rembg"
|
|
|
|
|
|
@dataclass
|
|
class ModelSession:
|
|
session_id: str
|
|
model_type: ModelType
|
|
model: Any
|
|
created_at: float
|
|
last_used: float
|
|
in_use: bool = False
|
|
|
|
def mark_used(self):
|
|
self.last_used = time.time()
|
|
|
|
def is_expired(self, timeout: int = 3600) -> bool:
|
|
return time.time() - self.last_used > timeout
|
|
|
|
|
|
class SessionPool:
|
|
def __init__(self,
|
|
simple_lama_count: int = 2,
|
|
migan_count: int = 2,
|
|
rembg_count: int = 1):
|
|
self.pools: Dict[ModelType, List[ModelSession]] = {
|
|
ModelType.SIMPLE_LAMA: [],
|
|
ModelType.MIGAN: [],
|
|
ModelType.REMBG: []
|
|
}
|
|
self.pool_sizes = {
|
|
ModelType.SIMPLE_LAMA: simple_lama_count,
|
|
ModelType.MIGAN: migan_count,
|
|
ModelType.REMBG: rembg_count
|
|
}
|
|
self.locks: Dict[ModelType, asyncio.Lock] = {
|
|
model_type: asyncio.Lock() for model_type in ModelType
|
|
}
|
|
self._initialized = False
|
|
|
|
async def initialize(self):
|
|
"""모든 모델 세션을 초기화합니다."""
|
|
if self._initialized:
|
|
return
|
|
|
|
logger.info("Initializing session pools...")
|
|
|
|
for model_type, count in self.pool_sizes.items():
|
|
await self._initialize_model_pool(model_type, count)
|
|
|
|
self._initialized = True
|
|
logger.info("Session pools initialized successfully")
|
|
|
|
async def _initialize_model_pool(self, model_type: ModelType, count: int):
|
|
"""특정 모델의 세션 풀을 초기화합니다."""
|
|
logger.info(f"Initializing {count} sessions for {model_type.value}")
|
|
|
|
for i in range(count):
|
|
try:
|
|
session = await self._create_session(model_type, f"{model_type.value}_{i}")
|
|
self.pools[model_type].append(session)
|
|
logger.info(f"Created session {session.session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create session for {model_type.value}: {e}")
|
|
|
|
async def _create_session(self, model_type: ModelType, session_id: str) -> ModelSession:
|
|
"""새로운 모델 세션을 생성합니다."""
|
|
model = await self._load_model(model_type)
|
|
return ModelSession(
|
|
session_id=session_id,
|
|
model_type=model_type,
|
|
model=model,
|
|
created_at=time.time(),
|
|
last_used=time.time()
|
|
)
|
|
|
|
async def _load_model(self, model_type: ModelType) -> Any:
|
|
"""모델을 로드합니다."""
|
|
# 실제 구현에서는 각 모델을 로드하는 로직이 들어갑니다
|
|
if model_type == ModelType.SIMPLE_LAMA:
|
|
return await self._load_simple_lama_model()
|
|
elif model_type == ModelType.MIGAN:
|
|
return await self._load_migan_model()
|
|
elif model_type == ModelType.REMBG:
|
|
return await self._load_rembg_model()
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
async def _load_simple_lama_model(self):
|
|
"""Simple LAMA 모델을 로드합니다."""
|
|
# Placeholder - 실제 모델 로딩 로직으로 대체
|
|
await asyncio.sleep(0.1) # 시뮬레이션
|
|
return {"model": "simple_lama", "loaded": True}
|
|
|
|
async def _load_migan_model(self):
|
|
"""MIGAN 모델을 로드합니다."""
|
|
# Placeholder - 실제 모델 로딩 로직으로 대체
|
|
await asyncio.sleep(0.1) # 시뮬레이션
|
|
return {"model": "migan", "loaded": True}
|
|
|
|
async def _load_rembg_model(self):
|
|
"""REMBG 모델을 로드합니다."""
|
|
# Placeholder - 실제 모델 로딩 로직으로 대체
|
|
await asyncio.sleep(0.1) # 시뮬레이션
|
|
return {"model": "rembg", "loaded": True}
|
|
|
|
@asynccontextmanager
|
|
async def get_session(self, model_type: ModelType):
|
|
"""세션을 가져와서 사용 후 반환합니다."""
|
|
session = await self._acquire_session(model_type)
|
|
try:
|
|
yield session
|
|
finally:
|
|
await self._release_session(session)
|
|
|
|
async def _acquire_session(self, model_type: ModelType) -> ModelSession:
|
|
"""사용 가능한 세션을 획득합니다."""
|
|
async with self.locks[model_type]:
|
|
# 사용 가능한 세션 찾기
|
|
for session in self.pools[model_type]:
|
|
if not session.in_use:
|
|
session.in_use = True
|
|
session.mark_used()
|
|
logger.debug(f"Acquired session {session.session_id}")
|
|
return session
|
|
|
|
# 사용 가능한 세션이 없으면 대기
|
|
logger.warning(f"No available sessions for {model_type.value}, waiting...")
|
|
|
|
# 세션이 사용 가능해질 때까지 대기
|
|
while True:
|
|
await asyncio.sleep(0.1)
|
|
async with self.locks[model_type]:
|
|
for session in self.pools[model_type]:
|
|
if not session.in_use:
|
|
session.in_use = True
|
|
session.mark_used()
|
|
logger.debug(f"Acquired session {session.session_id} after waiting")
|
|
return session
|
|
|
|
async def _release_session(self, session: ModelSession):
|
|
"""세션을 반환합니다."""
|
|
async with self.locks[session.model_type]:
|
|
session.in_use = False
|
|
logger.debug(f"Released session {session.session_id}")
|
|
|
|
async def get_pool_status(self) -> Dict[str, Any]:
|
|
"""풀 상태를 반환합니다."""
|
|
status = {}
|
|
for model_type in ModelType:
|
|
pool = self.pools[model_type]
|
|
total = len(pool)
|
|
in_use = sum(1 for session in pool if session.in_use)
|
|
available = total - in_use
|
|
|
|
status[model_type.value] = {
|
|
"total": total,
|
|
"in_use": in_use,
|
|
"available": available,
|
|
"sessions": [
|
|
{
|
|
"id": session.session_id,
|
|
"in_use": session.in_use,
|
|
"last_used": session.last_used,
|
|
"created_at": session.created_at
|
|
}
|
|
for session in pool
|
|
]
|
|
}
|
|
return status
|
|
|
|
async def cleanup_expired_sessions(self, timeout: int = 3600):
|
|
"""만료된 세션을 정리합니다."""
|
|
for model_type, pool in self.pools.items():
|
|
async with self.locks[model_type]:
|
|
expired_sessions = [s for s in pool if s.is_expired(timeout) and not s.in_use]
|
|
for session in expired_sessions:
|
|
pool.remove(session)
|
|
logger.info(f"Removed expired session {session.session_id}")
|
|
|
|
async def scale_pool(self, model_type: ModelType, new_size: int):
|
|
"""풀 크기를 조정합니다."""
|
|
async with self.locks[model_type]:
|
|
current_size = len(self.pools[model_type])
|
|
|
|
if new_size > current_size:
|
|
# 세션 추가
|
|
for i in range(current_size, new_size):
|
|
session_id = f"{model_type.value}_{i}"
|
|
session = await self._create_session(model_type, session_id)
|
|
self.pools[model_type].append(session)
|
|
logger.info(f"Added session {session_id}")
|
|
|
|
elif new_size < current_size:
|
|
# 세션 제거 (사용 중이지 않은 것만)
|
|
sessions_to_remove = []
|
|
for session in self.pools[model_type]:
|
|
if not session.in_use and len(sessions_to_remove) < (current_size - new_size):
|
|
sessions_to_remove.append(session)
|
|
|
|
for session in sessions_to_remove:
|
|
self.pools[model_type].remove(session)
|
|
logger.info(f"Removed session {session.session_id}")
|
|
|
|
self.pool_sizes[model_type] = new_size
|
|
|
|
def get_status(self) -> dict:
|
|
"""세션 풀의 현재 상태를 반환합니다."""
|
|
status_by_model = {}
|
|
|
|
all_sessions = list(self.pools.values()) # Flatten all sessions from all models
|
|
|
|
for model_type in ModelType:
|
|
model_sessions = [s for s in all_sessions if s.model_type == model_type]
|
|
in_use_count = sum(1 for s in model_sessions if s.in_use)
|
|
available_count = len(model_sessions) - in_use_count
|
|
|
|
status_by_model[model_type.value] = {
|
|
"total": len(model_sessions),
|
|
"in_use": in_use_count,
|
|
"available": available_count
|
|
}
|
|
|
|
return status_by_model
|
|
|
|
|
|
# 전역 세션 풀 인스턴스
|
|
session_pool = SessionPool()
|