Update app.py

This commit is contained in:
Phani-kp 2025-02-21 22:31:12 -06:00 committed by GitHub
parent ae27c2cd96
commit dd2f5c83d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,13 +21,11 @@ This Space demonstrates model [DeepSeek-Coder](https://huggingface.co/deepseek-a
if not torch.cuda.is_available(): if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available(): if torch.cuda.is_available():
model_id = "deepseek-ai/deepseek-coder-6.7b-instruct" model_id = "deepseek-ai/deepseek-coder-6.7b-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False tokenizer.use_default_system_prompt = False
@spaces.GPU @spaces.GPU
@ -56,11 +54,12 @@ def generate(
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict( generate_kwargs = dict(
{"input_ids": input_ids}, input_ids=input_ids,
streamer=streamer, streamer=streamer,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
do_sample=False, temperature=temperature,
num_beams=1, top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id eos_token_id=tokenizer.eos_token_id
) )
@ -70,7 +69,7 @@ def generate(
outputs = [] outputs = []
for text in streamer: for text in streamer:
outputs.append(text) outputs.append(text)
yield "".join(outputs).replace("<|EOT|>","") yield "".join(outputs).replace("<|EOT|>", "")
chat_interface = gr.ChatInterface( chat_interface = gr.ChatInterface(
@ -84,13 +83,13 @@ chat_interface = gr.ChatInterface(
step=1, step=1,
value=DEFAULT_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS,
), ),
# gr.Slider( gr.Slider(
# label="Temperature", label="Temperature",
# minimum=0, minimum=0,
# maximum=4.0, maximum=4.0,
# step=0.1, step=0.1,
# value=0, value=0.6,
# ), ),
gr.Slider( gr.Slider(
label="Top-p (nucleus sampling)", label="Top-p (nucleus sampling)",
minimum=0.05, minimum=0.05,