142 lines
5.6 KiB
Python
142 lines
5.6 KiB
Python
import asyncio
|
|
import time
|
|
from typing import List, Dict, Any, Tuple
|
|
from asyncio import Future, Queue, Task
|
|
import uuid
|
|
import logging
|
|
|
|
from .config import settings
|
|
from .worker_manager import worker_manager
|
|
from .session_pool import ModelType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class BatchJob:
|
|
"""배치 처리를 위한 개별 작업 단위"""
|
|
def __init__(self, job_id: str, job_data: Dict[str, Any]):
|
|
self.job_id = job_id
|
|
self.job_data = job_data
|
|
self.future = asyncio.get_running_loop().create_future()
|
|
|
|
class BatchManager:
|
|
"""
|
|
마이크로 배치를 관리하는 클래스.
|
|
- 요청을 큐에 수집합니다.
|
|
- 백그라운드 태스크를 통해 큐를 감시하며 배치를 생성합니다.
|
|
- 생성된 배치를 WorkerManager에 전달하여 처리합니다.
|
|
- 처리 결과를 각 요청에 전달합니다.
|
|
"""
|
|
def __init__(self):
|
|
self._queue: Queue[BatchJob] = Queue()
|
|
self._batch_creation_task: Task | None = None
|
|
self._active = False
|
|
|
|
async def start(self):
|
|
"""배치 관리자를 시작합니다."""
|
|
if self._active:
|
|
logger.warning("BatchManager is already running.")
|
|
return
|
|
|
|
logger.info("Starting BatchManager...")
|
|
self._active = True
|
|
self._batch_creation_task = asyncio.create_task(self._batch_creation_loop())
|
|
logger.info("BatchManager started successfully.")
|
|
|
|
async def stop(self):
|
|
"""배치 관리자를 중지합니다."""
|
|
if not self._active:
|
|
logger.warning("BatchManager is not running.")
|
|
return
|
|
|
|
logger.info("Stopping BatchManager...")
|
|
self._active = False
|
|
if self._batch_creation_task:
|
|
self._batch_creation_task.cancel()
|
|
try:
|
|
await self._batch_creation_task
|
|
except asyncio.CancelledError:
|
|
pass # Task cancellation is expected
|
|
logger.info("BatchManager stopped.")
|
|
|
|
async def add_job(self, job_data: Dict[str, Any]) -> Any:
|
|
"""
|
|
API 엔드포인트에서 호출하는 메서드.
|
|
작업을 큐에 추가하고 결과가 나올 때까지 대기합니다.
|
|
"""
|
|
if not self._active:
|
|
raise RuntimeError("BatchManager is not running.")
|
|
|
|
job_id = str(uuid.uuid4())
|
|
job = BatchJob(job_id=job_id, job_data=job_data)
|
|
|
|
await self._queue.put(job)
|
|
logger.debug(f"Job {job.job_id} added to the batch queue. Queue size: {self._queue.qsize()}")
|
|
|
|
# 작업 결과를 기다립니다.
|
|
result = await job.future
|
|
return result
|
|
|
|
async def _batch_creation_loop(self):
|
|
"""
|
|
백그라운드에서 실행되며 큐를 감시하여 배치를 생성하는 루프.
|
|
"""
|
|
while self._active:
|
|
try:
|
|
# 첫 번째 작업을 기다립니다. 타임아웃이 발생하면 루프를 계속합니다.
|
|
first_job = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
batch = [first_job]
|
|
batch_size = settings.MICRO_BATCH_SIZE
|
|
timeout = settings.MICRO_BATCH_TIMEOUT_MS / 1000.0 # 초 단위로 변환
|
|
|
|
# 타임아웃까지 또는 배치가 꽉 찰 때까지 작업을 추가로 수집합니다.
|
|
start_time = time.monotonic()
|
|
while len(batch) < batch_size and (time.monotonic() - start_time) < timeout:
|
|
try:
|
|
# 남은 시간만큼만 대기합니다.
|
|
remaining_time = timeout - (time.monotonic() - start_time)
|
|
if remaining_time <= 0:
|
|
break
|
|
job = await asyncio.wait_for(self._queue.get(), timeout=remaining_time)
|
|
batch.append(job)
|
|
except asyncio.TimeoutError:
|
|
break # 대기 시간 초과
|
|
|
|
logger.info(f"Creating a new batch with {len(batch)} jobs.")
|
|
# 배치를 처리할 별도의 태스크를 생성하여 루프가 다른 배치를 만드는 것을 막지 않도록 합니다.
|
|
asyncio.create_task(self._process_batch(batch))
|
|
|
|
async def _process_batch(self, batch: List[BatchJob]):
|
|
"""
|
|
생성된 배치를 WorkerManager에 전달하여 처리하고 결과를 전파합니다.
|
|
"""
|
|
batch_data = [job.job_data for job in batch]
|
|
|
|
try:
|
|
# WorkerManager에 배치 처리를 요청합니다.
|
|
# worker_manager의 process_inpaint는 이제 배치 데이터를 처리할 수 있어야 합니다.
|
|
results = await worker_manager.process_inpaint_batch(batch_data)
|
|
|
|
if len(results) != len(batch):
|
|
raise ValueError(f"Result count ({len(results)}) does not match batch size ({len(batch)}).")
|
|
|
|
# 결과를 각 작업의 Future에 설정합니다.
|
|
for job, result in zip(batch, results):
|
|
if isinstance(result, Exception):
|
|
job.future.set_exception(result)
|
|
else:
|
|
job.future.set_result(result)
|
|
logger.info(f"Successfully processed batch of {len(batch)} jobs.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process batch: {e}", exc_info=True)
|
|
# 모든 작업에 예외를 전파합니다.
|
|
for job in batch:
|
|
if not job.future.done():
|
|
job.future.set_exception(e)
|
|
|
|
# 전역 BatchManager 인스턴스
|
|
batch_manager = BatchManager()
|