support for Scala

This commit is contained in:
sherwin6180 2025-06-26 00:27:13 -04:00
parent 7240b22b56
commit bc50a43521
5 changed files with 60 additions and 1 deletions

View File

@ -1,4 +1,4 @@
MODEL_NAME_OR_PATH="/scratch/shared_dir/xinyu/deepseek-1.3b"
DATASET_ROOT="data/"
LANGUAGE="rust"
LANGUAGE="scala"
CUDA_VISIBLE_DEVICES=1,2,3 python -m accelerate.commands.launch --config_file test_config.yaml eval_pal.py --logdir ${MODEL_NAME_OR_PATH} --language ${LANGUAGE} --dataroot ${DATASET_ROOT}

View File

@ -174,6 +174,8 @@ def process_humaneval_test(sample, problems, example_test=False, is_mbpp=False,
if code[:5] != "<?php":
code = "<?php\n" + code
test_string = code + "\n" + test + "?>"
elif language == "scala":
test_string = code + "\n" + test
return test_string

View File

@ -34,6 +34,14 @@ def check_correctness(
"""
def unsafe_execute(tmp_dir):
import os
import shutil
import tempfile
import random
import subprocess
from contextlib import redirect_stderr, redirect_stdout, suppress
from .execution import time_limit, swallow_io, create_tempdir, reliability_guard, TimeoutException
random_id = random.randint(1, 100000)
if "python" in language_type.lower():
with create_tempdir():
@ -546,6 +554,48 @@ def check_correctness(
os.chdir(origin_path)
shutil.rmtree(tmp_dir)
elif "scala" in language_type.lower():
tmp_dir_scala = os.path.join(tempfile.gettempdir(), f"scala-eval-{random.randint(1, 100000)}")
os.makedirs(tmp_dir_scala, exist_ok=True)
file_path = os.path.join(tmp_dir_scala, "Problem.scala")
try:
with open(file_path, "w", encoding="utf-8") as f:
f.write(sample["test_code"])
compile_result = subprocess.run(
["scalac", file_path],
cwd=tmp_dir_scala,
timeout=30.0,
capture_output=True
)
if compile_result.returncode != 0:
error_output = compile_result.stderr.decode("utf-8", "ignore")
result.append(f"failed: compilation error: {error_output}")
else:
run_result = subprocess.run(
["scala", "-cp", ".", "Problem"],
cwd=tmp_dir_scala,
timeout=timeout,
capture_output=True
)
if run_result.returncode == 0:
result.append("passed")
else:
error_output = run_result.stderr.decode("utf-8", "ignore")
result.append(f"failed: {error_output}")
except subprocess.TimeoutExpired:
result.append("timed out")
except Exception as e:
result.append(f"failed: {e}")
finally:
if os.path.exists(tmp_dir_scala):
shutil.rmtree(tmp_dir_scala)
manager = multiprocessing.Manager()
result = manager.list()

View File

@ -35,6 +35,10 @@ languge_settings = {
'sh': {
'full_name': "Bash",
'indent': 0
},
'scala': {
'full_name': "Scala",
'indent': 4,
}
}
@ -122,6 +126,9 @@ def cleanup_code(
code = _truncate_code_at_stopwords(code, stop_words)
elif language_type.lower() == "ts":
code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
elif language_type.lower() == "scala":
stop_words = stop_words + ["\nobject ", "\nclass ", "\n/**"]
code = _truncate_code_at_stopwords(code, stop_words)
else:
code = _truncate_code_at_stopwords(code, stop_words)