From 25466b5f408655d4d99e08e8b42b087e60af2f6f Mon Sep 17 00:00:00 2001
From: scifisatan <brawlershere@gmail.com>
Date: Tue, 4 Feb 2025 20:04:48 +0545
Subject: [PATCH] fix: streamline past_key_values handling in generate function

---
 demo/app_janusflow.py | 23 +++++++++--------------
 1 file changed, 9 insertions(+), 14 deletions(-)

diff --git a/demo/app_janusflow.py b/demo/app_janusflow.py
index 4777196..736359e 100644
--- a/demo/app_janusflow.py
+++ b/demo/app_janusflow.py
@@ -106,22 +106,17 @@ def generate(
         # input to the llm
         # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
         if step == 0:
-            outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, 
-                                             use_cache=True, 
-                                             attention_mask=attention_mask,
-                                             past_key_values=None)
-            past_key_values = []
-            for kv_cache in past_key_values:
-                k, v = kv_cache[0], kv_cache[1]
-                past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
-            past_key_values = tuple(past_key_values)
+            past_key_values = None  # Ensure it starts as None
         else:
-            outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, 
-                                             use_cache=True, 
-                                             attention_mask=attention_mask,
-                                             past_key_values=past_key_values)
+            past_key_values = tuple(past_key_values) if past_key_values else None  # Convert only if it's valid
+
+        outputs = vl_gpt.language_model.model(
+            inputs_embeds=llm_emb, 
+            use_cache=True, 
+            attention_mask=attention_mask,
+            past_key_values=past_key_values  # Now correctly assigned
+        )
         hidden_states = outputs.last_hidden_state
-        
         # transform hidden_states back to v
         hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
         hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)