Janus/janus/utils/cuda_memory_manager.py
2025-01-29 10:56:28 +02:00

69 lines
2.4 KiB
Python

from functools import wraps
from typing import Callable, Any
import torch
import warnings
def monitor_memory(
warning_threshold_gb: float = 2.0,
track_stats: bool = True,
cleanup_on_warning: bool = True,
) -> Callable:
"""Memory monitoring decorator for CUDA operations.
Args:
warning_threshold_gb: Memory threshold in GB to trigger warnings
track_stats: Whether to track and print memory statistics
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached
Returns:
Decorator function that monitors memory usage
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
if not torch.cuda.is_available():
return func(*args, **kwargs)
# Get initial memory state
free_before = torch.cuda.mem_get_info()[0] / 1024**3
try:
# Check memory state and cleanup if needed
if free_before < warning_threshold_gb and cleanup_on_warning:
torch.cuda.empty_cache()
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
if free_after_cleanup < warning_threshold_gb:
warnings.warn(
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
)
result = func(*args, **kwargs)
# Track memory statistics if enabled
if track_stats:
peak = torch.cuda.max_memory_allocated() / 1024**3
free_after = torch.cuda.mem_get_info()[0] / 1024**3
print(
f"Memory stats for {func.__name__}:\n"
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
)
torch.cuda.reset_peak_memory_stats()
return result
except RuntimeError as e:
if "out of memory" in str(e):
free = torch.cuda.mem_get_info()[0] / 1024**3
raise RuntimeError(
f"OOM in {func.__name__} with {free:.2f}GB free. "
"Consider reducing batch size or image resolution."
) from e
raise
return wrapper
return decorator