diff --git a/demo/app_janusflow.py b/demo/app_janusflow.py index 4777196..736359e 100644 --- a/demo/app_janusflow.py +++ b/demo/app_janusflow.py @@ -106,22 +106,17 @@ def generate( # input to the llm # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked. if step == 0: - outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, - use_cache=True, - attention_mask=attention_mask, - past_key_values=None) - past_key_values = [] - for kv_cache in past_key_values: - k, v = kv_cache[0], kv_cache[1] - past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :])) - past_key_values = tuple(past_key_values) + past_key_values = None # Ensure it starts as None else: - outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, - use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + past_key_values = tuple(past_key_values) if past_key_values else None # Convert only if it's valid + + outputs = vl_gpt.language_model.model( + inputs_embeds=llm_emb, + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values # Now correctly assigned + ) hidden_states = outputs.last_hidden_state - # transform hidden_states back to v hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :])) hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)