57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
import time, logging
|
|
from typing import Optional, Dict
|
|
import pynvml
|
|
|
|
class GPUMemTracker:
|
|
"""NVML 기반 GPU 메모리 스냅샷/로그 유틸"""
|
|
def __init__(self, logger=None, device_index: int = 0):
|
|
self.logger = logger
|
|
self.device_index = device_index
|
|
self.inited = False
|
|
self.handle = None
|
|
try:
|
|
pynvml.nvmlInit()
|
|
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
|
|
self.inited = True
|
|
self._log(f"[GPUMem] NVML init ok. device_index={device_index}")
|
|
except Exception as e:
|
|
self._log(f"[GPUMem] NVML init failed: {e}", level=logging.WARNING)
|
|
|
|
def _log(self, msg, level=logging.INFO):
|
|
if self.logger and hasattr(self.logger, "log"):
|
|
self.logger.log(msg, level=level)
|
|
else:
|
|
print(msg)
|
|
|
|
@staticmethod
|
|
def _fmt(bytes_val: int) -> str:
|
|
return f"{bytes_val/1024/1024:.1f}MB"
|
|
|
|
def snapshot(self) -> Optional[Dict[str, int]]:
|
|
if not self.inited:
|
|
return None
|
|
mi = pynvml.nvmlDeviceGetMemoryInfo(self.handle)
|
|
return {"total": mi.total, "used": mi.used, "free": mi.free}
|
|
|
|
def prettify(self, snap: Optional[Dict[str, int]]) -> str:
|
|
if not snap:
|
|
return "NA"
|
|
return f"used={self._fmt(snap['used'])} free={self._fmt(snap['free'])} total={self._fmt(snap['total'])}"
|
|
|
|
def log_snapshot(self, tag: str = "", trace: str = ""):
|
|
s = self.snapshot()
|
|
if s:
|
|
self._log(f"[GPUMem]{f'[{trace}]' if trace else ''}{f'[{tag}]' if tag else ''} {self.prettify(s)}")
|
|
|
|
def diff_used(self, before: Dict[str,int], after: Dict[str,int]) -> Optional[int]:
|
|
if not before or not after:
|
|
return None
|
|
return after["used"] - before["used"]
|
|
|
|
def close(self):
|
|
if self.inited:
|
|
try: pynvml.nvmlShutdown()
|
|
except Exception: pass
|
|
self.inited = False
|