diff --git a/demo/app_janusflow.py b/demo/app_janusflow.py index 2caf9a9..1836b99 100644 --- a/demo/app_janusflow.py +++ b/demo/app_janusflow.py @@ -117,8 +117,7 @@ def generate( if step == 0: outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, use_cache=True, - attention_mask=attention_mask, - past_key_values=None) + attention_mask=attention_mask) past_key_values = [] for kv_cache in past_key_values: k, v = kv_cache[0], kv_cache[1] @@ -127,8 +126,7 @@ def generate( else: outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + attention_mask=attention_mask) hidden_states = outputs.last_hidden_state # transform hidden_states back to v