fix: streamline past_key_values handling in generate function

This commit is contained in:
scifisatan 2025-02-04 20:04:48 +05:45
parent 1daa72fa40
commit 25466b5f40

View File

@ -106,22 +106,17 @@ def generate(
# input to the llm
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
if step == 0:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=None)
past_key_values = []
for kv_cache in past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
past_key_values = tuple(past_key_values)
past_key_values = None # Ensure it starts as None
else:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
past_key_values = tuple(past_key_values) if past_key_values else None # Convert only if it's valid
outputs = vl_gpt.language_model.model(
inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values # Now correctly assigned
)
hidden_states = outputs.last_hidden_state
# transform hidden_states back to v
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)