390 lines
14 KiB
Python
390 lines
14 KiB
Python
"""
|
|
동적 워커 관리 시스템
|
|
VRAM 사용량에 따라 워커 수를 동적으로 조정합니다.
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Dict, List, Optional, Callable, Any, Tuple
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import numpy as np
|
|
|
|
from ..utils.gpu_monitor import gpu_monitor
|
|
from ..core.config import settings
|
|
from ..core.stats_manager import stats_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WorkerStatus(Enum):
|
|
IDLE = "idle"
|
|
BUSY = "busy"
|
|
STARTING = "starting"
|
|
STOPPING = "stopping"
|
|
ERROR = "error"
|
|
|
|
|
|
@dataclass
|
|
class Worker:
|
|
worker_id: str
|
|
status: WorkerStatus
|
|
created_at: float
|
|
last_task_at: Optional[float] = None
|
|
current_task: Optional[str] = None
|
|
task_count: int = 0
|
|
error_count: int = 0
|
|
|
|
def mark_task_start(self, task_id: str):
|
|
self.status = WorkerStatus.BUSY
|
|
self.current_task = task_id
|
|
self.last_task_at = time.time()
|
|
|
|
def mark_task_complete(self):
|
|
self.status = WorkerStatus.IDLE
|
|
self.current_task = None
|
|
self.task_count += 1
|
|
|
|
def mark_error(self):
|
|
self.status = WorkerStatus.ERROR
|
|
self.error_count += 1
|
|
|
|
|
|
class WorkerManager:
|
|
def __init__(self):
|
|
self.workers: Dict[str, Worker] = {}
|
|
self.task_queue: asyncio.Queue = asyncio.Queue()
|
|
self.executor = ThreadPoolExecutor(max_workers=settings.MAX_WORKERS)
|
|
self.running = False
|
|
self.monitor_task: Optional[asyncio.Task] = None
|
|
self.worker_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
# 스케일링 설정
|
|
self.last_scale_time = time.time()
|
|
self.scale_cooldown = 60 # 1분 쿨다운
|
|
|
|
async def start(self):
|
|
"""워커 매니저를 시작합니다."""
|
|
if self.running:
|
|
return
|
|
|
|
self.running = True
|
|
logger.info("Starting worker manager...")
|
|
|
|
# 초기 워커 생성
|
|
await self._scale_workers(settings.MIN_WORKERS)
|
|
|
|
# 모니터링 태스크 시작
|
|
self.monitor_task = asyncio.create_task(self._monitor_loop())
|
|
|
|
logger.info(f"Worker manager started with {len(self.workers)} workers")
|
|
|
|
async def stop(self):
|
|
"""워커 매니저를 중지합니다."""
|
|
if not self.running:
|
|
return
|
|
|
|
self.running = False
|
|
logger.info("Stopping worker manager...")
|
|
|
|
# 모니터링 태스크 중지
|
|
if self.monitor_task:
|
|
self.monitor_task.cancel()
|
|
try:
|
|
await self.monitor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# 모든 워커 중지
|
|
await self._stop_all_workers()
|
|
|
|
# 스레드 풀 종료
|
|
self.executor.shutdown(wait=True)
|
|
|
|
logger.info("Worker manager stopped")
|
|
|
|
async def submit_task(self, task_func: Callable, *args, **kwargs) -> Any:
|
|
"""태스크를 워커에게 제출합니다."""
|
|
task_id = str(uuid.uuid4())
|
|
|
|
# 사용 가능한 워커 찾기
|
|
worker = await self._get_available_worker()
|
|
if not worker:
|
|
logger.warning("No available workers, queuing task")
|
|
# 큐에 추가하고 대기
|
|
future = asyncio.Future()
|
|
await self.task_queue.put((task_id, task_func, args, kwargs, future))
|
|
return await future
|
|
|
|
# 워커에 태스크 할당
|
|
return await self._execute_task(worker, task_id, task_func, *args, **kwargs)
|
|
|
|
async def _get_available_worker(self) -> Optional[Worker]:
|
|
"""사용 가능한 워커를 찾습니다."""
|
|
for worker in self.workers.values():
|
|
if worker.status == WorkerStatus.IDLE:
|
|
return worker
|
|
return None
|
|
|
|
async def _execute_task(self, worker: Worker, task_id: str,
|
|
task_func: Callable, *args, **kwargs) -> Any:
|
|
"""워커에서 태스크를 실행합니다."""
|
|
worker.mark_task_start(task_id)
|
|
logger.debug(f"Executing task {task_id} on worker {worker.worker_id}")
|
|
|
|
try:
|
|
# 비동기 함수인지 확인
|
|
if asyncio.iscoroutinefunction(task_func):
|
|
result = await task_func(*args, **kwargs)
|
|
else:
|
|
# 동기 함수는 스레드 풀에서 실행
|
|
loop = asyncio.get_event_loop()
|
|
result = await loop.run_in_executor(self.executor, task_func, *args, **kwargs)
|
|
|
|
worker.mark_task_complete()
|
|
logger.debug(f"Task {task_id} completed successfully")
|
|
|
|
# 큐에서 대기 중인 태스크가 있다면 처리
|
|
asyncio.create_task(self._process_queue())
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
worker.mark_error()
|
|
logger.error(f"Task {task_id} failed on worker {worker.worker_id}: {e}")
|
|
|
|
# 큐에서 대기 중인 태스크가 있다면 처리
|
|
asyncio.create_task(self._process_queue())
|
|
|
|
raise e
|
|
|
|
async def _process_queue(self):
|
|
"""큐에서 대기 중인 태스크를 처리합니다."""
|
|
if self.task_queue.empty():
|
|
return
|
|
|
|
worker = await self._get_available_worker()
|
|
if not worker:
|
|
return
|
|
|
|
try:
|
|
task_id, task_func, args, kwargs, future = self.task_queue.get_nowait()
|
|
result = await self._execute_task(worker, task_id, task_func, *args, **kwargs)
|
|
future.set_result(result)
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
except Exception as e:
|
|
if 'future' in locals():
|
|
future.set_exception(e)
|
|
|
|
async def _monitor_loop(self):
|
|
"""모니터링 루프"""
|
|
while self.running:
|
|
try:
|
|
await self._check_scaling()
|
|
await self._cleanup_error_workers()
|
|
await asyncio.sleep(settings.VRAM_CHECK_INTERVAL)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in monitor loop: {e}")
|
|
await asyncio.sleep(5)
|
|
|
|
async def _check_scaling(self):
|
|
"""스케일링 필요성을 확인합니다."""
|
|
current_time = time.time()
|
|
if current_time - self.last_scale_time < self.scale_cooldown:
|
|
return
|
|
|
|
# VRAM 사용량 확인
|
|
gpu_info = gpu_monitor.get_gpu_memory_info()
|
|
vram_usage = gpu_info['usage_percent'] / 100.0
|
|
|
|
# 현재 워커 상태 분석
|
|
total_workers = len(self.workers)
|
|
idle_workers = sum(1 for w in self.workers.values() if w.status == WorkerStatus.IDLE)
|
|
busy_workers = sum(1 for w in self.workers.values() if w.status == WorkerStatus.BUSY)
|
|
queue_size = self.task_queue.qsize()
|
|
|
|
logger.debug(f"Scaling check - VRAM: {vram_usage:.2f}, Workers: {total_workers}, "
|
|
f"Idle: {idle_workers}, Busy: {busy_workers}, Queue: {queue_size}")
|
|
|
|
# 스케일 업 조건
|
|
should_scale_up = (
|
|
vram_usage < settings.VRAM_THRESHOLD_LOW and
|
|
(queue_size > 0 or idle_workers == 0) and
|
|
total_workers < settings.MAX_WORKERS
|
|
)
|
|
|
|
# 스케일 다운 조건
|
|
should_scale_down = (
|
|
vram_usage > settings.VRAM_THRESHOLD_HIGH or
|
|
(idle_workers > total_workers * 0.5 and total_workers > settings.MIN_WORKERS)
|
|
)
|
|
|
|
if should_scale_up:
|
|
new_count = min(total_workers + 1, settings.MAX_WORKERS)
|
|
await self._scale_workers(new_count)
|
|
self.last_scale_time = current_time
|
|
logger.info(f"Scaled up to {new_count} workers (VRAM: {vram_usage:.2f})")
|
|
|
|
elif should_scale_down:
|
|
new_count = max(total_workers - 1, settings.MIN_WORKERS)
|
|
await self._scale_workers(new_count)
|
|
self.last_scale_time = current_time
|
|
logger.info(f"Scaled down to {new_count} workers (VRAM: {vram_usage:.2f})")
|
|
|
|
async def _scale_workers(self, target_count: int):
|
|
"""워커 수를 조정합니다."""
|
|
current_count = len(self.workers)
|
|
|
|
if target_count > current_count:
|
|
# 워커 추가
|
|
for i in range(target_count - current_count):
|
|
worker_id = f"worker_{uuid.uuid4().hex[:8]}"
|
|
worker = Worker(
|
|
worker_id=worker_id,
|
|
status=WorkerStatus.IDLE,
|
|
created_at=time.time()
|
|
)
|
|
self.workers[worker_id] = worker
|
|
logger.debug(f"Created worker {worker_id}")
|
|
|
|
elif target_count < current_count:
|
|
# 워커 제거 (유휴 상태인 것만)
|
|
workers_to_remove = []
|
|
for worker in self.workers.values():
|
|
if (worker.status == WorkerStatus.IDLE and
|
|
len(workers_to_remove) < (current_count - target_count)):
|
|
workers_to_remove.append(worker)
|
|
|
|
for worker in workers_to_remove:
|
|
worker.status = WorkerStatus.STOPPING
|
|
del self.workers[worker.worker_id]
|
|
logger.debug(f"Removed worker {worker.worker_id}")
|
|
|
|
async def _cleanup_error_workers(self):
|
|
"""에러 상태의 워커를 정리합니다."""
|
|
error_workers = [w for w in self.workers.values() if w.status == WorkerStatus.ERROR]
|
|
|
|
for worker in error_workers:
|
|
# 에러 워커 제거 후 새 워커 생성
|
|
del self.workers[worker.worker_id]
|
|
logger.warning(f"Removed error worker {worker.worker_id}")
|
|
|
|
# 새 워커 생성
|
|
new_worker_id = f"worker_{uuid.uuid4().hex[:8]}"
|
|
new_worker = Worker(
|
|
worker_id=new_worker_id,
|
|
status=WorkerStatus.IDLE,
|
|
created_at=time.time()
|
|
)
|
|
self.workers[new_worker_id] = new_worker
|
|
logger.info(f"Created replacement worker {new_worker_id}")
|
|
|
|
async def _stop_all_workers(self):
|
|
"""모든 워커를 중지합니다."""
|
|
for worker in self.workers.values():
|
|
worker.status = WorkerStatus.STOPPING
|
|
|
|
# 실행 중인 태스크가 완료될 때까지 대기
|
|
max_wait = 30 # 30초 최대 대기
|
|
wait_start = time.time()
|
|
|
|
while time.time() - wait_start < max_wait:
|
|
busy_workers = [w for w in self.workers.values() if w.status == WorkerStatus.BUSY]
|
|
if not busy_workers:
|
|
break
|
|
await asyncio.sleep(1)
|
|
|
|
self.workers.clear()
|
|
|
|
def get_status(self) -> dict:
|
|
"""워커 매니저의 현재 상태를 반환합니다."""
|
|
workers_by_status = {
|
|
"idle": [],
|
|
"busy": [],
|
|
"starting": [],
|
|
"stopping": [],
|
|
"error": []
|
|
}
|
|
|
|
for worker_id, worker in self.workers.items():
|
|
status_data = {
|
|
"id": worker.worker_id,
|
|
"status": worker.status.value,
|
|
"task_count": worker.task_count,
|
|
"error_count": worker.error_count,
|
|
"last_task_at": worker.last_task_at
|
|
}
|
|
workers_by_status[worker.status.value].append(status_data)
|
|
|
|
return {
|
|
"running": self.running,
|
|
"total_workers": len(self.workers),
|
|
"queue_size": self.task_queue.qsize(),
|
|
"workers_by_status": workers_by_status
|
|
}
|
|
|
|
async def process_inpaint(self, **kwargs) -> Optional[np.ndarray]:
|
|
"""인페인팅 작업을 처리합니다."""
|
|
try:
|
|
from ..core.session_pool import session_pool, ModelType
|
|
|
|
model_name = kwargs.get('model_name', 'simple-lama')
|
|
|
|
# 모델명에 따라 세션 타입 결정
|
|
if model_name == 'migan':
|
|
model_type = ModelType.MIGAN
|
|
else:
|
|
model_type = ModelType.SIMPLE_LAMA # 기본값
|
|
|
|
# 세션 풀에서 모델 세션 가져와서 처리
|
|
async with session_pool.get_session(model_type) as session:
|
|
start_time = time.time()
|
|
# session.model 에서 실제 모델 객체의 메서드를 호출해야 함
|
|
result = await session.model.inpaint(
|
|
image=kwargs['image'],
|
|
mask=kwargs['mask']
|
|
)
|
|
duration = time.time() - start_time
|
|
stats_manager.record_time(model_name, duration)
|
|
logger.info(f"'{model_name}' inpainting processed in {duration:.3f}s")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"인페인팅 처리 실패: {e}", exc_info=True)
|
|
return None
|
|
|
|
async def process_remove_bg(self, **kwargs) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
"""배경 제거 작업을 처리합니다."""
|
|
try:
|
|
from ..core.session_pool import session_pool, ModelType
|
|
|
|
model_name = kwargs.get('model_name', 'birefnet-general-lite') # 기본 모델명 변경
|
|
|
|
# 세션 풀에서 REMBG 모델 세션 가져와서 처리
|
|
async with session_pool.get_session(ModelType.REMBG) as session:
|
|
start_time = time.time()
|
|
# session.model 에서 실제 모델 객체의 메서드를 호출해야 함
|
|
result = await session.model.remove_background(
|
|
image=kwargs['image'],
|
|
model_name=model_name
|
|
)
|
|
duration = time.time() - start_time
|
|
stats_manager.record_time('rembg', duration)
|
|
logger.info(f"'rembg ({model_name})' processed in {duration:.3f}s")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"배경 제거 처리 실패: {e}", exc_info=True)
|
|
return None, None
|
|
|
|
|
|
# 전역 워커 매니저 인스턴스
|
|
worker_manager = WorkerManager()
|