mirror of
https://github.com/deepseek-ai/DeepSeek-VL2.git
synced 2025-04-19 18:19:03 -04:00
Merge cef3ab5e28
into ef9f91e2b6
This commit is contained in:
commit
93bb5022bb
@ -146,6 +146,7 @@ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
|||||||
|
|
||||||
# run the model to get the response
|
# run the model to get the response
|
||||||
outputs = vl_gpt.language.generate(
|
outputs = vl_gpt.language.generate(
|
||||||
|
input_ids = prepare_inputs["input_ids"].to(vl_gpt.device),
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
attention_mask=prepare_inputs.attention_mask,
|
attention_mask=prepare_inputs.attention_mask,
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
@ -1742,13 +1742,16 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
|
|||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return (loss,) + output if loss is not None else output
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
output = CausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
output['logits'] = output['logits'].to(device)
|
||||||
|
return output
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user