inpaintServer/app/core/batch_manager.py

142 lines
5.6 KiB
Python

import asyncio
import time
from typing import List, Dict, Any, Tuple
from asyncio import Future, Queue, Task
import uuid
import logging
from .config import settings
from .worker_manager import worker_manager
from .session_pool import ModelType
logger = logging.getLogger(__name__)
class BatchJob:
"""배치 처리를 위한 개별 작업 단위"""
def __init__(self, job_id: str, job_data: Dict[str, Any]):
self.job_id = job_id
self.job_data = job_data
self.future = asyncio.get_running_loop().create_future()
class BatchManager:
"""
마이크로 배치를 관리하는 클래스.
- 요청을 큐에 수집합니다.
- 백그라운드 태스크를 통해 큐를 감시하며 배치를 생성합니다.
- 생성된 배치를 WorkerManager에 전달하여 처리합니다.
- 처리 결과를 각 요청에 전달합니다.
"""
def __init__(self):
self._queue: Queue[BatchJob] = Queue()
self._batch_creation_task: Task | None = None
self._active = False
async def start(self):
"""배치 관리자를 시작합니다."""
if self._active:
logger.warning("BatchManager is already running.")
return
logger.info("Starting BatchManager...")
self._active = True
self._batch_creation_task = asyncio.create_task(self._batch_creation_loop())
logger.info("BatchManager started successfully.")
async def stop(self):
"""배치 관리자를 중지합니다."""
if not self._active:
logger.warning("BatchManager is not running.")
return
logger.info("Stopping BatchManager...")
self._active = False
if self._batch_creation_task:
self._batch_creation_task.cancel()
try:
await self._batch_creation_task
except asyncio.CancelledError:
pass # Task cancellation is expected
logger.info("BatchManager stopped.")
async def add_job(self, job_data: Dict[str, Any]) -> Any:
"""
API 엔드포인트에서 호출하는 메서드.
작업을 큐에 추가하고 결과가 나올 때까지 대기합니다.
"""
if not self._active:
raise RuntimeError("BatchManager is not running.")
job_id = str(uuid.uuid4())
job = BatchJob(job_id=job_id, job_data=job_data)
await self._queue.put(job)
logger.debug(f"Job {job.job_id} added to the batch queue. Queue size: {self._queue.qsize()}")
# 작업 결과를 기다립니다.
result = await job.future
return result
async def _batch_creation_loop(self):
"""
백그라운드에서 실행되며 큐를 감시하여 배치를 생성하는 루프.
"""
while self._active:
try:
# 첫 번째 작업을 기다립니다. 타임아웃이 발생하면 루프를 계속합니다.
first_job = await asyncio.wait_for(self._queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue
batch = [first_job]
batch_size = settings.MICRO_BATCH_SIZE
timeout = settings.MICRO_BATCH_TIMEOUT_MS / 1000.0 # 초 단위로 변환
# 타임아웃까지 또는 배치가 꽉 찰 때까지 작업을 추가로 수집합니다.
start_time = time.monotonic()
while len(batch) < batch_size and (time.monotonic() - start_time) < timeout:
try:
# 남은 시간만큼만 대기합니다.
remaining_time = timeout - (time.monotonic() - start_time)
if remaining_time <= 0:
break
job = await asyncio.wait_for(self._queue.get(), timeout=remaining_time)
batch.append(job)
except asyncio.TimeoutError:
break # 대기 시간 초과
logger.info(f"Creating a new batch with {len(batch)} jobs.")
# 배치를 처리할 별도의 태스크를 생성하여 루프가 다른 배치를 만드는 것을 막지 않도록 합니다.
asyncio.create_task(self._process_batch(batch))
async def _process_batch(self, batch: List[BatchJob]):
"""
생성된 배치를 WorkerManager에 전달하여 처리하고 결과를 전파합니다.
"""
batch_data = [job.job_data for job in batch]
try:
# WorkerManager에 배치 처리를 요청합니다.
# worker_manager의 process_inpaint는 이제 배치 데이터를 처리할 수 있어야 합니다.
results = await worker_manager.process_inpaint_batch(batch_data)
if len(results) != len(batch):
raise ValueError(f"Result count ({len(results)}) does not match batch size ({len(batch)}).")
# 결과를 각 작업의 Future에 설정합니다.
for job, result in zip(batch, results):
if isinstance(result, Exception):
job.future.set_exception(result)
else:
job.future.set_result(result)
logger.info(f"Successfully processed batch of {len(batch)} jobs.")
except Exception as e:
logger.error(f"Failed to process batch: {e}", exc_info=True)
# 모든 작업에 예외를 전파합니다.
for job in batch:
if not job.future.done():
job.future.set_exception(e)
# 전역 BatchManager 인스턴스
batch_manager = BatchManager()