import asyncio import time from typing import List, Dict, Any, Tuple from asyncio import Future, Queue, Task import uuid import logging from .config import settings from .worker_manager import worker_manager from .session_pool import ModelType logger = logging.getLogger(__name__) class BatchJob: """배치 처리를 위한 개별 작업 단위""" def __init__(self, job_id: str, job_data: Dict[str, Any]): self.job_id = job_id self.job_data = job_data self.future = asyncio.get_running_loop().create_future() class BatchManager: """ 마이크로 배치를 관리하는 클래스. - 요청을 큐에 수집합니다. - 백그라운드 태스크를 통해 큐를 감시하며 배치를 생성합니다. - 생성된 배치를 WorkerManager에 전달하여 처리합니다. - 처리 결과를 각 요청에 전달합니다. """ def __init__(self): self._queue: Queue[BatchJob] = Queue() self._batch_creation_task: Task | None = None self._active = False async def start(self): """배치 관리자를 시작합니다.""" if self._active: logger.warning("BatchManager is already running.") return logger.info("Starting BatchManager...") self._active = True self._batch_creation_task = asyncio.create_task(self._batch_creation_loop()) logger.info("BatchManager started successfully.") async def stop(self): """배치 관리자를 중지합니다.""" if not self._active: logger.warning("BatchManager is not running.") return logger.info("Stopping BatchManager...") self._active = False if self._batch_creation_task: self._batch_creation_task.cancel() try: await self._batch_creation_task except asyncio.CancelledError: pass # Task cancellation is expected logger.info("BatchManager stopped.") async def add_job(self, job_data: Dict[str, Any]) -> Any: """ API 엔드포인트에서 호출하는 메서드. 작업을 큐에 추가하고 결과가 나올 때까지 대기합니다. """ if not self._active: raise RuntimeError("BatchManager is not running.") job_id = str(uuid.uuid4()) job = BatchJob(job_id=job_id, job_data=job_data) await self._queue.put(job) logger.debug(f"Job {job.job_id} added to the batch queue. Queue size: {self._queue.qsize()}") # 작업 결과를 기다립니다. result = await job.future return result async def _batch_creation_loop(self): """ 백그라운드에서 실행되며 큐를 감시하여 배치를 생성하는 루프. """ while self._active: try: # 첫 번째 작업을 기다립니다. 타임아웃이 발생하면 루프를 계속합니다. first_job = await asyncio.wait_for(self._queue.get(), timeout=1.0) except asyncio.TimeoutError: continue batch = [first_job] batch_size = settings.MICRO_BATCH_SIZE timeout = settings.MICRO_BATCH_TIMEOUT_MS / 1000.0 # 초 단위로 변환 # 타임아웃까지 또는 배치가 꽉 찰 때까지 작업을 추가로 수집합니다. start_time = time.monotonic() while len(batch) < batch_size and (time.monotonic() - start_time) < timeout: try: # 남은 시간만큼만 대기합니다. remaining_time = timeout - (time.monotonic() - start_time) if remaining_time <= 0: break job = await asyncio.wait_for(self._queue.get(), timeout=remaining_time) batch.append(job) except asyncio.TimeoutError: break # 대기 시간 초과 logger.info(f"Creating a new batch with {len(batch)} jobs.") # 배치를 처리할 별도의 태스크를 생성하여 루프가 다른 배치를 만드는 것을 막지 않도록 합니다. asyncio.create_task(self._process_batch(batch)) async def _process_batch(self, batch: List[BatchJob]): """ 생성된 배치를 WorkerManager에 전달하여 처리하고 결과를 전파합니다. """ batch_data = [job.job_data for job in batch] try: # WorkerManager에 배치 처리를 요청합니다. # worker_manager의 process_inpaint는 이제 배치 데이터를 처리할 수 있어야 합니다. results = await worker_manager.process_inpaint_batch(batch_data) if len(results) != len(batch): raise ValueError(f"Result count ({len(results)}) does not match batch size ({len(batch)}).") # 결과를 각 작업의 Future에 설정합니다. for job, result in zip(batch, results): if isinstance(result, Exception): job.future.set_exception(result) else: job.future.set_result(result) logger.info(f"Successfully processed batch of {len(batch)} jobs.") except Exception as e: logger.error(f"Failed to process batch: {e}", exc_info=True) # 모든 작업에 예외를 전파합니다. for job in batch: if not job.future.done(): job.future.set_exception(e) # 전역 BatchManager 인스턴스 batch_manager = BatchManager()