inpaintServer/app/core/worker_manager.py

390 lines
14 KiB
Python

"""
동적 워커 관리 시스템
VRAM 사용량에 따라 워커 수를 동적으로 조정합니다.
"""
import asyncio
import logging
import time
import uuid
from typing import Dict, List, Optional, Callable, Any, Tuple
from dataclasses import dataclass
from enum import Enum
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from ..utils.gpu_monitor import gpu_monitor
from ..core.config import settings
from ..core.stats_manager import stats_manager
logger = logging.getLogger(__name__)
class WorkerStatus(Enum):
IDLE = "idle"
BUSY = "busy"
STARTING = "starting"
STOPPING = "stopping"
ERROR = "error"
@dataclass
class Worker:
worker_id: str
status: WorkerStatus
created_at: float
last_task_at: Optional[float] = None
current_task: Optional[str] = None
task_count: int = 0
error_count: int = 0
def mark_task_start(self, task_id: str):
self.status = WorkerStatus.BUSY
self.current_task = task_id
self.last_task_at = time.time()
def mark_task_complete(self):
self.status = WorkerStatus.IDLE
self.current_task = None
self.task_count += 1
def mark_error(self):
self.status = WorkerStatus.ERROR
self.error_count += 1
class WorkerManager:
def __init__(self):
self.workers: Dict[str, Worker] = {}
self.task_queue: asyncio.Queue = asyncio.Queue()
self.executor = ThreadPoolExecutor(max_workers=settings.MAX_WORKERS)
self.running = False
self.monitor_task: Optional[asyncio.Task] = None
self.worker_tasks: Dict[str, asyncio.Task] = {}
# 스케일링 설정
self.last_scale_time = time.time()
self.scale_cooldown = 60 # 1분 쿨다운
async def start(self):
"""워커 매니저를 시작합니다."""
if self.running:
return
self.running = True
logger.info("Starting worker manager...")
# 초기 워커 생성
await self._scale_workers(settings.MIN_WORKERS)
# 모니터링 태스크 시작
self.monitor_task = asyncio.create_task(self._monitor_loop())
logger.info(f"Worker manager started with {len(self.workers)} workers")
async def stop(self):
"""워커 매니저를 중지합니다."""
if not self.running:
return
self.running = False
logger.info("Stopping worker manager...")
# 모니터링 태스크 중지
if self.monitor_task:
self.monitor_task.cancel()
try:
await self.monitor_task
except asyncio.CancelledError:
pass
# 모든 워커 중지
await self._stop_all_workers()
# 스레드 풀 종료
self.executor.shutdown(wait=True)
logger.info("Worker manager stopped")
async def submit_task(self, task_func: Callable, *args, **kwargs) -> Any:
"""태스크를 워커에게 제출합니다."""
task_id = str(uuid.uuid4())
# 사용 가능한 워커 찾기
worker = await self._get_available_worker()
if not worker:
logger.warning("No available workers, queuing task")
# 큐에 추가하고 대기
future = asyncio.Future()
await self.task_queue.put((task_id, task_func, args, kwargs, future))
return await future
# 워커에 태스크 할당
return await self._execute_task(worker, task_id, task_func, *args, **kwargs)
async def _get_available_worker(self) -> Optional[Worker]:
"""사용 가능한 워커를 찾습니다."""
for worker in self.workers.values():
if worker.status == WorkerStatus.IDLE:
return worker
return None
async def _execute_task(self, worker: Worker, task_id: str,
task_func: Callable, *args, **kwargs) -> Any:
"""워커에서 태스크를 실행합니다."""
worker.mark_task_start(task_id)
logger.debug(f"Executing task {task_id} on worker {worker.worker_id}")
try:
# 비동기 함수인지 확인
if asyncio.iscoroutinefunction(task_func):
result = await task_func(*args, **kwargs)
else:
# 동기 함수는 스레드 풀에서 실행
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(self.executor, task_func, *args, **kwargs)
worker.mark_task_complete()
logger.debug(f"Task {task_id} completed successfully")
# 큐에서 대기 중인 태스크가 있다면 처리
asyncio.create_task(self._process_queue())
return result
except Exception as e:
worker.mark_error()
logger.error(f"Task {task_id} failed on worker {worker.worker_id}: {e}")
# 큐에서 대기 중인 태스크가 있다면 처리
asyncio.create_task(self._process_queue())
raise e
async def _process_queue(self):
"""큐에서 대기 중인 태스크를 처리합니다."""
if self.task_queue.empty():
return
worker = await self._get_available_worker()
if not worker:
return
try:
task_id, task_func, args, kwargs, future = self.task_queue.get_nowait()
result = await self._execute_task(worker, task_id, task_func, *args, **kwargs)
future.set_result(result)
except asyncio.QueueEmpty:
pass
except Exception as e:
if 'future' in locals():
future.set_exception(e)
async def _monitor_loop(self):
"""모니터링 루프"""
while self.running:
try:
await self._check_scaling()
await self._cleanup_error_workers()
await asyncio.sleep(settings.VRAM_CHECK_INTERVAL)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in monitor loop: {e}")
await asyncio.sleep(5)
async def _check_scaling(self):
"""스케일링 필요성을 확인합니다."""
current_time = time.time()
if current_time - self.last_scale_time < self.scale_cooldown:
return
# VRAM 사용량 확인
gpu_info = gpu_monitor.get_gpu_memory_info()
vram_usage = gpu_info['usage_percent'] / 100.0
# 현재 워커 상태 분석
total_workers = len(self.workers)
idle_workers = sum(1 for w in self.workers.values() if w.status == WorkerStatus.IDLE)
busy_workers = sum(1 for w in self.workers.values() if w.status == WorkerStatus.BUSY)
queue_size = self.task_queue.qsize()
logger.debug(f"Scaling check - VRAM: {vram_usage:.2f}, Workers: {total_workers}, "
f"Idle: {idle_workers}, Busy: {busy_workers}, Queue: {queue_size}")
# 스케일 업 조건
should_scale_up = (
vram_usage < settings.VRAM_THRESHOLD_LOW and
(queue_size > 0 or idle_workers == 0) and
total_workers < settings.MAX_WORKERS
)
# 스케일 다운 조건
should_scale_down = (
vram_usage > settings.VRAM_THRESHOLD_HIGH or
(idle_workers > total_workers * 0.5 and total_workers > settings.MIN_WORKERS)
)
if should_scale_up:
new_count = min(total_workers + 1, settings.MAX_WORKERS)
await self._scale_workers(new_count)
self.last_scale_time = current_time
logger.info(f"Scaled up to {new_count} workers (VRAM: {vram_usage:.2f})")
elif should_scale_down:
new_count = max(total_workers - 1, settings.MIN_WORKERS)
await self._scale_workers(new_count)
self.last_scale_time = current_time
logger.info(f"Scaled down to {new_count} workers (VRAM: {vram_usage:.2f})")
async def _scale_workers(self, target_count: int):
"""워커 수를 조정합니다."""
current_count = len(self.workers)
if target_count > current_count:
# 워커 추가
for i in range(target_count - current_count):
worker_id = f"worker_{uuid.uuid4().hex[:8]}"
worker = Worker(
worker_id=worker_id,
status=WorkerStatus.IDLE,
created_at=time.time()
)
self.workers[worker_id] = worker
logger.debug(f"Created worker {worker_id}")
elif target_count < current_count:
# 워커 제거 (유휴 상태인 것만)
workers_to_remove = []
for worker in self.workers.values():
if (worker.status == WorkerStatus.IDLE and
len(workers_to_remove) < (current_count - target_count)):
workers_to_remove.append(worker)
for worker in workers_to_remove:
worker.status = WorkerStatus.STOPPING
del self.workers[worker.worker_id]
logger.debug(f"Removed worker {worker.worker_id}")
async def _cleanup_error_workers(self):
"""에러 상태의 워커를 정리합니다."""
error_workers = [w for w in self.workers.values() if w.status == WorkerStatus.ERROR]
for worker in error_workers:
# 에러 워커 제거 후 새 워커 생성
del self.workers[worker.worker_id]
logger.warning(f"Removed error worker {worker.worker_id}")
# 새 워커 생성
new_worker_id = f"worker_{uuid.uuid4().hex[:8]}"
new_worker = Worker(
worker_id=new_worker_id,
status=WorkerStatus.IDLE,
created_at=time.time()
)
self.workers[new_worker_id] = new_worker
logger.info(f"Created replacement worker {new_worker_id}")
async def _stop_all_workers(self):
"""모든 워커를 중지합니다."""
for worker in self.workers.values():
worker.status = WorkerStatus.STOPPING
# 실행 중인 태스크가 완료될 때까지 대기
max_wait = 30 # 30초 최대 대기
wait_start = time.time()
while time.time() - wait_start < max_wait:
busy_workers = [w for w in self.workers.values() if w.status == WorkerStatus.BUSY]
if not busy_workers:
break
await asyncio.sleep(1)
self.workers.clear()
def get_status(self) -> dict:
"""워커 매니저의 현재 상태를 반환합니다."""
workers_by_status = {
"idle": [],
"busy": [],
"starting": [],
"stopping": [],
"error": []
}
for worker_id, worker in self.workers.items():
status_data = {
"id": worker.worker_id,
"status": worker.status.value,
"task_count": worker.task_count,
"error_count": worker.error_count,
"last_task_at": worker.last_task_at
}
workers_by_status[worker.status.value].append(status_data)
return {
"running": self.running,
"total_workers": len(self.workers),
"queue_size": self.task_queue.qsize(),
"workers_by_status": workers_by_status
}
async def process_inpaint(self, **kwargs) -> Optional[np.ndarray]:
"""인페인팅 작업을 처리합니다."""
try:
from ..core.session_pool import session_pool, ModelType
model_name = kwargs.get('model_name', 'simple-lama')
# 모델명에 따라 세션 타입 결정
if model_name == 'migan':
model_type = ModelType.MIGAN
else:
model_type = ModelType.SIMPLE_LAMA # 기본값
# 세션 풀에서 모델 세션 가져와서 처리
async with session_pool.get_session(model_type) as session:
start_time = time.time()
# session.model 에서 실제 모델 객체의 메서드를 호출해야 함
result = await session.model.inpaint(
image=kwargs['image'],
mask=kwargs['mask']
)
duration = time.time() - start_time
stats_manager.record_time(model_name, duration)
logger.info(f"'{model_name}' inpainting processed in {duration:.3f}s")
return result
except Exception as e:
logger.error(f"인페인팅 처리 실패: {e}", exc_info=True)
return None
async def process_remove_bg(self, **kwargs) -> Optional[Tuple[np.ndarray, np.ndarray]]:
"""배경 제거 작업을 처리합니다."""
try:
from ..core.session_pool import session_pool, ModelType
model_name = kwargs.get('model_name', 'birefnet-general-lite') # 기본 모델명 변경
# 세션 풀에서 REMBG 모델 세션 가져와서 처리
async with session_pool.get_session(ModelType.REMBG) as session:
start_time = time.time()
# session.model 에서 실제 모델 객체의 메서드를 호출해야 함
result = await session.model.remove_background(
image=kwargs['image'],
model_name=model_name
)
duration = time.time() - start_time
stats_manager.record_time('rembg', duration)
logger.info(f"'rembg ({model_name})' processed in {duration:.3f}s")
return result
except Exception as e:
logger.error(f"배경 제거 처리 실패: {e}", exc_info=True)
return None, None
# 전역 워커 매니저 인스턴스
worker_manager = WorkerManager()