mirror of
https://github.com/deepseek-ai/DeepSeek-Coder.git
synced 2025-07-10 19:28:55 -04:00
support for Scala
This commit is contained in:
parent
7240b22b56
commit
bc50a43521
@ -1,4 +1,4 @@
|
|||||||
MODEL_NAME_OR_PATH="/scratch/shared_dir/xinyu/deepseek-1.3b"
|
MODEL_NAME_OR_PATH="/scratch/shared_dir/xinyu/deepseek-1.3b"
|
||||||
DATASET_ROOT="data/"
|
DATASET_ROOT="data/"
|
||||||
LANGUAGE="rust"
|
LANGUAGE="scala"
|
||||||
CUDA_VISIBLE_DEVICES=1,2,3 python -m accelerate.commands.launch --config_file test_config.yaml eval_pal.py --logdir ${MODEL_NAME_OR_PATH} --language ${LANGUAGE} --dataroot ${DATASET_ROOT}
|
CUDA_VISIBLE_DEVICES=1,2,3 python -m accelerate.commands.launch --config_file test_config.yaml eval_pal.py --logdir ${MODEL_NAME_OR_PATH} --language ${LANGUAGE} --dataroot ${DATASET_ROOT}
|
||||||
|
@ -174,6 +174,8 @@ def process_humaneval_test(sample, problems, example_test=False, is_mbpp=False,
|
|||||||
if code[:5] != "<?php":
|
if code[:5] != "<?php":
|
||||||
code = "<?php\n" + code
|
code = "<?php\n" + code
|
||||||
test_string = code + "\n" + test + "?>"
|
test_string = code + "\n" + test + "?>"
|
||||||
|
elif language == "scala":
|
||||||
|
test_string = code + "\n" + test
|
||||||
return test_string
|
return test_string
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +34,14 @@ def check_correctness(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def unsafe_execute(tmp_dir):
|
def unsafe_execute(tmp_dir):
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
from contextlib import redirect_stderr, redirect_stdout, suppress
|
||||||
|
from .execution import time_limit, swallow_io, create_tempdir, reliability_guard, TimeoutException
|
||||||
|
|
||||||
random_id = random.randint(1, 100000)
|
random_id = random.randint(1, 100000)
|
||||||
if "python" in language_type.lower():
|
if "python" in language_type.lower():
|
||||||
with create_tempdir():
|
with create_tempdir():
|
||||||
@ -546,6 +554,48 @@ def check_correctness(
|
|||||||
os.chdir(origin_path)
|
os.chdir(origin_path)
|
||||||
shutil.rmtree(tmp_dir)
|
shutil.rmtree(tmp_dir)
|
||||||
|
|
||||||
|
elif "scala" in language_type.lower():
|
||||||
|
tmp_dir_scala = os.path.join(tempfile.gettempdir(), f"scala-eval-{random.randint(1, 100000)}")
|
||||||
|
os.makedirs(tmp_dir_scala, exist_ok=True)
|
||||||
|
|
||||||
|
file_path = os.path.join(tmp_dir_scala, "Problem.scala")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(sample["test_code"])
|
||||||
|
|
||||||
|
compile_result = subprocess.run(
|
||||||
|
["scalac", file_path],
|
||||||
|
cwd=tmp_dir_scala,
|
||||||
|
timeout=30.0,
|
||||||
|
capture_output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if compile_result.returncode != 0:
|
||||||
|
error_output = compile_result.stderr.decode("utf-8", "ignore")
|
||||||
|
result.append(f"failed: compilation error: {error_output}")
|
||||||
|
else:
|
||||||
|
run_result = subprocess.run(
|
||||||
|
["scala", "-cp", ".", "Problem"],
|
||||||
|
cwd=tmp_dir_scala,
|
||||||
|
timeout=timeout,
|
||||||
|
capture_output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if run_result.returncode == 0:
|
||||||
|
result.append("passed")
|
||||||
|
else:
|
||||||
|
error_output = run_result.stderr.decode("utf-8", "ignore")
|
||||||
|
result.append(f"failed: {error_output}")
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
result.append("timed out")
|
||||||
|
except Exception as e:
|
||||||
|
result.append(f"failed: {e}")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_dir_scala):
|
||||||
|
shutil.rmtree(tmp_dir_scala)
|
||||||
|
|
||||||
manager = multiprocessing.Manager()
|
manager = multiprocessing.Manager()
|
||||||
result = manager.list()
|
result = manager.list()
|
||||||
|
|
||||||
|
Binary file not shown.
@ -35,6 +35,10 @@ languge_settings = {
|
|||||||
'sh': {
|
'sh': {
|
||||||
'full_name': "Bash",
|
'full_name': "Bash",
|
||||||
'indent': 0
|
'indent': 0
|
||||||
|
},
|
||||||
|
'scala': {
|
||||||
|
'full_name': "Scala",
|
||||||
|
'indent': 4,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,6 +126,9 @@ def cleanup_code(
|
|||||||
code = _truncate_code_at_stopwords(code, stop_words)
|
code = _truncate_code_at_stopwords(code, stop_words)
|
||||||
elif language_type.lower() == "ts":
|
elif language_type.lower() == "ts":
|
||||||
code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
|
code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
|
||||||
|
elif language_type.lower() == "scala":
|
||||||
|
stop_words = stop_words + ["\nobject ", "\nclass ", "\n/**"]
|
||||||
|
code = _truncate_code_at_stopwords(code, stop_words)
|
||||||
else:
|
else:
|
||||||
code = _truncate_code_at_stopwords(code, stop_words)
|
code = _truncate_code_at_stopwords(code, stop_words)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user