# -*- coding: utf-8 -*- import time, logging from typing import Optional, Dict import pynvml class GPUMemTracker: """NVML 기반 GPU 메모리 스냅샷/로그 유틸""" def __init__(self, logger=None, device_index: int = 0): self.logger = logger self.device_index = device_index self.inited = False self.handle = None try: pynvml.nvmlInit() self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) self.inited = True self._log(f"[GPUMem] NVML init ok. device_index={device_index}") except Exception as e: self._log(f"[GPUMem] NVML init failed: {e}", level=logging.WARNING) def _log(self, msg, level=logging.INFO): if self.logger and hasattr(self.logger, "log"): self.logger.log(msg, level=level) else: print(msg) @staticmethod def _fmt(bytes_val: int) -> str: return f"{bytes_val/1024/1024:.1f}MB" def snapshot(self) -> Optional[Dict[str, int]]: if not self.inited: return None mi = pynvml.nvmlDeviceGetMemoryInfo(self.handle) return {"total": mi.total, "used": mi.used, "free": mi.free} def prettify(self, snap: Optional[Dict[str, int]]) -> str: if not snap: return "NA" return f"used={self._fmt(snap['used'])} free={self._fmt(snap['free'])} total={self._fmt(snap['total'])}" def log_snapshot(self, tag: str = "", trace: str = ""): s = self.snapshot() if s: self._log(f"[GPUMem]{f'[{trace}]' if trace else ''}{f'[{tag}]' if tag else ''} {self.prettify(s)}") def diff_used(self, before: Dict[str,int], after: Dict[str,int]) -> Optional[int]: if not before or not after: return None return after["used"] - before["used"] def close(self): if self.inited: try: pynvml.nvmlShutdown() except Exception: pass self.inited = False