""" 세션 풀 관리 시스템 각 모델(simple-lama, migan, rembg)의 세션을 효율적으로 관리합니다. VRAM 사용량을 고려하여 세션을 동적으로 생성하고, 유휴 세션을 자동으로 제거합니다. """ import asyncio import logging import time from typing import Dict, List, Optional, Any, Tuple from enum import Enum from dataclasses import dataclass from contextlib import asynccontextmanager from collections import defaultdict from ..core.config import settings from ..utils.gpu_monitor import gpu_monitor logger = logging.getLogger(__name__) class ModelType(Enum): SIMPLE_LAMA = "simple_lama" MIGAN = "migan" REMBG = "rembg" @dataclass class ModelSession: session_id: str model_type: ModelType model: Any created_at: float last_used: float in_use: bool = False def mark_used(self): self.last_used = time.time() def is_idle(self, timeout: int) -> bool: if timeout <= 0: return False return not self.in_use and (time.time() - self.last_used > timeout) class SessionPool: def __init__(self, model_configs: Dict[ModelType, Tuple[int, int]]): self.pools: Dict[ModelType, List[ModelSession]] = {mt: [] for mt in ModelType} self.model_configs = model_configs self.conditions: Dict[ModelType, asyncio.Condition] = { mt: asyncio.Condition() for mt in ModelType } self._initialized = False self._reaper_task: Optional[asyncio.Task] = None async def initialize(self): if self._initialized: return logger.info("Initializing dynamic session pools...") for model_type, (min_sessions, _) in self.model_configs.items(): if min_sessions > 0: logger.info(f"Pre-loading {min_sessions} sessions for {model_type.value}") tasks = [ self._create_session(model_type, f"{model_type.value}_{i}") for i in range(min_sessions) ] try: results = await asyncio.gather(*tasks, return_exceptions=True) for i, res in enumerate(results): if isinstance(res, ModelSession): self.pools[model_type].append(res) else: logger.error(f"Failed to create initial session {model_type.value}_{i}: {res}") except Exception as e: logger.error(f"Error during concurrent session creation for {model_type.value}: {e}", exc_info=True) self._initialized = True if settings.SESSION_IDLE_TIMEOUT > 0: self._reaper_task = asyncio.create_task(self._reap_idle_sessions_task()) logger.info("Session pools initialized successfully") def _log_pool_status(self, event: str, model_type: str = "", reaped_info: str = ""): try: gpu_info = gpu_monitor.get_gpu_memory_info() if not gpu_info or 'used' not in gpu_info: vram_usage = "VRAM: N/A" else: vram_usage = f"VRAM: {(gpu_info['used'] / 1024):.1f}/{(gpu_info['total'] / 1024):.1f} GB ({gpu_info['usage_percent']:.1f}%)" session_counts = ", ".join([f"{mt.value}: {len(p)}" for mt, p in self.pools.items()]) if event == "create": log_message = f"➕ Session Created ({model_type}). Status -> {session_counts} | {vram_usage}" elif event == "reap": log_message = f"➖ Session Reaped ({reaped_info}). Status -> {session_counts} | {vram_usage}" else: log_message = f"ℹ️ Pool Status ({event}) -> {session_counts} | {vram_usage}" logger.info(log_message) except Exception as e: logger.warning(f"Failed to log memory status after session event '{event}': {e}") async def _create_session(self, model_type: ModelType, session_id: str) -> ModelSession: logger.info(f"Creating new session {session_id} for {model_type.value}...") try: model = await self._load_model(model_type) session = ModelSession( session_id=session_id, model_type=model_type, model=model, created_at=time.time(), last_used=time.time() ) logger.info(f"Successfully created session {session_id}") self._log_pool_status("create", model_type.value) return session except Exception as e: logger.error(f"Failed to create session {session_id}: {e}", exc_info=True) raise async def _load_model(self, model_type: ModelType) -> Any: if model_type == ModelType.SIMPLE_LAMA: from ..models.simple_lama import SimpleLamaInpainter model = SimpleLamaInpainter( model_path=settings.SIMPLE_LAMA_MODEL_PATH, device="cuda" if settings.USE_CUDA else "cpu", fp16=settings.USE_FP16 ) elif model_type == ModelType.MIGAN: from ..models.migan import MiganInpainter model = MiganInpainter( model_path=getattr(settings, 'MIGAN_ONNX_PATH', settings.MIGAN_MODEL_PATH), device="cuda" if settings.USE_CUDA else "cpu", fp16=settings.USE_FP16, use_cuda=settings.USE_CUDA ) elif model_type == ModelType.REMBG: # rembg 대신 BriaAI RMBG 1.4 ONNX 프로세서를 사용 from ..models.bria_rmbg_onnx import BriaRMBGOnnxProcessor model = BriaRMBGOnnxProcessor() else: raise ValueError(f"Unknown model type: {model_type}") try: await asyncio.wait_for(model.load_model(), timeout=180) logger.debug(f"{model_type.value} model instance created and loaded.") return model except Exception as e: logger.error(f"Failed to load {model_type.value} model: {e}", exc_info=True) raise @asynccontextmanager async def get_session(self, model_type: ModelType): session = await self._acquire_session(model_type) try: yield session finally: if session: await self._release_session(session) async def _acquire_session(self, model_type: ModelType) -> ModelSession: condition = self.conditions[model_type] _, max_sessions = self.model_configs[model_type] while True: async with condition: for session in self.pools[model_type]: if not session.in_use: session.in_use = True session.mark_used() logger.debug(f"Acquired existing session {session.session_id}") return session if len(self.pools[model_type]) < max_sessions: gpu_mem_info = gpu_monitor.get_gpu_memory_info() free_vram_ratio = gpu_mem_info.get("free_ratio", 0) if free_vram_ratio > settings.SESSION_VRAM_THRESHOLD: current_pool_size = len(self.pools[model_type]) session_id = f"{model_type.value}_{current_pool_size}" logger.info(f"Attempting to create new session for {model_type.value}. Current size: {current_pool_size}, Max size: {max_sessions}") try: new_session = await self._create_session(model_type, session_id) # 다시 lock을 잡고 풀에 추가 new_session.in_use = True new_session.mark_used() self.pools[model_type].append(new_session) self._log_pool_status("create", model_type.value) # 위치 변경 logger.info(f"Acquired new session {new_session.session_id} as VRAM is sufficient ({free_vram_ratio:.2f} > {settings.SESSION_VRAM_THRESHOLD:.2f})") return new_session except Exception: # 세션 생성 실패 시 루프를 계속하여 다시 시도하거나 대기 logger.error(f"New session creation failed for {model_type.value}. Will wait for an existing session.") pass else: logger.warning(f"Cannot create new session for {model_type.value}. VRAM threshold not met. (Free: {free_vram_ratio:.2f} <= Threshold: {settings.SESSION_VRAM_THRESHOLD:.2f})") logger.debug(f"No available sessions or VRAM for {model_type.value}, waiting...") await condition.wait() async def _release_session(self, session: ModelSession): condition = self.conditions[session.model_type] async with condition: session.in_use = False logger.debug(f"Released session {session.session_id}") condition.notify() def get_status(self) -> dict: status_by_model = {} for model_type in ModelType: pool = self.pools[model_type] min_s, max_s = self.model_configs[model_type] total = len(pool) in_use = sum(1 for session in pool if session.in_use) available = total - in_use status_by_model[model_type.value] = { "min": min_s, "max": max_s, "total": total, "in_use": in_use, "available": available } return status_by_model async def _reap_idle_sessions_task(self): logger.info(f"Idle session reaper started. Timeout: {settings.SESSION_IDLE_TIMEOUT}s, Check Interval: 60s") while True: await asyncio.sleep(60) reaped_counts = defaultdict(int) for model_type, pool in self.pools.items(): min_sessions, _ = self.model_configs[model_type] async with self.conditions[model_type]: if len(pool) <= min_sessions: continue idle_sessions = [s for s in pool if s.is_idle(settings.SESSION_IDLE_TIMEOUT)] num_to_reap = min(len(idle_sessions), len(pool) - min_sessions) if num_to_reap > 0: sessions_to_reap = idle_sessions[:num_to_reap] logger.info(f"Reaping {len(sessions_to_reap)} idle session(s) for {model_type.value}.") for session in sessions_to_reap: pool.remove(session) reaped_counts[session.model_type.value] += 1 del session.model del session self.conditions[model_type].notify_all() if reaped_counts: reaped_info = ", ".join([f"{count} {model}" for model, count in reaped_counts.items()]) self._log_pool_status("reap", reaped_info=reaped_info) # Global session pool instance model_configs = { ModelType.SIMPLE_LAMA: (settings.SIMPLE_LAMA_MIN_SESSIONS, settings.SIMPLE_LAMA_MAX_SESSIONS), ModelType.MIGAN: (settings.MIGAN_MIN_SESSIONS, settings.MIGAN_MAX_SESSIONS), ModelType.REMBG: (settings.REMBG_MIN_SESSIONS, settings.REMBG_MAX_SESSIONS), } session_pool = SessionPool(model_configs=model_configs)