fix: image generate in demo/app_janusflow.py

This commit is contained in:
yangmeng 2025-01-28 23:43:18 +08:00
parent 17bc41245a
commit 877c778c0e

View File

@ -117,8 +117,7 @@ def generate(
if step == 0: if step == 0:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True, use_cache=True,
attention_mask=attention_mask, attention_mask=attention_mask)
past_key_values=None)
past_key_values = [] past_key_values = []
for kv_cache in past_key_values: for kv_cache in past_key_values:
k, v = kv_cache[0], kv_cache[1] k, v = kv_cache[0], kv_cache[1]
@ -127,8 +126,7 @@ def generate(
else: else:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True, use_cache=True,
attention_mask=attention_mask, attention_mask=attention_mask)
past_key_values=past_key_values)
hidden_states = outputs.last_hidden_state hidden_states = outputs.last_hidden_state
# transform hidden_states back to v # transform hidden_states back to v