Janus/janus/utils/cuda_memory_manager.py

69 lines
2.4 KiB
Python
Raw Normal View History

2025-01-28 12:17:07 -05:00
from functools import wraps
2025-01-29 03:56:28 -05:00
from typing import Callable, Any
import torch
2025-01-28 12:17:07 -05:00
import warnings
def monitor_memory(
2025-01-29 03:56:28 -05:00
warning_threshold_gb: float = 2.0,
track_stats: bool = True,
cleanup_on_warning: bool = True,
2025-01-28 12:17:07 -05:00
) -> Callable:
2025-01-29 03:56:28 -05:00
"""Memory monitoring decorator for CUDA operations.
Args:
warning_threshold_gb: Memory threshold in GB to trigger warnings
track_stats: Whether to track and print memory statistics
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached
Returns:
Decorator function that monitors memory usage
2025-01-28 12:17:07 -05:00
"""
2025-01-29 03:56:28 -05:00
2025-01-28 12:17:07 -05:00
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
if not torch.cuda.is_available():
return func(*args, **kwargs)
2025-01-29 03:56:28 -05:00
# Get initial memory state
2025-01-28 12:17:07 -05:00
free_before = torch.cuda.mem_get_info()[0] / 1024**3
try:
2025-01-29 03:56:28 -05:00
# Check memory state and cleanup if needed
if free_before < warning_threshold_gb and cleanup_on_warning:
2025-01-28 12:17:07 -05:00
torch.cuda.empty_cache()
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
2025-01-29 03:56:28 -05:00
if free_after_cleanup < warning_threshold_gb:
2025-01-28 12:17:07 -05:00
warnings.warn(
2025-01-29 03:56:28 -05:00
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
2025-01-28 12:17:07 -05:00
)
result = func(*args, **kwargs)
2025-01-29 03:56:28 -05:00
# Track memory statistics if enabled
if track_stats:
2025-01-28 12:17:07 -05:00
peak = torch.cuda.max_memory_allocated() / 1024**3
free_after = torch.cuda.mem_get_info()[0] / 1024**3
print(
f"Memory stats for {func.__name__}:\n"
2025-01-29 03:56:28 -05:00
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
2025-01-28 12:17:07 -05:00
)
2025-01-29 03:56:28 -05:00
torch.cuda.reset_peak_memory_stats()
2025-01-28 12:17:07 -05:00
return result
except RuntimeError as e:
if "out of memory" in str(e):
free = torch.cuda.mem_get_info()[0] / 1024**3
raise RuntimeError(
2025-01-29 03:56:28 -05:00
f"OOM in {func.__name__} with {free:.2f}GB free. "
"Consider reducing batch size or image resolution."
2025-01-28 12:17:07 -05:00
) from e
raise
return wrapper
2025-01-29 03:56:28 -05:00
return decorator