Update README.md

This commit is contained in:
ZHU QIHAO 2023-10-27 14:16:44 +08:00 committed by GitHub
parent 5d772a085c
commit 108504e6a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,13 +52,13 @@ import torch
tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True)
device = 2 if torch.cuda.is_available() else -1 device = 2 if torch.cuda.is_available() else -1
model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device) model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device)
inputs = tokenizer("#write a quick sort algorithm", return_tensors="pt").to(device) input_text = "#write a quick sort algorithm"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128) outputs = model.generate(**inputs, max_length=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) print(tokenizer.decode(outputs[0], skip_special_tokens=True))
``` ```
This code will output This code will output the following result
```python ```python
#write a quick sort algorithm
def quick_sort(arr): def quick_sort(arr):
if len(arr) <= 1: if len(arr) <= 1:
@ -77,13 +77,29 @@ def quick_sort(arr):
#### Code Insertion #### Code Insertion
```python ```python
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b") import torch
device = 0 if torch.cuda.is_available() else -1 tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b").to(device) device = 2 if torch.cuda.is_available() else -1
input_text = "<fim_prefix>def print_hello_world():\n <fim_suffix>\n print('Hello world!')<fim_middle>" model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) inputtext = """<fim_prefix>def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[0]
left = []
right = []
<fim_middle>
if arr[i] < pivot:
left.append(arr[i])
else:
right.append(arr[i])
return quick_sort(left) + [pivot] + quick_sort(right)<fim_suffix>"""
inputs = tokenizer(inputtext, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128) outputs = model.generate(**inputs, max_length=128)
print(tokenizer.decode(outputs[0])) print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(inputtext):])
```
This code will output the following result
```python
for i in range(1, len(arr)):
``` ```
#### Repository Level Code Completion #### Repository Level Code Completion
```python ```python