diff --git a/README.md b/README.md index 43d8b75..a578abe 100644 --- a/README.md +++ b/README.md @@ -52,13 +52,13 @@ import torch tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True) device = 2 if torch.cuda.is_available() else -1 model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device) -inputs = tokenizer("#write a quick sort algorithm", return_tensors="pt").to(device) +input_text = "#write a quick sort algorithm" +inputs = tokenizer(input_text, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_length=128) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` -This code will output +This code will output the following result ```python -#write a quick sort algorithm def quick_sort(arr): if len(arr) <= 1: @@ -77,13 +77,29 @@ def quick_sort(arr): #### Code Insertion ```python from transformers import AutoTokenizer, AutoModelForCausalLM -tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b") -device = 0 if torch.cuda.is_available() else -1 -model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b").to(device) -input_text = "def print_hello_world():\n \n print('Hello world!')" -inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) +import torch +tokenizer = AutoTokenizer.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True) +device = 2 if torch.cuda.is_available() else -1 +model = AutoModelForCausalLM.from_pretrained("deepseek/deepseek-coder-7b-base", trust_remote_code=True).to(device) +inputtext = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] + + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" +inputs = tokenizer(inputtext, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_length=128) -print(tokenizer.decode(outputs[0])) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(inputtext):]) +``` +This code will output the following result +```python + for i in range(1, len(arr)): ``` #### Repository Level Code Completion ```python