ImageProcessor_MainServer/worker/gpu_mem_monitor.py

57 lines
2.0 KiB
Python

# -*- coding: utf-8 -*-
import time, logging
from typing import Optional, Dict
import pynvml
class GPUMemTracker:
"""NVML 기반 GPU 메모리 스냅샷/로그 유틸"""
def __init__(self, logger=None, device_index: int = 0):
self.logger = logger
self.device_index = device_index
self.inited = False
self.handle = None
try:
pynvml.nvmlInit()
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
self.inited = True
self._log(f"[GPUMem] NVML init ok. device_index={device_index}")
except Exception as e:
self._log(f"[GPUMem] NVML init failed: {e}", level=logging.WARNING)
def _log(self, msg, level=logging.INFO):
if self.logger and hasattr(self.logger, "log"):
self.logger.log(msg, level=level)
else:
print(msg)
@staticmethod
def _fmt(bytes_val: int) -> str:
return f"{bytes_val/1024/1024:.1f}MB"
def snapshot(self) -> Optional[Dict[str, int]]:
if not self.inited:
return None
mi = pynvml.nvmlDeviceGetMemoryInfo(self.handle)
return {"total": mi.total, "used": mi.used, "free": mi.free}
def prettify(self, snap: Optional[Dict[str, int]]) -> str:
if not snap:
return "NA"
return f"used={self._fmt(snap['used'])} free={self._fmt(snap['free'])} total={self._fmt(snap['total'])}"
def log_snapshot(self, tag: str = "", trace: str = ""):
s = self.snapshot()
if s:
self._log(f"[GPUMem]{f'[{trace}]' if trace else ''}{f'[{tag}]' if tag else ''} {self.prettify(s)}")
def diff_used(self, before: Dict[str,int], after: Dict[str,int]) -> Optional[int]:
if not before or not after:
return None
return after["used"] - before["used"]
def close(self):
if self.inited:
try: pynvml.nvmlShutdown()
except Exception: pass
self.inited = False