270 lines
12 KiB
Python
270 lines
12 KiB
Python
"""
|
||
세션 풀 관리 시스템
|
||
각 모델(simple-lama, migan, rembg)의 세션을 효율적으로 관리합니다.
|
||
VRAM 사용량을 고려하여 세션을 동적으로 생성하고, 유휴 세션을 자동으로 제거합니다.
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from typing import Dict, List, Optional, Any, Tuple
|
||
from enum import Enum
|
||
from dataclasses import dataclass
|
||
from contextlib import asynccontextmanager
|
||
from collections import defaultdict
|
||
|
||
from ..core.config import settings
|
||
from ..utils.gpu_monitor import gpu_monitor
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ModelType(Enum):
|
||
SIMPLE_LAMA = "simple_lama"
|
||
MIGAN = "migan"
|
||
REMBG = "rembg"
|
||
|
||
|
||
@dataclass
|
||
class ModelSession:
|
||
session_id: str
|
||
model_type: ModelType
|
||
model: Any
|
||
created_at: float
|
||
last_used: float
|
||
in_use: bool = False
|
||
|
||
def mark_used(self):
|
||
self.last_used = time.time()
|
||
|
||
def is_idle(self, timeout: int) -> bool:
|
||
if timeout <= 0:
|
||
return False
|
||
return not self.in_use and (time.time() - self.last_used > timeout)
|
||
|
||
|
||
class SessionPool:
|
||
def __init__(self, model_configs: Dict[ModelType, Tuple[int, int]]):
|
||
self.pools: Dict[ModelType, List[ModelSession]] = {mt: [] for mt in ModelType}
|
||
self.model_configs = model_configs
|
||
|
||
self.conditions: Dict[ModelType, asyncio.Condition] = {
|
||
mt: asyncio.Condition() for mt in ModelType
|
||
}
|
||
self._initialized = False
|
||
self._reaper_task: Optional[asyncio.Task] = None
|
||
|
||
async def initialize(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info("Initializing dynamic session pools...")
|
||
|
||
for model_type, (min_sessions, _) in self.model_configs.items():
|
||
if min_sessions > 0:
|
||
logger.info(f"Pre-loading {min_sessions} sessions for {model_type.value}")
|
||
tasks = [
|
||
self._create_session(model_type, f"{model_type.value}_{i}")
|
||
for i in range(min_sessions)
|
||
]
|
||
try:
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
for i, res in enumerate(results):
|
||
if isinstance(res, ModelSession):
|
||
self.pools[model_type].append(res)
|
||
else:
|
||
logger.error(f"Failed to create initial session {model_type.value}_{i}: {res}")
|
||
except Exception as e:
|
||
logger.error(f"Error during concurrent session creation for {model_type.value}: {e}", exc_info=True)
|
||
|
||
self._initialized = True
|
||
|
||
if settings.SESSION_IDLE_TIMEOUT > 0:
|
||
self._reaper_task = asyncio.create_task(self._reap_idle_sessions_task())
|
||
|
||
logger.info("Session pools initialized successfully")
|
||
|
||
def _log_pool_status(self, event: str, model_type: str = "", reaped_info: str = ""):
|
||
try:
|
||
gpu_info = gpu_monitor.get_gpu_memory_info()
|
||
if not gpu_info or 'used' not in gpu_info:
|
||
vram_usage = "VRAM: N/A"
|
||
else:
|
||
vram_usage = f"VRAM: {(gpu_info['used'] / 1024):.1f}/{(gpu_info['total'] / 1024):.1f} GB ({gpu_info['usage_percent']:.1f}%)"
|
||
|
||
session_counts = ", ".join([f"{mt.value}: {len(p)}" for mt, p in self.pools.items()])
|
||
|
||
if event == "create":
|
||
log_message = f"➕ Session Created ({model_type}). Status -> {session_counts} | {vram_usage}"
|
||
elif event == "reap":
|
||
log_message = f"➖ Session Reaped ({reaped_info}). Status -> {session_counts} | {vram_usage}"
|
||
else:
|
||
log_message = f"ℹ️ Pool Status ({event}) -> {session_counts} | {vram_usage}"
|
||
|
||
logger.info(log_message)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to log memory status after session event '{event}': {e}")
|
||
|
||
async def _create_session(self, model_type: ModelType, session_id: str) -> ModelSession:
|
||
logger.info(f"Creating new session {session_id} for {model_type.value}...")
|
||
try:
|
||
model = await self._load_model(model_type)
|
||
session = ModelSession(
|
||
session_id=session_id,
|
||
model_type=model_type,
|
||
model=model,
|
||
created_at=time.time(),
|
||
last_used=time.time()
|
||
)
|
||
logger.info(f"Successfully created session {session_id}")
|
||
self._log_pool_status("create", model_type.value)
|
||
return session
|
||
except Exception as e:
|
||
logger.error(f"Failed to create session {session_id}: {e}", exc_info=True)
|
||
raise
|
||
|
||
async def _load_model(self, model_type: ModelType) -> Any:
|
||
if model_type == ModelType.SIMPLE_LAMA:
|
||
from ..models.simple_lama import SimpleLamaInpainter
|
||
model = SimpleLamaInpainter(
|
||
model_path=settings.SIMPLE_LAMA_MODEL_PATH,
|
||
device="cuda" if settings.USE_CUDA else "cpu",
|
||
fp16=settings.USE_FP16
|
||
)
|
||
elif model_type == ModelType.MIGAN:
|
||
from ..models.migan import MiganInpainter
|
||
model = MiganInpainter(
|
||
model_path=getattr(settings, 'MIGAN_ONNX_PATH', settings.MIGAN_MODEL_PATH),
|
||
device="cuda" if settings.USE_CUDA else "cpu",
|
||
fp16=settings.USE_FP16,
|
||
use_cuda=settings.USE_CUDA
|
||
)
|
||
elif model_type == ModelType.REMBG:
|
||
# rembg 대신 BriaAI RMBG 1.4 ONNX 프로세서를 사용
|
||
from ..models.bria_rmbg_onnx import BriaRMBGOnnxProcessor
|
||
model = BriaRMBGOnnxProcessor()
|
||
else:
|
||
raise ValueError(f"Unknown model type: {model_type}")
|
||
|
||
try:
|
||
await asyncio.wait_for(model.load_model(), timeout=180)
|
||
logger.debug(f"{model_type.value} model instance created and loaded.")
|
||
return model
|
||
except Exception as e:
|
||
logger.error(f"Failed to load {model_type.value} model: {e}", exc_info=True)
|
||
raise
|
||
|
||
@asynccontextmanager
|
||
async def get_session(self, model_type: ModelType):
|
||
session = await self._acquire_session(model_type)
|
||
try:
|
||
yield session
|
||
finally:
|
||
if session:
|
||
await self._release_session(session)
|
||
|
||
async def _acquire_session(self, model_type: ModelType) -> ModelSession:
|
||
condition = self.conditions[model_type]
|
||
_, max_sessions = self.model_configs[model_type]
|
||
|
||
while True:
|
||
async with condition:
|
||
for session in self.pools[model_type]:
|
||
if not session.in_use:
|
||
session.in_use = True
|
||
session.mark_used()
|
||
logger.debug(f"Acquired existing session {session.session_id}")
|
||
return session
|
||
|
||
if len(self.pools[model_type]) < max_sessions:
|
||
gpu_mem_info = gpu_monitor.get_gpu_memory_info()
|
||
free_vram_ratio = gpu_mem_info.get("free_ratio", 0)
|
||
|
||
if free_vram_ratio > settings.SESSION_VRAM_THRESHOLD:
|
||
current_pool_size = len(self.pools[model_type])
|
||
session_id = f"{model_type.value}_{current_pool_size}"
|
||
|
||
logger.info(f"Attempting to create new session for {model_type.value}. Current size: {current_pool_size}, Max size: {max_sessions}")
|
||
|
||
try:
|
||
new_session = await self._create_session(model_type, session_id)
|
||
# 다시 lock을 잡고 풀에 추가
|
||
new_session.in_use = True
|
||
new_session.mark_used()
|
||
self.pools[model_type].append(new_session)
|
||
self._log_pool_status("create", model_type.value) # 위치 변경
|
||
logger.info(f"Acquired new session {new_session.session_id} as VRAM is sufficient ({free_vram_ratio:.2f} > {settings.SESSION_VRAM_THRESHOLD:.2f})")
|
||
return new_session
|
||
except Exception:
|
||
# 세션 생성 실패 시 루프를 계속하여 다시 시도하거나 대기
|
||
logger.error(f"New session creation failed for {model_type.value}. Will wait for an existing session.")
|
||
pass
|
||
else:
|
||
logger.warning(f"Cannot create new session for {model_type.value}. VRAM threshold not met. (Free: {free_vram_ratio:.2f} <= Threshold: {settings.SESSION_VRAM_THRESHOLD:.2f})")
|
||
|
||
logger.debug(f"No available sessions or VRAM for {model_type.value}, waiting...")
|
||
await condition.wait()
|
||
|
||
async def _release_session(self, session: ModelSession):
|
||
condition = self.conditions[session.model_type]
|
||
async with condition:
|
||
session.in_use = False
|
||
logger.debug(f"Released session {session.session_id}")
|
||
condition.notify()
|
||
|
||
def get_status(self) -> dict:
|
||
status_by_model = {}
|
||
for model_type in ModelType:
|
||
pool = self.pools[model_type]
|
||
min_s, max_s = self.model_configs[model_type]
|
||
total = len(pool)
|
||
in_use = sum(1 for session in pool if session.in_use)
|
||
available = total - in_use
|
||
|
||
status_by_model[model_type.value] = {
|
||
"min": min_s, "max": max_s, "total": total,
|
||
"in_use": in_use, "available": available
|
||
}
|
||
return status_by_model
|
||
|
||
async def _reap_idle_sessions_task(self):
|
||
logger.info(f"Idle session reaper started. Timeout: {settings.SESSION_IDLE_TIMEOUT}s, Check Interval: 60s")
|
||
while True:
|
||
await asyncio.sleep(60)
|
||
|
||
reaped_counts = defaultdict(int)
|
||
|
||
for model_type, pool in self.pools.items():
|
||
min_sessions, _ = self.model_configs[model_type]
|
||
|
||
async with self.conditions[model_type]:
|
||
if len(pool) <= min_sessions:
|
||
continue
|
||
|
||
idle_sessions = [s for s in pool if s.is_idle(settings.SESSION_IDLE_TIMEOUT)]
|
||
num_to_reap = min(len(idle_sessions), len(pool) - min_sessions)
|
||
|
||
if num_to_reap > 0:
|
||
sessions_to_reap = idle_sessions[:num_to_reap]
|
||
logger.info(f"Reaping {len(sessions_to_reap)} idle session(s) for {model_type.value}.")
|
||
|
||
for session in sessions_to_reap:
|
||
pool.remove(session)
|
||
reaped_counts[session.model_type.value] += 1
|
||
del session.model
|
||
del session
|
||
|
||
self.conditions[model_type].notify_all()
|
||
|
||
if reaped_counts:
|
||
reaped_info = ", ".join([f"{count} {model}" for model, count in reaped_counts.items()])
|
||
self._log_pool_status("reap", reaped_info=reaped_info)
|
||
|
||
# Global session pool instance
|
||
model_configs = {
|
||
ModelType.SIMPLE_LAMA: (settings.SIMPLE_LAMA_MIN_SESSIONS, settings.SIMPLE_LAMA_MAX_SESSIONS),
|
||
ModelType.MIGAN: (settings.MIGAN_MIN_SESSIONS, settings.MIGAN_MAX_SESSIONS),
|
||
ModelType.REMBG: (settings.REMBG_MIN_SESSIONS, settings.REMBG_MAX_SESSIONS),
|
||
}
|
||
session_pool = SessionPool(model_configs=model_configs)
|