From cab2784c2040c6f273c8c27bab00fd2cd079729a Mon Sep 17 00:00:00 2001 From: Elkana Baris <104759389+ElkanaBaris@users.noreply.github.com> Date: Tue, 28 Jan 2025 19:17:07 +0200 Subject: [PATCH] adding memory monitoring decorator --- generation_inference.py | 4 ++ janus/models/clip_encoder.py | 4 ++ janus/models/modeling_vlm.py | 5 ++ janus/models/processing_vlm.py | 4 ++ janus/utils/cuda_memory_manager.py | 107 +++++++++++++++++++++++++++++ 5 files changed, 124 insertions(+) create mode 100644 janus/utils/cuda_memory_manager.py diff --git a/generation_inference.py b/generation_inference.py index 31cc61e..56cf7cb 100644 --- a/generation_inference.py +++ b/generation_inference.py @@ -21,6 +21,9 @@ import torch from transformers import AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor +from janus.utils.cuda_memory_manager import ( + monitor_memory, +) import numpy as np import os import PIL.Image @@ -51,6 +54,7 @@ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( prompt = sft_format + vl_chat_processor.image_start_tag +@monitor_critical_memory(threshold_gb=2.0) @torch.inference_mode() def generate( mmgpt: MultiModalityCausalLM, diff --git a/janus/models/clip_encoder.py b/janus/models/clip_encoder.py index 8ea8fcf..26bc070 100644 --- a/janus/models/clip_encoder.py +++ b/janus/models/clip_encoder.py @@ -25,6 +25,9 @@ import torchvision.transforms from einops import rearrange from janus.models.siglip_vit import create_siglip_vit +from janus.utils.cuda_memory_manager import ( + monitor_memory, +) class CLIPVisionTower(nn.Module): @@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module): raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features + @monitor_critical_memory(threshold_gb=2.0) def forward(self, images): """ diff --git a/janus/models/modeling_vlm.py b/janus/models/modeling_vlm.py index 8d23bc0..962d9f1 100644 --- a/janus/models/modeling_vlm.py +++ b/janus/models/modeling_vlm.py @@ -31,6 +31,9 @@ from transformers.configuration_utils import PretrainedConfig from janus.models.clip_encoder import CLIPVisionTower from janus.models.projector import MlpProjector +from janus.utils.cuda_memory_manager import ( + monitor_memory, +) class vision_head(torch.nn.Module): @@ -218,6 +221,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): language_config = config.language_config self.language_model = LlamaForCausalLM(language_config) + @monitor_critical_memory(threshold_gb=2.0) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, @@ -259,6 +263,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): return inputs_embeds + @monitor_critical_memory(threshold_gb=2.0) def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): return self.gen_aligner(self.gen_embed(image_ids)) diff --git a/janus/models/processing_vlm.py b/janus/models/processing_vlm.py index eba6895..f8aae88 100644 --- a/janus/models/processing_vlm.py +++ b/janus/models/processing_vlm.py @@ -27,6 +27,9 @@ from transformers.processing_utils import ProcessorMixin from janus.models.image_processing_vlm import VLMImageProcessor from janus.utils.conversation import get_conv_template +from janus.utils.cuda_memory_manager import ( + monitor_memory, +) class DictOutput(object): @@ -354,6 +357,7 @@ class VLChatProcessor(ProcessorMixin): return prepare + @monitor_memory(threshold_gb=2.0) def batchify( self, prepare_list: List[VLChatProcessorOutput] ) -> BatchedVLChatProcessorOutput: diff --git a/janus/utils/cuda_memory_manager.py b/janus/utils/cuda_memory_manager.py new file mode 100644 index 0000000..0d7c644 --- /dev/null +++ b/janus/utils/cuda_memory_manager.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from dataclasses import dataclass +from typing import Dict, Optional, Callable, Any, Tuple, List +from functools import wraps +import warnings +import time +import math +import logging +from enum import Enum + +"""Memory monitoring utilities for Janus. + +This module provides essential memory management for multi-modal operations, +focusing on preventing OOM issues and optimizing resource usage for +vision-language tasks. +""" + +class JanusMemoryManager: + """Memory manager tailored for multi-modal operations.""" + + def __init__(self, config: Dict[str, Any]): + self.warning_threshold_gb = config.get('warning_threshold_gb', 2.0) + self.oom_threshold_gb = config.get('oom_threshold_gb', 1.0) + self.peak_tracking = config.get('peak_tracking', True) + + def check_memory(self) -> Dict[str, float]: + """Get current CUDA memory status.""" + if not torch.cuda.is_available(): + return {} + return { + 'free': torch.cuda.mem_get_info()[0] / 1024**3, + 'peak': torch.cuda.max_memory_allocated() / 1024**3 + } + +def monitor_memory( + threshold_gb: float = 2.0, + track_peak: bool = True +) -> Callable: + """Decorator for monitoring memory in critical paths. + + Designed specifically for multi-modal operations where + memory usage can spike during modality fusion. + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + if not torch.cuda.is_available(): + return func(*args, **kwargs) + + # Track initial state + free_before = torch.cuda.mem_get_info()[0] / 1024**3 + + try: + if free_before < threshold_gb: + torch.cuda.empty_cache() + free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3 + if free_after_cleanup < threshold_gb: + warnings.warn( + f"Critical memory state in {func.__name__}: " + f"{free_after_cleanup:.2f}GB free" + ) + + result = func(*args, **kwargs) + + if track_peak: + peak = torch.cuda.max_memory_allocated() / 1024**3 + free_after = torch.cuda.mem_get_info()[0] / 1024**3 + print( + f"Memory stats for {func.__name__}:\n" + f"Peak usage: {peak:.2f}GB\n" + f"Memory delta: {free_before - free_after:.2f}GB" + ) + + return result + + except RuntimeError as e: + if "out of memory" in str(e): + free = torch.cuda.mem_get_info()[0] / 1024**3 + raise RuntimeError( + f"OOM in {func.__name__}. Free memory: {free:.2f}GB\n" + f"Consider reducing batch size or image resolution" + ) from e + raise + finally: + if track_peak: + torch.cuda.reset_peak_memory_stats() + + return wrapper + return decorator \ No newline at end of file