This commit is contained in:
Elkana Baris 2025-02-01 01:26:27 -08:00 committed by GitHub
commit f538b1609e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 116 additions and 13 deletions

View File

@ -21,6 +21,9 @@ import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.cuda_memory_manager import (
monitor_memory,
)
import numpy as np import numpy as np
import os import os
import PIL.Image import PIL.Image
@ -51,6 +54,7 @@ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
prompt = sft_format + vl_chat_processor.image_start_tag prompt = sft_format + vl_chat_processor.image_start_tag
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
mmgpt: MultiModalityCausalLM, mmgpt: MultiModalityCausalLM,
@ -66,36 +70,46 @@ def generate(
input_ids = vl_chat_processor.tokenizer.encode(prompt) input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids) input_ids = torch.LongTensor(input_ids)
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size*2): for i in range(parallel_size * 2):
tokens[i, :] = input_ids tokens[i, :] = input_ids
if i % 2 != 0: if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() generated_tokens = torch.zeros(
(parallel_size, image_token_num_per_image), dtype=torch.int
).cuda()
for i in range(image_token_num_per_image): for i in range(image_token_num_per_image):
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None) outputs = mmgpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
)
hidden_states = outputs.last_hidden_state hidden_states = outputs.last_hidden_state
logits = mmgpt.gen_head(hidden_states[:, -1, :]) logits = mmgpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :] logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :] logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1) probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1) generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) next_token = torch.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
).view(-1)
img_embeds = mmgpt.prepare_gen_img_embeds(next_token) img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1) inputs_embeds = img_embeds.unsqueeze(dim=1)
dec = mmgpt.gen_vision_model.decode_code(
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]) generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
)
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255) dec = np.clip((dec + 1) / 2 * 255, 0, 255)
@ -103,9 +117,9 @@ def generate(
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
visual_img[:, :, :] = dec visual_img[:, :, :] = dec
os.makedirs('generated_samples', exist_ok=True) os.makedirs("generated_samples", exist_ok=True)
for i in range(parallel_size): for i in range(parallel_size):
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i)) save_path = os.path.join("generated_samples", "img_{}.jpg".format(i))
PIL.Image.fromarray(visual_img[i]).save(save_path) PIL.Image.fromarray(visual_img[i]).save(save_path)

View File

@ -25,6 +25,9 @@ import torchvision.transforms
from einops import rearrange from einops import rearrange
from janus.janusflow.models.siglip_vit import create_siglip_vit from janus.janusflow.models.siglip_vit import create_siglip_vit
from janus.utils.cuda_memory_manager import (
monitor_memory,
)
class CLIPVisionTower(nn.Module): class CLIPVisionTower(nn.Module):
@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module):
raise ValueError(f"Unexpected select feature: {self.select_feature}") raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features return image_features
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def forward(self, images): def forward(self, images):
""" """

View File

@ -31,6 +31,7 @@ from transformers import (
from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaRMSNorm
from janus.janusflow.models.clip_encoder import CLIPVisionTower from janus.janusflow.models.clip_encoder import CLIPVisionTower
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
from janus.utils.cuda_memory_manager import monitor_memory
import torch.nn as nn import torch.nn as nn
@ -168,6 +169,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
) )
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True) self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def prepare_inputs_embeds( def prepare_inputs_embeds(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,

View File

@ -27,6 +27,7 @@ from transformers.processing_utils import ProcessorMixin
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
from janus.utils.conversation import get_conv_template from janus.utils.conversation import get_conv_template
from janus.utils.cuda_memory_manager import monitor_memory
class DictOutput(object): class DictOutput(object):
@ -384,6 +385,7 @@ class VLChatProcessor(ProcessorMixin):
return prepare return prepare
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def batchify( def batchify(
self, prepare_list: List[VLChatProcessorOutput] self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput: ) -> BatchedVLChatProcessorOutput:

View File

@ -25,6 +25,9 @@ import torchvision.transforms
from einops import rearrange from einops import rearrange
from janus.models.siglip_vit import create_siglip_vit from janus.models.siglip_vit import create_siglip_vit
from janus.utils.cuda_memory_manager import (
monitor_memory,
)
class CLIPVisionTower(nn.Module): class CLIPVisionTower(nn.Module):
@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module):
raise ValueError(f"Unexpected select feature: {self.select_feature}") raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features return image_features
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def forward(self, images): def forward(self, images):
""" """

View File

@ -31,6 +31,9 @@ from transformers.configuration_utils import PretrainedConfig
from janus.models.clip_encoder import CLIPVisionTower from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector from janus.models.projector import MlpProjector
from janus.utils.cuda_memory_manager import (
monitor_memory,
)
class vision_head(torch.nn.Module): class vision_head(torch.nn.Module):
@ -218,6 +221,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
language_config = config.language_config language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config) self.language_model = LlamaForCausalLM(language_config)
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def prepare_inputs_embeds( def prepare_inputs_embeds(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
@ -259,6 +263,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
return inputs_embeds return inputs_embeds
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids)) return self.gen_aligner(self.gen_embed(image_ids))

View File

@ -27,6 +27,9 @@ from transformers.processing_utils import ProcessorMixin
from janus.models.image_processing_vlm import VLMImageProcessor from janus.models.image_processing_vlm import VLMImageProcessor
from janus.utils.conversation import get_conv_template from janus.utils.conversation import get_conv_template
from janus.utils.cuda_memory_manager import (
monitor_memory,
)
class DictOutput(object): class DictOutput(object):
@ -354,6 +357,7 @@ class VLChatProcessor(ProcessorMixin):
return prepare return prepare
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def batchify( def batchify(
self, prepare_list: List[VLChatProcessorOutput] self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput: ) -> BatchedVLChatProcessorOutput:

View File

@ -0,0 +1,68 @@
from functools import wraps
from typing import Callable, Any
import torch
import warnings
def monitor_memory(
warning_threshold_gb: float = 2.0,
track_stats: bool = True,
cleanup_on_warning: bool = True,
) -> Callable:
"""Memory monitoring decorator for CUDA operations.
Args:
warning_threshold_gb: Memory threshold in GB to trigger warnings
track_stats: Whether to track and print memory statistics
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached
Returns:
Decorator function that monitors memory usage
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
if not torch.cuda.is_available():
return func(*args, **kwargs)
# Get initial memory state
free_before = torch.cuda.mem_get_info()[0] / 1024**3
try:
# Check memory state and cleanup if needed
if free_before < warning_threshold_gb and cleanup_on_warning:
torch.cuda.empty_cache()
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
if free_after_cleanup < warning_threshold_gb:
warnings.warn(
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
)
result = func(*args, **kwargs)
# Track memory statistics if enabled
if track_stats:
peak = torch.cuda.max_memory_allocated() / 1024**3
free_after = torch.cuda.mem_get_info()[0] / 1024**3
print(
f"Memory stats for {func.__name__}:\n"
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
)
torch.cuda.reset_peak_memory_stats()
return result
except RuntimeError as e:
if "out of memory" in str(e):
free = torch.cuda.mem_get_info()[0] / 1024**3
raise RuntimeError(
f"OOM in {func.__name__} with {free:.2f}GB free. "
"Consider reducing batch size or image resolution."
) from e
raise
return wrapper
return decorator