mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-20 02:28:58 -04:00
Merge edd64d88f5
into 1daa72fa40
This commit is contained in:
commit
f538b1609e
@ -21,6 +21,9 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -51,6 +54,7 @@ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
|||||||
prompt = sft_format + vl_chat_processor.image_start_tag
|
prompt = sft_format + vl_chat_processor.image_start_tag
|
||||||
|
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
mmgpt: MultiModalityCausalLM,
|
mmgpt: MultiModalityCausalLM,
|
||||||
@ -66,36 +70,46 @@ def generate(
|
|||||||
input_ids = vl_chat_processor.tokenizer.encode(prompt)
|
input_ids = vl_chat_processor.tokenizer.encode(prompt)
|
||||||
input_ids = torch.LongTensor(input_ids)
|
input_ids = torch.LongTensor(input_ids)
|
||||||
|
|
||||||
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
|
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
|
||||||
for i in range(parallel_size*2):
|
for i in range(parallel_size * 2):
|
||||||
tokens[i, :] = input_ids
|
tokens[i, :] = input_ids
|
||||||
if i % 2 != 0:
|
if i % 2 != 0:
|
||||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||||
|
|
||||||
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
|
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
generated_tokens = torch.zeros(
|
||||||
|
(parallel_size, image_token_num_per_image), dtype=torch.int
|
||||||
|
).cuda()
|
||||||
|
|
||||||
for i in range(image_token_num_per_image):
|
for i in range(image_token_num_per_image):
|
||||||
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
|
outputs = mmgpt.language_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=outputs.past_key_values if i != 0 else None,
|
||||||
|
)
|
||||||
hidden_states = outputs.last_hidden_state
|
hidden_states = outputs.last_hidden_state
|
||||||
|
|
||||||
logits = mmgpt.gen_head(hidden_states[:, -1, :])
|
logits = mmgpt.gen_head(hidden_states[:, -1, :])
|
||||||
logit_cond = logits[0::2, :]
|
logit_cond = logits[0::2, :]
|
||||||
logit_uncond = logits[1::2, :]
|
logit_uncond = logits[1::2, :]
|
||||||
|
|
||||||
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
||||||
probs = torch.softmax(logits / temperature, dim=-1)
|
probs = torch.softmax(logits / temperature, dim=-1)
|
||||||
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
||||||
|
|
||||||
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
next_token = torch.cat(
|
||||||
|
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
|
||||||
|
).view(-1)
|
||||||
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
|
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
|
||||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||||
|
|
||||||
|
dec = mmgpt.gen_vision_model.decode_code(
|
||||||
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
|
generated_tokens.to(dtype=torch.int),
|
||||||
|
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
|
||||||
|
)
|
||||||
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
||||||
@ -103,9 +117,9 @@ def generate(
|
|||||||
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
||||||
visual_img[:, :, :] = dec
|
visual_img[:, :, :] = dec
|
||||||
|
|
||||||
os.makedirs('generated_samples', exist_ok=True)
|
os.makedirs("generated_samples", exist_ok=True)
|
||||||
for i in range(parallel_size):
|
for i in range(parallel_size):
|
||||||
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
|
save_path = os.path.join("generated_samples", "img_{}.jpg".format(i))
|
||||||
PIL.Image.fromarray(visual_img[i]).save(save_path)
|
PIL.Image.fromarray(visual_img[i]).save(save_path)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +25,9 @@ import torchvision.transforms
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from janus.janusflow.models.siglip_vit import create_siglip_vit
|
from janus.janusflow.models.siglip_vit import create_siglip_vit
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionTower(nn.Module):
|
class CLIPVisionTower(nn.Module):
|
||||||
@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module):
|
|||||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from transformers import (
|
|||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
from janus.janusflow.models.clip_encoder import CLIPVisionTower
|
from janus.janusflow.models.clip_encoder import CLIPVisionTower
|
||||||
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
|
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
|
||||||
|
from janus.utils.cuda_memory_manager import monitor_memory
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
@ -168,6 +169,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
|
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def prepare_inputs_embeds(
|
def prepare_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
|
@ -27,6 +27,7 @@ from transformers.processing_utils import ProcessorMixin
|
|||||||
|
|
||||||
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
|
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
|
||||||
from janus.utils.conversation import get_conv_template
|
from janus.utils.conversation import get_conv_template
|
||||||
|
from janus.utils.cuda_memory_manager import monitor_memory
|
||||||
|
|
||||||
|
|
||||||
class DictOutput(object):
|
class DictOutput(object):
|
||||||
@ -384,6 +385,7 @@ class VLChatProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return prepare
|
return prepare
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def batchify(
|
def batchify(
|
||||||
self, prepare_list: List[VLChatProcessorOutput]
|
self, prepare_list: List[VLChatProcessorOutput]
|
||||||
) -> BatchedVLChatProcessorOutput:
|
) -> BatchedVLChatProcessorOutput:
|
||||||
|
@ -25,6 +25,9 @@ import torchvision.transforms
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from janus.models.siglip_vit import create_siglip_vit
|
from janus.models.siglip_vit import create_siglip_vit
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionTower(nn.Module):
|
class CLIPVisionTower(nn.Module):
|
||||||
@ -104,6 +107,7 @@ class CLIPVisionTower(nn.Module):
|
|||||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -31,6 +31,9 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
|
|
||||||
from janus.models.clip_encoder import CLIPVisionTower
|
from janus.models.clip_encoder import CLIPVisionTower
|
||||||
from janus.models.projector import MlpProjector
|
from janus.models.projector import MlpProjector
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class vision_head(torch.nn.Module):
|
class vision_head(torch.nn.Module):
|
||||||
@ -218,6 +221,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
language_config = config.language_config
|
language_config = config.language_config
|
||||||
self.language_model = LlamaForCausalLM(language_config)
|
self.language_model = LlamaForCausalLM(language_config)
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def prepare_inputs_embeds(
|
def prepare_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@ -259,6 +263,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||||
return self.gen_aligner(self.gen_embed(image_ids))
|
return self.gen_aligner(self.gen_embed(image_ids))
|
||||||
|
|
||||||
|
@ -27,6 +27,9 @@ from transformers.processing_utils import ProcessorMixin
|
|||||||
|
|
||||||
from janus.models.image_processing_vlm import VLMImageProcessor
|
from janus.models.image_processing_vlm import VLMImageProcessor
|
||||||
from janus.utils.conversation import get_conv_template
|
from janus.utils.conversation import get_conv_template
|
||||||
|
from janus.utils.cuda_memory_manager import (
|
||||||
|
monitor_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DictOutput(object):
|
class DictOutput(object):
|
||||||
@ -354,6 +357,7 @@ class VLChatProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return prepare
|
return prepare
|
||||||
|
|
||||||
|
@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
|
||||||
def batchify(
|
def batchify(
|
||||||
self, prepare_list: List[VLChatProcessorOutput]
|
self, prepare_list: List[VLChatProcessorOutput]
|
||||||
) -> BatchedVLChatProcessorOutput:
|
) -> BatchedVLChatProcessorOutput:
|
||||||
|
68
janus/utils/cuda_memory_manager.py
Normal file
68
janus/utils/cuda_memory_manager.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, Any
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_memory(
|
||||||
|
warning_threshold_gb: float = 2.0,
|
||||||
|
track_stats: bool = True,
|
||||||
|
cleanup_on_warning: bool = True,
|
||||||
|
) -> Callable:
|
||||||
|
"""Memory monitoring decorator for CUDA operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
warning_threshold_gb: Memory threshold in GB to trigger warnings
|
||||||
|
track_stats: Whether to track and print memory statistics
|
||||||
|
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that monitors memory usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Get initial memory state
|
||||||
|
free_before = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check memory state and cleanup if needed
|
||||||
|
if free_before < warning_threshold_gb and cleanup_on_warning:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
|
||||||
|
if free_after_cleanup < warning_threshold_gb:
|
||||||
|
warnings.warn(
|
||||||
|
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Track memory statistics if enabled
|
||||||
|
if track_stats:
|
||||||
|
peak = torch.cuda.max_memory_allocated() / 1024**3
|
||||||
|
free_after = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
print(
|
||||||
|
f"Memory stats for {func.__name__}:\n"
|
||||||
|
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
|
||||||
|
)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "out of memory" in str(e):
|
||||||
|
free = torch.cuda.mem_get_info()[0] / 1024**3
|
||||||
|
raise RuntimeError(
|
||||||
|
f"OOM in {func.__name__} with {free:.2f}GB free. "
|
||||||
|
"Consider reducing batch size or image resolution."
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
Loading…
Reference in New Issue
Block a user