2025-01-28 12:17:07 -05:00
|
|
|
from functools import wraps
|
2025-01-29 03:56:28 -05:00
|
|
|
from typing import Callable, Any
|
|
|
|
import torch
|
2025-01-28 12:17:07 -05:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
def monitor_memory(
|
2025-01-29 03:56:28 -05:00
|
|
|
warning_threshold_gb: float = 2.0,
|
|
|
|
track_stats: bool = True,
|
|
|
|
cleanup_on_warning: bool = True,
|
2025-01-28 12:17:07 -05:00
|
|
|
) -> Callable:
|
2025-01-29 03:56:28 -05:00
|
|
|
"""Memory monitoring decorator for CUDA operations.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
warning_threshold_gb: Memory threshold in GB to trigger warnings
|
|
|
|
track_stats: Whether to track and print memory statistics
|
|
|
|
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Decorator function that monitors memory usage
|
2025-01-28 12:17:07 -05:00
|
|
|
"""
|
2025-01-29 03:56:28 -05:00
|
|
|
|
2025-01-28 12:17:07 -05:00
|
|
|
def decorator(func: Callable) -> Callable:
|
|
|
|
@wraps(func)
|
|
|
|
def wrapper(*args, **kwargs) -> Any:
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
2025-01-29 03:56:28 -05:00
|
|
|
# Get initial memory state
|
2025-01-28 12:17:07 -05:00
|
|
|
free_before = torch.cuda.mem_get_info()[0] / 1024**3
|
|
|
|
|
|
|
|
try:
|
2025-01-29 03:56:28 -05:00
|
|
|
# Check memory state and cleanup if needed
|
|
|
|
if free_before < warning_threshold_gb and cleanup_on_warning:
|
2025-01-28 12:17:07 -05:00
|
|
|
torch.cuda.empty_cache()
|
|
|
|
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
|
2025-01-29 03:56:28 -05:00
|
|
|
|
|
|
|
if free_after_cleanup < warning_threshold_gb:
|
2025-01-28 12:17:07 -05:00
|
|
|
warnings.warn(
|
2025-01-29 03:56:28 -05:00
|
|
|
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
|
2025-01-28 12:17:07 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
|
2025-01-29 03:56:28 -05:00
|
|
|
# Track memory statistics if enabled
|
|
|
|
if track_stats:
|
2025-01-28 12:17:07 -05:00
|
|
|
peak = torch.cuda.max_memory_allocated() / 1024**3
|
|
|
|
free_after = torch.cuda.mem_get_info()[0] / 1024**3
|
|
|
|
print(
|
|
|
|
f"Memory stats for {func.__name__}:\n"
|
2025-01-29 03:56:28 -05:00
|
|
|
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
|
2025-01-28 12:17:07 -05:00
|
|
|
)
|
2025-01-29 03:56:28 -05:00
|
|
|
torch.cuda.reset_peak_memory_stats()
|
2025-01-28 12:17:07 -05:00
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
if "out of memory" in str(e):
|
|
|
|
free = torch.cuda.mem_get_info()[0] / 1024**3
|
|
|
|
raise RuntimeError(
|
2025-01-29 03:56:28 -05:00
|
|
|
f"OOM in {func.__name__} with {free:.2f}GB free. "
|
|
|
|
"Consider reducing batch size or image resolution."
|
2025-01-28 12:17:07 -05:00
|
|
|
) from e
|
|
|
|
raise
|
|
|
|
|
|
|
|
return wrapper
|
2025-01-29 03:56:28 -05:00
|
|
|
|
|
|
|
return decorator
|