Update README.md

This commit is contained in:
ZHU QIHAO 2023-10-27 15:19:27 +08:00 committed by GitHub
parent 108504e6a0
commit c6e311e9d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -57,9 +57,8 @@ inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
This code will output the following result
```python
This code will output the following result:
```
def quick_sort(arr):
if len(arr) <= 1:
return arr
@ -81,7 +80,7 @@ import torch
tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True)
device = 2 if torch.cuda.is_available() else -1
model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device)
inputtext = """<fim_prefix>def quick_sort(arr):
input_text = """<fim_prefix>def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[0]
@ -93,18 +92,22 @@ inputtext = """<fim_prefix>def quick_sort(arr):
else:
right.append(arr[i])
return quick_sort(left) + [pivot] + quick_sort(right)<fim_suffix>"""
inputs = tokenizer(inputtext, return_tensors="pt").to(device)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(inputtext):])
print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_text):])
```
This code will output the following result:
```
This code will output the following result
```python
for i in range(1, len(arr)):
```
#### Repository Level Code Completion
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b")
device = 2 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device)
input_text = """#utils.py
import torch
from sklearn import datasets
@ -179,8 +182,8 @@ from model import IrisClassifier as Classifier
def main():
# Model training and evaluation
"""
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=140)
print(tokenizer.decode(outputs[0]))
```