mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-23 06:08:59 -05:00
adding memory monitoring decorator
This commit is contained in:
parent
a74a59f8a9
commit
cab2784c20
@ -21,6 +21,9 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -51,6 +54,7 @@ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
|||||||
prompt = sft_format + vl_chat_processor.image_start_tag
|
prompt = sft_format + vl_chat_processor.image_start_tag
|
||||||
|
|
||||||
|
|
||||||
|
@monitor_critical_memory(threshold_gb=2.0)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
mmgpt: MultiModalityCausalLM,
|
mmgpt: MultiModalityCausalLM,
|
||||||
|
@ -25,6 +25,9 @@ import torchvision.transforms
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from janus.models.siglip_vit import create_siglip_vit
|
from janus.models.siglip_vit import create_siglip_vit
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionTower(nn.Module):
|
class CLIPVisionTower(nn.Module):
|
||||||
@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module):
|
|||||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@monitor_critical_memory(threshold_gb=2.0)
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -31,6 +31,9 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
|
|
||||||
from janus.models.clip_encoder import CLIPVisionTower
|
from janus.models.clip_encoder import CLIPVisionTower
|
||||||
from janus.models.projector import MlpProjector
|
from janus.models.projector import MlpProjector
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class vision_head(torch.nn.Module):
|
class vision_head(torch.nn.Module):
|
||||||
@ -218,6 +221,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
language_config = config.language_config
|
language_config = config.language_config
|
||||||
self.language_model = LlamaForCausalLM(language_config)
|
self.language_model = LlamaForCausalLM(language_config)
|
||||||
|
|
||||||
|
@monitor_critical_memory(threshold_gb=2.0)
|
||||||
def prepare_inputs_embeds(
|
def prepare_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@ -259,6 +263,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
@monitor_critical_memory(threshold_gb=2.0)
|
||||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||||
return self.gen_aligner(self.gen_embed(image_ids))
|
return self.gen_aligner(self.gen_embed(image_ids))
|
||||||
|
|
||||||
|
@ -27,6 +27,9 @@ from transformers.processing_utils import ProcessorMixin
|
|||||||
|
|
||||||
from janus.models.image_processing_vlm import VLMImageProcessor
|
from janus.models.image_processing_vlm import VLMImageProcessor
|
||||||
from janus.utils.conversation import get_conv_template
|
from janus.utils.conversation import get_conv_template
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DictOutput(object):
|
class DictOutput(object):
|
||||||
@ -354,6 +357,7 @@ class VLChatProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return prepare
|
return prepare
|
||||||
|
|
||||||
|
@monitor_memory(threshold_gb=2.0)
|
||||||
def batchify(
|
def batchify(
|
||||||
self, prepare_list: List[VLChatProcessorOutput]
|
self, prepare_list: List[VLChatProcessorOutput]
|
||||||
) -> BatchedVLChatProcessorOutput:
|
) -> BatchedVLChatProcessorOutput:
|
||||||
|
107
janus/utils/cuda_memory_manager.py
Normal file
107
janus/utils/cuda_memory_manager.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# Copyright (c) 2023-2024 DeepSeek.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
|
# this software and associated documentation files (the "Software"), to deal in
|
||||||
|
# the Software without restriction, including without limitation the rights to
|
||||||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||||
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||||
|
# subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||||
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||||
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Callable, Any, Tuple, List
|
||||||
|
from functools import wraps
|
||||||
|
import warnings
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
"""Memory monitoring utilities for Janus.
|
||||||
|
|
||||||
|
This module provides essential memory management for multi-modal operations,
|
||||||
|
focusing on preventing OOM issues and optimizing resource usage for
|
||||||
|
vision-language tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class JanusMemoryManager:
|
||||||
|
"""Memory manager tailored for multi-modal operations."""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
self.warning_threshold_gb = config.get('warning_threshold_gb', 2.0)
|
||||||
|
self.oom_threshold_gb = config.get('oom_threshold_gb', 1.0)
|
||||||
|
self.peak_tracking = config.get('peak_tracking', True)
|
||||||
|
|
||||||
|
def check_memory(self) -> Dict[str, float]:
|
||||||
|
"""Get current CUDA memory status."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
'free': torch.cuda.mem_get_info()[0] / 1024**3,
|
||||||
|
'peak': torch.cuda.max_memory_allocated() / 1024**3
|
||||||
|
}
|
||||||
|
|
||||||
|
def monitor_memory(
|
||||||
|
threshold_gb: float = 2.0,
|
||||||
|
track_peak: bool = True
|
||||||
|
) -> Callable:
|
||||||
|
"""Decorator for monitoring memory in critical paths.
|
||||||
|
|
||||||
|
Designed specifically for multi-modal operations where
|
||||||
|
memory usage can spike during modality fusion.
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Track initial state
|
||||||
|
free_before = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
|
||||||
|
try:
|
||||||
|
if free_before < threshold_gb:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
if free_after_cleanup < threshold_gb:
|
||||||
|
warnings.warn(
|
||||||
|
f"Critical memory state in {func.__name__}: "
|
||||||
|
f"{free_after_cleanup:.2f}GB free"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
if track_peak:
|
||||||
|
peak = torch.cuda.max_memory_allocated() / 1024**3
|
||||||
|
free_after = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
print(
|
||||||
|
f"Memory stats for {func.__name__}:\n"
|
||||||
|
f"Peak usage: {peak:.2f}GB\n"
|
||||||
|
f"Memory delta: {free_before - free_after:.2f}GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "out of memory" in str(e):
|
||||||
|
free = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
raise RuntimeError(
|
||||||
|
f"OOM in {func.__name__}. Free memory: {free:.2f}GB\n"
|
||||||
|
f"Consider reducing batch size or image resolution"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if track_peak:
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
Loading…
Reference in New Issue
Block a user