From edd64d88f5d02a691818c19472a91f889ba17322 Mon Sep 17 00:00:00 2001 From: Elkana Baris <104759389+ElkanaBaris@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:56:28 +0200 Subject: [PATCH] adding memory monitoring to janus flow --- generation_inference.py | 38 ++++++---- janus/janusflow/models/clip_encoder.py | 4 + janus/janusflow/models/modeling_vlm.py | 2 + janus/janusflow/models/processing_vlm.py | 2 + janus/models/clip_encoder.py | 2 +- janus/models/modeling_vlm.py | 4 +- janus/models/processing_vlm.py | 2 +- janus/utils/cuda_memory_manager.py | 97 +++++++----------------- 8 files changed, 65 insertions(+), 86 deletions(-) diff --git a/generation_inference.py b/generation_inference.py index 56cf7cb..9383f6d 100644 --- a/generation_inference.py +++ b/generation_inference.py @@ -54,7 +54,7 @@ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( prompt = sft_format + vl_chat_processor.image_start_tag -@monitor_critical_memory(threshold_gb=2.0) +@monitor_memory(warning_threshold_gb=1.5, track_stats=True) @torch.inference_mode() def generate( mmgpt: MultiModalityCausalLM, @@ -70,36 +70,46 @@ def generate( input_ids = vl_chat_processor.tokenizer.encode(prompt) input_ids = torch.LongTensor(input_ids) - tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda() - for i in range(parallel_size*2): + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda() + for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) - generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() + generated_tokens = torch.zeros( + (parallel_size, image_token_num_per_image), dtype=torch.int + ).cuda() for i in range(image_token_num_per_image): - outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None) + outputs = mmgpt.language_model.model( + inputs_embeds=inputs_embeds, + use_cache=True, + past_key_values=outputs.past_key_values if i != 0 else None, + ) hidden_states = outputs.last_hidden_state - + logits = mmgpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] - - logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) + + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) - next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + next_token = torch.cat( + [next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1 + ).view(-1) img_embeds = mmgpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) - - dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]) + dec = mmgpt.gen_vision_model.decode_code( + generated_tokens.to(dtype=torch.int), + shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size], + ) dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) @@ -107,9 +117,9 @@ def generate( visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) visual_img[:, :, :] = dec - os.makedirs('generated_samples', exist_ok=True) + os.makedirs("generated_samples", exist_ok=True) for i in range(parallel_size): - save_path = os.path.join('generated_samples', "img_{}.jpg".format(i)) + save_path = os.path.join("generated_samples", "img_{}.jpg".format(i)) PIL.Image.fromarray(visual_img[i]).save(save_path) @@ -117,4 +127,4 @@ generate( vl_gpt, vl_chat_processor, prompt, -) \ No newline at end of file +) diff --git a/janus/janusflow/models/clip_encoder.py b/janus/janusflow/models/clip_encoder.py index ffdaba6..76901ab 100644 --- a/janus/janusflow/models/clip_encoder.py +++ b/janus/janusflow/models/clip_encoder.py @@ -25,6 +25,9 @@ import torchvision.transforms from einops import rearrange from janus.janusflow.models.siglip_vit import create_siglip_vit +from janus.utils.cuda_memory_manager import ( + monitor_memory, +) class CLIPVisionTower(nn.Module): @@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module): raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def forward(self, images): """ diff --git a/janus/janusflow/models/modeling_vlm.py b/janus/janusflow/models/modeling_vlm.py index 3fc119b..7f0d504 100644 --- a/janus/janusflow/models/modeling_vlm.py +++ b/janus/janusflow/models/modeling_vlm.py @@ -31,6 +31,7 @@ from transformers import ( from transformers.models.llama.modeling_llama import LlamaRMSNorm from janus.janusflow.models.clip_encoder import CLIPVisionTower from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder +from janus.utils.cuda_memory_manager import monitor_memory import torch.nn as nn @@ -168,6 +169,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ) self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True) + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, diff --git a/janus/janusflow/models/processing_vlm.py b/janus/janusflow/models/processing_vlm.py index a6fe494..05e8d58 100644 --- a/janus/janusflow/models/processing_vlm.py +++ b/janus/janusflow/models/processing_vlm.py @@ -27,6 +27,7 @@ from transformers.processing_utils import ProcessorMixin from janus.janusflow.models.image_processing_vlm import VLMImageProcessor from janus.utils.conversation import get_conv_template +from janus.utils.cuda_memory_manager import monitor_memory class DictOutput(object): @@ -384,6 +385,7 @@ class VLChatProcessor(ProcessorMixin): return prepare + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def batchify( self, prepare_list: List[VLChatProcessorOutput] ) -> BatchedVLChatProcessorOutput: diff --git a/janus/models/clip_encoder.py b/janus/models/clip_encoder.py index 26bc070..93f458a 100644 --- a/janus/models/clip_encoder.py +++ b/janus/models/clip_encoder.py @@ -107,7 +107,7 @@ class CLIPVisionTower(nn.Module): raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features - @monitor_critical_memory(threshold_gb=2.0) + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def forward(self, images): """ diff --git a/janus/models/modeling_vlm.py b/janus/models/modeling_vlm.py index 962d9f1..e3239d6 100644 --- a/janus/models/modeling_vlm.py +++ b/janus/models/modeling_vlm.py @@ -221,7 +221,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): language_config = config.language_config self.language_model = LlamaForCausalLM(language_config) - @monitor_critical_memory(threshold_gb=2.0) + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, @@ -263,7 +263,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): return inputs_embeds - @monitor_critical_memory(threshold_gb=2.0) + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): return self.gen_aligner(self.gen_embed(image_ids)) diff --git a/janus/models/processing_vlm.py b/janus/models/processing_vlm.py index f8aae88..2f0dfe8 100644 --- a/janus/models/processing_vlm.py +++ b/janus/models/processing_vlm.py @@ -357,7 +357,7 @@ class VLChatProcessor(ProcessorMixin): return prepare - @monitor_memory(threshold_gb=2.0) + @monitor_memory(warning_threshold_gb=1.5, track_stats=True) def batchify( self, prepare_list: List[VLChatProcessorOutput] ) -> BatchedVLChatProcessorOutput: diff --git a/janus/utils/cuda_memory_manager.py b/janus/utils/cuda_memory_manager.py index 0d7c644..f3b35ec 100644 --- a/janus/utils/cuda_memory_manager.py +++ b/janus/utils/cuda_memory_manager.py @@ -1,93 +1,56 @@ -# Copyright (c) 2023-2024 DeepSeek. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy of -# this software and associated documentation files (the "Software"), to deal in -# the Software without restriction, including without limitation the rights to -# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -# the Software, and to permit persons to whom the Software is furnished to do so, -# subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -from dataclasses import dataclass -from typing import Dict, Optional, Callable, Any, Tuple, List from functools import wraps +from typing import Callable, Any +import torch import warnings -import time -import math -import logging -from enum import Enum -"""Memory monitoring utilities for Janus. - -This module provides essential memory management for multi-modal operations, -focusing on preventing OOM issues and optimizing resource usage for -vision-language tasks. -""" - -class JanusMemoryManager: - """Memory manager tailored for multi-modal operations.""" - - def __init__(self, config: Dict[str, Any]): - self.warning_threshold_gb = config.get('warning_threshold_gb', 2.0) - self.oom_threshold_gb = config.get('oom_threshold_gb', 1.0) - self.peak_tracking = config.get('peak_tracking', True) - - def check_memory(self) -> Dict[str, float]: - """Get current CUDA memory status.""" - if not torch.cuda.is_available(): - return {} - return { - 'free': torch.cuda.mem_get_info()[0] / 1024**3, - 'peak': torch.cuda.max_memory_allocated() / 1024**3 - } def monitor_memory( - threshold_gb: float = 2.0, - track_peak: bool = True + warning_threshold_gb: float = 2.0, + track_stats: bool = True, + cleanup_on_warning: bool = True, ) -> Callable: - """Decorator for monitoring memory in critical paths. - - Designed specifically for multi-modal operations where - memory usage can spike during modality fusion. + """Memory monitoring decorator for CUDA operations. + + Args: + warning_threshold_gb: Memory threshold in GB to trigger warnings + track_stats: Whether to track and print memory statistics + cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached + + Returns: + Decorator function that monitors memory usage """ + def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> Any: if not torch.cuda.is_available(): return func(*args, **kwargs) - # Track initial state + # Get initial memory state free_before = torch.cuda.mem_get_info()[0] / 1024**3 try: - if free_before < threshold_gb: + # Check memory state and cleanup if needed + if free_before < warning_threshold_gb and cleanup_on_warning: torch.cuda.empty_cache() free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3 - if free_after_cleanup < threshold_gb: + + if free_after_cleanup < warning_threshold_gb: warnings.warn( - f"Critical memory state in {func.__name__}: " - f"{free_after_cleanup:.2f}GB free" + f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free" ) result = func(*args, **kwargs) - if track_peak: + # Track memory statistics if enabled + if track_stats: peak = torch.cuda.max_memory_allocated() / 1024**3 free_after = torch.cuda.mem_get_info()[0] / 1024**3 print( f"Memory stats for {func.__name__}:\n" - f"Peak usage: {peak:.2f}GB\n" - f"Memory delta: {free_before - free_after:.2f}GB" + f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB" ) + torch.cuda.reset_peak_memory_stats() return result @@ -95,13 +58,11 @@ def monitor_memory( if "out of memory" in str(e): free = torch.cuda.mem_get_info()[0] / 1024**3 raise RuntimeError( - f"OOM in {func.__name__}. Free memory: {free:.2f}GB\n" - f"Consider reducing batch size or image resolution" + f"OOM in {func.__name__} with {free:.2f}GB free. " + "Consider reducing batch size or image resolution." ) from e raise - finally: - if track_peak: - torch.cuda.reset_peak_memory_stats() return wrapper - return decorator \ No newline at end of file + + return decorator