inpaintServer/app/core/session_pool.py

274 lines
12 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
세션 풀 관리 시스템
각 모델(simple-lama, migan, rembg)의 세션을 효율적으로 관리합니다.
VRAM 사용량을 고려하여 세션을 동적으로 생성하고, 유휴 세션을 자동으로 제거합니다.
"""
import asyncio
import logging
import time
from typing import Dict, List, Optional, Any, Tuple
from enum import Enum
from dataclasses import dataclass
from contextlib import asynccontextmanager
from collections import defaultdict
from ..core.config import settings
from ..utils.gpu_monitor import gpu_monitor
logger = logging.getLogger(__name__)
class ModelType(Enum):
SIMPLE_LAMA = "simple_lama"
MIGAN = "migan"
REMBG = "rembg"
@dataclass
class ModelSession:
session_id: str
model_type: ModelType
model: Any
created_at: float
last_used: float
in_use: bool = False
def mark_used(self):
self.last_used = time.time()
def is_idle(self, timeout: int) -> bool:
if timeout <= 0:
return False
return not self.in_use and (time.time() - self.last_used > timeout)
class SessionPool:
def __init__(self, model_configs: Dict[ModelType, Tuple[int, int]]):
self.pools: Dict[ModelType, List[ModelSession]] = {mt: [] for mt in ModelType}
self.model_configs = model_configs
self.conditions: Dict[ModelType, asyncio.Condition] = {
mt: asyncio.Condition() for mt in ModelType
}
self._initialized = False
self._reaper_task: Optional[asyncio.Task] = None
async def initialize(self):
if self._initialized:
return
logger.info("Initializing dynamic session pools...")
for model_type, (min_sessions, _) in self.model_configs.items():
if min_sessions > 0:
logger.info(f"Pre-loading {min_sessions} sessions for {model_type.value}")
tasks = [
self._create_session(model_type, f"{model_type.value}_{i}")
for i in range(min_sessions)
]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, res in enumerate(results):
if isinstance(res, ModelSession):
self.pools[model_type].append(res)
else:
logger.error(f"Failed to create initial session {model_type.value}_{i}: {res}")
except Exception as e:
logger.error(f"Error during concurrent session creation for {model_type.value}: {e}", exc_info=True)
self._initialized = True
if settings.SESSION_IDLE_TIMEOUT > 0:
self._reaper_task = asyncio.create_task(self._reap_idle_sessions_task())
logger.info("Session pools initialized successfully")
def _log_pool_status(self, event: str, model_type: str = "", reaped_info: str = ""):
try:
gpu_info = gpu_monitor.get_gpu_memory_info()
if not gpu_info or 'used' not in gpu_info:
vram_usage = "VRAM: N/A"
else:
vram_usage = f"VRAM: {(gpu_info['used'] / 1024):.1f}/{(gpu_info['total'] / 1024):.1f} GB ({gpu_info['usage_percent']:.1f}%)"
session_counts = ", ".join([f"{mt.value}: {len(p)}" for mt, p in self.pools.items()])
if event == "create":
log_message = f" Session Created ({model_type}). Status -> {session_counts} | {vram_usage}"
elif event == "reap":
log_message = f" Session Reaped ({reaped_info}). Status -> {session_counts} | {vram_usage}"
else:
log_message = f" Pool Status ({event}) -> {session_counts} | {vram_usage}"
logger.info(log_message)
except Exception as e:
logger.warning(f"Failed to log memory status after session event '{event}': {e}")
async def _create_session(self, model_type: ModelType, session_id: str) -> ModelSession:
logger.info(f"Creating new session {session_id} for {model_type.value}...")
try:
model = await self._load_model(model_type)
session = ModelSession(
session_id=session_id,
model_type=model_type,
model=model,
created_at=time.time(),
last_used=time.time()
)
logger.info(f"Successfully created session {session_id}")
self._log_pool_status("create", model_type.value)
return session
except Exception as e:
logger.error(f"Failed to create session {session_id}: {e}", exc_info=True)
raise
async def _load_model(self, model_type: ModelType) -> Any:
if model_type == ModelType.SIMPLE_LAMA:
from ..models.simple_lama import SimpleLamaInpainter
model = SimpleLamaInpainter(
model_path=settings.SIMPLE_LAMA_MODEL_PATH,
device="cuda" if settings.USE_CUDA else "cpu",
fp16=settings.USE_FP16
)
elif model_type == ModelType.MIGAN:
from ..models.migan import MiganInpainter
model = MiganInpainter(
model_path=getattr(settings, 'MIGAN_ONNX_PATH', settings.MIGAN_MODEL_PATH),
device="cuda" if settings.USE_CUDA else "cpu",
fp16=settings.USE_FP16,
use_cuda=settings.USE_CUDA
)
elif model_type == ModelType.REMBG:
from ..models.rembg_model import RembgProcessor
model = RembgProcessor(
model_name=getattr(settings, 'REMBG_MODEL_NAME', 'birefnet-general-lite'),
device="cuda" if settings.USE_CUDA else "cpu",
fp16=settings.USE_FP16,
local_rembg_model_path=getattr(settings, 'LOCAL_REMBG_MODEL_PATH', None)
)
else:
raise ValueError(f"Unknown model type: {model_type}")
try:
await asyncio.wait_for(model.load_model(), timeout=180)
logger.debug(f"{model_type.value} model instance created and loaded.")
return model
except Exception as e:
logger.error(f"Failed to load {model_type.value} model: {e}", exc_info=True)
raise
@asynccontextmanager
async def get_session(self, model_type: ModelType):
session = await self._acquire_session(model_type)
try:
yield session
finally:
if session:
await self._release_session(session)
async def _acquire_session(self, model_type: ModelType) -> ModelSession:
condition = self.conditions[model_type]
_, max_sessions = self.model_configs[model_type]
while True:
async with condition:
for session in self.pools[model_type]:
if not session.in_use:
session.in_use = True
session.mark_used()
logger.debug(f"Acquired existing session {session.session_id}")
return session
if len(self.pools[model_type]) < max_sessions:
gpu_mem_info = gpu_monitor.get_gpu_memory_info()
free_vram_ratio = gpu_mem_info.get("free_ratio", 0)
if free_vram_ratio > settings.SESSION_VRAM_THRESHOLD:
current_pool_size = len(self.pools[model_type])
session_id = f"{model_type.value}_{current_pool_size}"
logger.info(f"Attempting to create new session for {model_type.value}. Current size: {current_pool_size}, Max size: {max_sessions}")
try:
new_session = await self._create_session(model_type, session_id)
# 다시 lock을 잡고 풀에 추가
new_session.in_use = True
new_session.mark_used()
self.pools[model_type].append(new_session)
self._log_pool_status("create", model_type.value) # 위치 변경
logger.info(f"Acquired new session {new_session.session_id} as VRAM is sufficient ({free_vram_ratio:.2f} > {settings.SESSION_VRAM_THRESHOLD:.2f})")
return new_session
except Exception:
# 세션 생성 실패 시 루프를 계속하여 다시 시도하거나 대기
logger.error(f"New session creation failed for {model_type.value}. Will wait for an existing session.")
pass
else:
logger.warning(f"Cannot create new session for {model_type.value}. VRAM threshold not met. (Free: {free_vram_ratio:.2f} <= Threshold: {settings.SESSION_VRAM_THRESHOLD:.2f})")
logger.debug(f"No available sessions or VRAM for {model_type.value}, waiting...")
await condition.wait()
async def _release_session(self, session: ModelSession):
condition = self.conditions[session.model_type]
async with condition:
session.in_use = False
logger.debug(f"Released session {session.session_id}")
condition.notify()
def get_status(self) -> dict:
status_by_model = {}
for model_type in ModelType:
pool = self.pools[model_type]
min_s, max_s = self.model_configs[model_type]
total = len(pool)
in_use = sum(1 for session in pool if session.in_use)
available = total - in_use
status_by_model[model_type.value] = {
"min": min_s, "max": max_s, "total": total,
"in_use": in_use, "available": available
}
return status_by_model
async def _reap_idle_sessions_task(self):
logger.info(f"Idle session reaper started. Timeout: {settings.SESSION_IDLE_TIMEOUT}s, Check Interval: 60s")
while True:
await asyncio.sleep(60)
reaped_counts = defaultdict(int)
for model_type, pool in self.pools.items():
min_sessions, _ = self.model_configs[model_type]
async with self.conditions[model_type]:
if len(pool) <= min_sessions:
continue
idle_sessions = [s for s in pool if s.is_idle(settings.SESSION_IDLE_TIMEOUT)]
num_to_reap = min(len(idle_sessions), len(pool) - min_sessions)
if num_to_reap > 0:
sessions_to_reap = idle_sessions[:num_to_reap]
logger.info(f"Reaping {len(sessions_to_reap)} idle session(s) for {model_type.value}.")
for session in sessions_to_reap:
pool.remove(session)
reaped_counts[session.model_type.value] += 1
del session.model
del session
self.conditions[model_type].notify_all()
if reaped_counts:
reaped_info = ", ".join([f"{count} {model}" for model, count in reaped_counts.items()])
self._log_pool_status("reap", reaped_info=reaped_info)
# Global session pool instance
model_configs = {
ModelType.SIMPLE_LAMA: (settings.SIMPLE_LAMA_MIN_SESSIONS, settings.SIMPLE_LAMA_MAX_SESSIONS),
ModelType.MIGAN: (settings.MIGAN_MIN_SESSIONS, settings.MIGAN_MAX_SESSIONS),
ModelType.REMBG: (settings.REMBG_MIN_SESSIONS, settings.REMBG_MAX_SESSIONS),
}
session_pool = SessionPool(model_configs=model_configs)