mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-22 13:48:57 -05:00
fix: streamline past_key_values handling in generate function
This commit is contained in:
parent
1daa72fa40
commit
25466b5f40
@ -106,22 +106,17 @@ def generate(
|
||||
# input to the llm
|
||||
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
|
||||
if step == 0:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=None)
|
||||
past_key_values = []
|
||||
for kv_cache in past_key_values:
|
||||
k, v = kv_cache[0], kv_cache[1]
|
||||
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
|
||||
past_key_values = tuple(past_key_values)
|
||||
past_key_values = None # Ensure it starts as None
|
||||
else:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
past_key_values = tuple(past_key_values) if past_key_values else None # Convert only if it's valid
|
||||
|
||||
outputs = vl_gpt.language_model.model(
|
||||
inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values # Now correctly assigned
|
||||
)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# transform hidden_states back to v
|
||||
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
|
||||
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
|
||||
|
Loading…
Reference in New Issue
Block a user