mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-22 13:48:57 -05:00
adding memory monitoring to janus flow
This commit is contained in:
parent
cab2784c20
commit
edd64d88f5
@ -54,7 +54,7 @@ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
||||
prompt = sft_format + vl_chat_processor.image_start_tag
|
||||
|
||||
|
||||
@monitor_critical_memory(threshold_gb=2.0)
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
mmgpt: MultiModalityCausalLM,
|
||||
@ -70,36 +70,46 @@ def generate(
|
||||
input_ids = vl_chat_processor.tokenizer.encode(prompt)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
|
||||
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
|
||||
for i in range(parallel_size*2):
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
|
||||
for i in range(parallel_size * 2):
|
||||
tokens[i, :] = input_ids
|
||||
if i % 2 != 0:
|
||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||
|
||||
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
||||
generated_tokens = torch.zeros(
|
||||
(parallel_size, image_token_num_per_image), dtype=torch.int
|
||||
).cuda()
|
||||
|
||||
for i in range(image_token_num_per_image):
|
||||
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
|
||||
outputs = mmgpt.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=True,
|
||||
past_key_values=outputs.past_key_values if i != 0 else None,
|
||||
)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
|
||||
logits = mmgpt.gen_head(hidden_states[:, -1, :])
|
||||
logit_cond = logits[0::2, :]
|
||||
logit_uncond = logits[1::2, :]
|
||||
|
||||
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
||||
|
||||
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
||||
probs = torch.softmax(logits / temperature, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
||||
|
||||
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
||||
next_token = torch.cat(
|
||||
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
|
||||
).view(-1)
|
||||
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
|
||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||
|
||||
|
||||
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
|
||||
dec = mmgpt.gen_vision_model.decode_code(
|
||||
generated_tokens.to(dtype=torch.int),
|
||||
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
|
||||
)
|
||||
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
||||
|
||||
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
||||
@ -107,9 +117,9 @@ def generate(
|
||||
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
||||
visual_img[:, :, :] = dec
|
||||
|
||||
os.makedirs('generated_samples', exist_ok=True)
|
||||
os.makedirs("generated_samples", exist_ok=True)
|
||||
for i in range(parallel_size):
|
||||
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
|
||||
save_path = os.path.join("generated_samples", "img_{}.jpg".format(i))
|
||||
PIL.Image.fromarray(visual_img[i]).save(save_path)
|
||||
|
||||
|
||||
@ -117,4 +127,4 @@ generate(
|
||||
vl_gpt,
|
||||
vl_chat_processor,
|
||||
prompt,
|
||||
)
|
||||
)
|
||||
|
@ -25,6 +25,9 @@ import torchvision.transforms
|
||||
from einops import rearrange
|
||||
|
||||
from janus.janusflow.models.siglip_vit import create_siglip_vit
|
||||
from janus.utils.cuda_memory_manager import (
|
||||
monitor_memory,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionTower(nn.Module):
|
||||
@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module):
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
return image_features
|
||||
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def forward(self, images):
|
||||
"""
|
||||
|
||||
|
@ -31,6 +31,7 @@ from transformers import (
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from janus.janusflow.models.clip_encoder import CLIPVisionTower
|
||||
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
|
||||
from janus.utils.cuda_memory_manager import monitor_memory
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@ -168,6 +169,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
)
|
||||
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
|
||||
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
|
@ -27,6 +27,7 @@ from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
|
||||
from janus.utils.conversation import get_conv_template
|
||||
from janus.utils.cuda_memory_manager import monitor_memory
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
@ -384,6 +385,7 @@ class VLChatProcessor(ProcessorMixin):
|
||||
|
||||
return prepare
|
||||
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def batchify(
|
||||
self, prepare_list: List[VLChatProcessorOutput]
|
||||
) -> BatchedVLChatProcessorOutput:
|
||||
|
@ -107,7 +107,7 @@ class CLIPVisionTower(nn.Module):
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
return image_features
|
||||
|
||||
@monitor_critical_memory(threshold_gb=2.0)
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def forward(self, images):
|
||||
"""
|
||||
|
||||
|
@ -221,7 +221,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
language_config = config.language_config
|
||||
self.language_model = LlamaForCausalLM(language_config)
|
||||
|
||||
@monitor_critical_memory(threshold_gb=2.0)
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -263,7 +263,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@monitor_critical_memory(threshold_gb=2.0)
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||
return self.gen_aligner(self.gen_embed(image_ids))
|
||||
|
||||
|
@ -357,7 +357,7 @@ class VLChatProcessor(ProcessorMixin):
|
||||
|
||||
return prepare
|
||||
|
||||
@monitor_memory(threshold_gb=2.0)
|
||||
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||
def batchify(
|
||||
self, prepare_list: List[VLChatProcessorOutput]
|
||||
) -> BatchedVLChatProcessorOutput:
|
||||
|
@ -1,93 +1,56 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Callable, Any, Tuple, List
|
||||
from functools import wraps
|
||||
from typing import Callable, Any
|
||||
import torch
|
||||
import warnings
|
||||
import time
|
||||
import math
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
"""Memory monitoring utilities for Janus.
|
||||
|
||||
This module provides essential memory management for multi-modal operations,
|
||||
focusing on preventing OOM issues and optimizing resource usage for
|
||||
vision-language tasks.
|
||||
"""
|
||||
|
||||
class JanusMemoryManager:
|
||||
"""Memory manager tailored for multi-modal operations."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.warning_threshold_gb = config.get('warning_threshold_gb', 2.0)
|
||||
self.oom_threshold_gb = config.get('oom_threshold_gb', 1.0)
|
||||
self.peak_tracking = config.get('peak_tracking', True)
|
||||
|
||||
def check_memory(self) -> Dict[str, float]:
|
||||
"""Get current CUDA memory status."""
|
||||
if not torch.cuda.is_available():
|
||||
return {}
|
||||
return {
|
||||
'free': torch.cuda.mem_get_info()[0] / 1024**3,
|
||||
'peak': torch.cuda.max_memory_allocated() / 1024**3
|
||||
}
|
||||
|
||||
def monitor_memory(
|
||||
threshold_gb: float = 2.0,
|
||||
track_peak: bool = True
|
||||
warning_threshold_gb: float = 2.0,
|
||||
track_stats: bool = True,
|
||||
cleanup_on_warning: bool = True,
|
||||
) -> Callable:
|
||||
"""Decorator for monitoring memory in critical paths.
|
||||
|
||||
Designed specifically for multi-modal operations where
|
||||
memory usage can spike during modality fusion.
|
||||
"""Memory monitoring decorator for CUDA operations.
|
||||
|
||||
Args:
|
||||
warning_threshold_gb: Memory threshold in GB to trigger warnings
|
||||
track_stats: Whether to track and print memory statistics
|
||||
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached
|
||||
|
||||
Returns:
|
||||
Decorator function that monitors memory usage
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
if not torch.cuda.is_available():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Track initial state
|
||||
# Get initial memory state
|
||||
free_before = torch.cuda.mem_get_info()[0] / 1024**3
|
||||
|
||||
try:
|
||||
if free_before < threshold_gb:
|
||||
# Check memory state and cleanup if needed
|
||||
if free_before < warning_threshold_gb and cleanup_on_warning:
|
||||
torch.cuda.empty_cache()
|
||||
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
|
||||
if free_after_cleanup < threshold_gb:
|
||||
|
||||
if free_after_cleanup < warning_threshold_gb:
|
||||
warnings.warn(
|
||||
f"Critical memory state in {func.__name__}: "
|
||||
f"{free_after_cleanup:.2f}GB free"
|
||||
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if track_peak:
|
||||
# Track memory statistics if enabled
|
||||
if track_stats:
|
||||
peak = torch.cuda.max_memory_allocated() / 1024**3
|
||||
free_after = torch.cuda.mem_get_info()[0] / 1024**3
|
||||
print(
|
||||
f"Memory stats for {func.__name__}:\n"
|
||||
f"Peak usage: {peak:.2f}GB\n"
|
||||
f"Memory delta: {free_before - free_after:.2f}GB"
|
||||
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
return result
|
||||
|
||||
@ -95,13 +58,11 @@ def monitor_memory(
|
||||
if "out of memory" in str(e):
|
||||
free = torch.cuda.mem_get_info()[0] / 1024**3
|
||||
raise RuntimeError(
|
||||
f"OOM in {func.__name__}. Free memory: {free:.2f}GB\n"
|
||||
f"Consider reducing batch size or image resolution"
|
||||
f"OOM in {func.__name__} with {free:.2f}GB free. "
|
||||
"Consider reducing batch size or image resolution."
|
||||
) from e
|
||||
raise
|
||||
finally:
|
||||
if track_peak:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
return decorator
|
||||
|
Loading…
Reference in New Issue
Block a user