ADDED: humaneval-x rust integration
This commit is contained in:
parent
394df4187d
commit
f66205b5f6
|
@ -0,0 +1,239 @@
|
|||
import os
|
||||
import sys
|
||||
import fire
|
||||
import json
|
||||
import gzip
|
||||
import regex
|
||||
import numpy as np
|
||||
|
||||
from typing import *
|
||||
from tqdm.auto import tqdm
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from codegeex.benchmark.utils import read_dataset, IMPORT_HELPER
|
||||
from codegeex.benchmark.metric import estimate_pass_at_k
|
||||
from codegeex.benchmark.execution import check_correctness
|
||||
|
||||
LANGUAGE_NAME = {
|
||||
"cpp" : "CPP",
|
||||
"go" : "Go",
|
||||
"java" : "Java",
|
||||
"js" : "JavaScript",
|
||||
"python": "Python",
|
||||
}
|
||||
|
||||
|
||||
def process_humaneval_test(sample, problems, example_test=False):
|
||||
task_id = sample["task_id"]
|
||||
language = task_id.split("/")[0].lower()
|
||||
|
||||
prompt = sample["prompt"]
|
||||
if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "":
|
||||
test = problems[task_id]["example_test"]
|
||||
else:
|
||||
test = problems[task_id]["test"]
|
||||
code = sample["generation"]
|
||||
|
||||
# Pre-process for different languages
|
||||
if language == "python":
|
||||
code_ = []
|
||||
for line in code.split("\n"):
|
||||
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'):
|
||||
break
|
||||
code_.append(line)
|
||||
code = "\n".join(code_)
|
||||
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
|
||||
test_string = test_setup + prompt + code + "\n" + test + "\n"
|
||||
elif language == "cpp":
|
||||
test_set_up = ""
|
||||
for s in IMPORT_HELPER["cpp"]:
|
||||
if s not in prompt:
|
||||
test_set_up += s + "\n"
|
||||
test_string = test_set_up + "\n" + prompt + code + "\n" + test
|
||||
elif language == "java":
|
||||
test_string = prompt + code + "\n" + test
|
||||
elif language == "js" or language == "javascript":
|
||||
test_string = prompt + code + "\n" + test
|
||||
elif language == "go":
|
||||
import_string = problems[task_id]["import"]
|
||||
prompt = prompt.replace(import_string, "")
|
||||
if example_test and "example_test" in problems[task_id]:
|
||||
test = problems[task_id]["example_test"]
|
||||
else:
|
||||
test = problems[task_id]["test"]
|
||||
test_setup = problems[task_id]["test_setup"]
|
||||
other_pkgs = []
|
||||
for pkg in IMPORT_HELPER["go"]:
|
||||
if pkg not in test_setup:
|
||||
p = pkg.split("/")[-1]
|
||||
if p + "." in code:
|
||||
other_pkgs.append(f"\"{pkg}\"")
|
||||
if other_pkgs:
|
||||
import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
|
||||
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
|
||||
else:
|
||||
test_string = test_setup + "\n" + prompt + code + "\n" + test
|
||||
elif language == "rust":
|
||||
main = "\nfn main(){ \n } \n"
|
||||
declaration = problems[task_id]["declaration"]
|
||||
test_string = main + declaration + prompt + code + test
|
||||
|
||||
return test_string
|
||||
|
||||
|
||||
def stream_jsonl_all(filename: str) -> Iterable[Dict]:
|
||||
results = []
|
||||
if filename.endswith(".gz"):
|
||||
fp = gzip.open(open(filename, "rb"), "rt")
|
||||
else:
|
||||
fp = open(filename, "r")
|
||||
for line in fp:
|
||||
if any(not x.isspace() for x in line):
|
||||
results.append(json.loads(line))
|
||||
fp.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def evaluate_functional_correctness(
|
||||
input_file: str = None,
|
||||
tmp_dir: str = "./",
|
||||
n_workers: int = 32,
|
||||
timeout: float = 500.0,
|
||||
problem_file: str = "../data/humaneval_python.jsonl.gz",
|
||||
out_dir: str = None,
|
||||
k: List[int] = [1, 10, 100],
|
||||
test_groundtruth: bool = False,
|
||||
example_test: bool = False,
|
||||
):
|
||||
if example_test:
|
||||
print("Example test...")
|
||||
|
||||
problems = read_dataset(problem_file,
|
||||
dataset_type="humaneval")
|
||||
sample_jsonl = stream_jsonl_all(input_file)
|
||||
|
||||
if example_test:
|
||||
suffix = "_example_test.jsonl"
|
||||
else:
|
||||
suffix = "_results.jsonl"
|
||||
if out_dir is not None:
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix))
|
||||
else:
|
||||
out_file = os.path.join(input_file.replace(".jsonl", suffix))
|
||||
|
||||
if "/codegeex/benchmark/humaneval-x/" in input_file:
|
||||
test_groundtruth = True
|
||||
|
||||
if "-to-" in input_file:
|
||||
translation_mode = True
|
||||
else:
|
||||
translation_mode = False
|
||||
|
||||
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
||||
|
||||
futures = []
|
||||
completion_id = Counter()
|
||||
n_samples = 0
|
||||
results = defaultdict(list)
|
||||
|
||||
if test_groundtruth:
|
||||
print("Testing ground truth...")
|
||||
for sample in tqdm(problems.values()):
|
||||
task_id = sample["task_id"]
|
||||
lang = task_id.split("/")[0].lower()
|
||||
if lang == "javascript":
|
||||
lang = "js"
|
||||
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
|
||||
sample["generation"] = sample["canonical_solution"]
|
||||
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
|
||||
if sample["test_code"] is None:
|
||||
continue
|
||||
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id])
|
||||
future = executor.submit(check_correctness, *args)
|
||||
futures.append(future)
|
||||
completion_id[task_id] += 1
|
||||
n_samples += 1
|
||||
else:
|
||||
print("Reading samples...")
|
||||
for sample in tqdm(sample_jsonl):
|
||||
task_id = sample["task_id"]
|
||||
lang = task_id.split("/")[0].lower()
|
||||
if translation_mode:
|
||||
task_id = sample["task_id"].split("/")[-1]
|
||||
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
|
||||
for l in LANGUAGE_NAME:
|
||||
if l in lang:
|
||||
lang = l
|
||||
break
|
||||
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
|
||||
if lang == "javascript":
|
||||
lang = "js"
|
||||
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
|
||||
sample["task_id"] = task_id
|
||||
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
|
||||
if sample["test_code"] is None:
|
||||
continue
|
||||
if "completion_id" in sample:
|
||||
completion_id_ = sample["completion_id"]
|
||||
else:
|
||||
completion_id_ = completion_id[task_id]
|
||||
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_)
|
||||
future = executor.submit(check_correctness, *args)
|
||||
futures.append(future)
|
||||
completion_id[task_id] += 1
|
||||
n_samples += 1
|
||||
|
||||
print(completion_id)
|
||||
if len(completion_id) == len(problems):
|
||||
evaluate_pass_at_k = True
|
||||
else:
|
||||
evaluate_pass_at_k = False
|
||||
|
||||
print("Running test suites...")
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
result = future.result()
|
||||
results[result["task_id"]].append((result["completion_id"], result))
|
||||
|
||||
# Calculate pass@k.
|
||||
total, correct = [], []
|
||||
for result in results.values():
|
||||
passed = [r[1]["passed"] for r in result]
|
||||
total.append(len(passed))
|
||||
correct.append(sum(passed))
|
||||
total = np.array(total)
|
||||
correct = np.array(correct)
|
||||
if evaluate_pass_at_k:
|
||||
ks = k
|
||||
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
|
||||
for k in ks if (total >= k).all()}
|
||||
print(pass_at_k)
|
||||
else:
|
||||
print("Total:", np.sum(total))
|
||||
print("Correct:", np.sum(correct))
|
||||
|
||||
print("Writing to: ", out_file)
|
||||
if out_file.endswith(".gz"):
|
||||
fp = gzip.GzipFile(fileobj=open(out_file, "wb"), mode="wb")
|
||||
for res in results.values():
|
||||
for r in res:
|
||||
fp.write((json.dumps(r[1]) + "\n").encode("utf-8"))
|
||||
else:
|
||||
fp = open(out_file, 'w')
|
||||
for res in results.values():
|
||||
for r in res:
|
||||
fp.write(json.dumps(r[1]) + "\n")
|
||||
fp.close()
|
||||
|
||||
print("Evaluation finished.")
|
||||
|
||||
|
||||
def main():
|
||||
fire.Fire(evaluate_functional_correctness)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
|
@ -8,8 +8,38 @@ import signal
|
|||
import random
|
||||
import subprocess
|
||||
import tempfile
|
||||
import gzip
|
||||
import json
|
||||
from typing import *
|
||||
|
||||
def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> None:
|
||||
"""
|
||||
Method saves list of dicts into jsonl file.
|
||||
:param data: (list) list of dicts to be stored,
|
||||
:param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends
|
||||
.jsonl suffix into the file.
|
||||
:param compress: (bool) should file be compressed into a gzip archive?
|
||||
"""
|
||||
sjsonl = '.jsonl'
|
||||
sgz = '.gz'
|
||||
# Check filename
|
||||
if not filename.endswith(sjsonl):
|
||||
filename = filename + sjsonl
|
||||
# Save data
|
||||
|
||||
if compress:
|
||||
filename = filename + sgz
|
||||
with gzip.open(filename, 'w') as compressed:
|
||||
for ddict in data_list:
|
||||
jout = json.dumps(ddict) + '\n'
|
||||
jout = jout.encode('utf-8')
|
||||
compressed.write(jout)
|
||||
else:
|
||||
with open(filename, 'w') as out:
|
||||
for ddict in data_list:
|
||||
jout = json.dumps(ddict) + '\n'
|
||||
out.write(jout)
|
||||
|
||||
|
||||
def check_correctness(
|
||||
task_id: str,
|
||||
|
@ -52,8 +82,8 @@ def check_correctness(
|
|||
# does not perform destructive actions on their host or network.
|
||||
# Once you have read this disclaimer and taken appropriate precautions,
|
||||
# uncomment the following line and proceed at your own risk:
|
||||
# exec(sample["test_code"], exec_globals)
|
||||
result.append("passed")
|
||||
exec(sample["test_code"], exec_globals)
|
||||
result.append("passed")
|
||||
except TimeoutException:
|
||||
result.append("timed out")
|
||||
except AssertionError as e:
|
||||
|
@ -92,7 +122,7 @@ def check_correctness(
|
|||
# does not perform destructive actions on their host or network.
|
||||
# Once you have read this disclaimer and taken appropriate precautions,
|
||||
# uncomment the following line and proceed at your own risk:
|
||||
# exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
|
||||
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
|
||||
|
||||
if exec_result.returncode == 0:
|
||||
result.append("passed")
|
||||
|
@ -137,7 +167,7 @@ def check_correctness(
|
|||
# does not perform destructive actions on their host or network.
|
||||
# Once you have read this disclaimer and taken appropriate precautions,
|
||||
# uncomment the following line and proceed at your own risk:
|
||||
# exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
|
||||
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
|
||||
|
||||
if exec_result.stderr.decode():
|
||||
err = exec_result.stderr.decode()
|
||||
|
@ -190,7 +220,7 @@ def check_correctness(
|
|||
# does not perform destructive actions on their host or network.
|
||||
# Once you have read this disclaimer and taken appropriate precautions,
|
||||
# uncomment the following line and proceed at your own risk:
|
||||
# exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
|
||||
exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
|
||||
|
||||
if exec_result.returncode == 0:
|
||||
result.append("passed")
|
||||
|
@ -210,6 +240,71 @@ def check_correctness(
|
|||
result.append("timed out")
|
||||
|
||||
shutil.rmtree(tmp_dir)
|
||||
elif "rust" in language_type.lower():
|
||||
import os
|
||||
|
||||
WD: str = os.path.dirname(os.path.abspath(__file__))
|
||||
RUST_DIR: str = os.path.join(WD, "rust")
|
||||
RUST_SRC: str = os.path.join(RUST_DIR, "src")
|
||||
RUST_BIN: str = os.path.join(RUST_SRC, "bin")
|
||||
RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
|
||||
RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
|
||||
RUST_EXT: str = ".rs"
|
||||
|
||||
# Create mandatory tmp directories
|
||||
os.makedirs(RUST_TMP_DIR, exist_ok=True)
|
||||
os.makedirs(RUST_LOGS, exist_ok=True)
|
||||
os.makedirs(RUST_SRC, exist_ok=True)
|
||||
os.makedirs(RUST_BIN, exist_ok=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
|
||||
#temporal file name
|
||||
file_prefix = sample["task_id"].lower().replace("/", "_")
|
||||
file_name:str = file_prefix +RUST_EXT
|
||||
|
||||
os.rename(f.name, os.path.join(RUST_BIN, file_name))
|
||||
|
||||
# Sample to pure Rust function
|
||||
rust_code: str = sample["test_code"]
|
||||
|
||||
# dump the rust source code in the target temporal file
|
||||
f.write(rust_code.encode('utf-8'))
|
||||
|
||||
# Proceed towards Rust binaries compilation. Therefore move to Rust module root dir.
|
||||
os.chdir(RUST_DIR)
|
||||
|
||||
# Two possible outcomes
|
||||
# Pass OR Fail compilation
|
||||
log_filename: str = file_prefix + ".jsonl"
|
||||
log_path: str = os.path.join(RUST_LOGS, log_filename)
|
||||
cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path
|
||||
# Compilation build status
|
||||
returned_val_compilation: int
|
||||
|
||||
# Overwrite file content
|
||||
if os.path.exists(log_path):
|
||||
if(file_size := os.path.getsize(log_path)) >= 0:
|
||||
os.remove(log_path)
|
||||
returned_val_compilation = os.system(cargo_check)
|
||||
|
||||
else:
|
||||
returned_val_compilation = os.system(cargo_check)
|
||||
|
||||
# 0 means success
|
||||
if returned_val_compilation == 0:
|
||||
|
||||
#Execution pipeline
|
||||
cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path
|
||||
returned_val_execution = os.system(cargo_test)
|
||||
|
||||
if returned_val_execution == 0:
|
||||
result.append("passed")
|
||||
else:
|
||||
result.append(f"failed: execution error")
|
||||
|
||||
else:
|
||||
result.append(f"failed: compilation error")
|
||||
|
||||
|
||||
elif "java" in language_type.lower():
|
||||
assert tmp_dir is not None, "Java should be evaluated in a temporary dir."
|
||||
|
@ -264,7 +359,7 @@ def check_correctness(
|
|||
result.append(res)
|
||||
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
|
||||
manager = multiprocessing.Manager()
|
||||
result = manager.list()
|
||||
|
||||
|
|
|
@ -74,6 +74,10 @@ def process_humaneval_test(sample, problems, example_test=False):
|
|||
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
|
||||
else:
|
||||
test_string = test_setup + "\n" + prompt + code + "\n" + test
|
||||
elif language == "rust":
|
||||
main = "\nfn main(){ \n } \n"
|
||||
declaration = problems[task_id]["declaration"]
|
||||
test_string = main + declaration + prompt + code + test
|
||||
|
||||
return test_string
|
||||
|
||||
|
@ -96,7 +100,7 @@ def evaluate_functional_correctness(
|
|||
input_file: str = None,
|
||||
tmp_dir: str = "./",
|
||||
n_workers: int = 32,
|
||||
timeout: float = 5.0,
|
||||
timeout: float = 500.0,
|
||||
problem_file: str = "../data/humaneval_python.jsonl.gz",
|
||||
out_dir: str = None,
|
||||
k: List[int] = [1, 10, 100],
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,121 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fuchsia-cprng"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.139"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
|
||||
|
||||
[[package]]
|
||||
name = "md5"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293"
|
||||
dependencies = [
|
||||
"fuchsia-cprng",
|
||||
"libc",
|
||||
"rand_core 0.3.1",
|
||||
"rdrand",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b"
|
||||
dependencies = [
|
||||
"rand_core 0.4.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc"
|
||||
|
||||
[[package]]
|
||||
name = "rdrand"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2"
|
||||
dependencies = [
|
||||
"rand_core 0.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
|
||||
|
||||
[[package]]
|
||||
name = "rust"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"md5",
|
||||
"rand",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
|
@ -0,0 +1,12 @@
|
|||
[package]
|
||||
name = "rust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
rand = "0.4"
|
||||
regex = "1"
|
||||
md5 = "0.7.0"
|
||||
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,71 @@
|
|||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from codegeex.benchmark.evaluate_humaneval_x import evaluate_functional_correctness
|
||||
#GLOBALS
|
||||
INPUT_FILE: str
|
||||
LANGUAGE: str
|
||||
N_WORKERS: int
|
||||
TIMEOUT: int
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("Debugging evaluate humaneval_x")
|
||||
# Path to the .jsonl file that contains the generated codes.
|
||||
parser.add_argument("-s","--samples", type=str)
|
||||
|
||||
# Target programming language, currently support one of ["python", "java", "cpp", "js", "go"]
|
||||
parser.add_argument("-l","--language", default="python", type=str)
|
||||
|
||||
# Number of parallel workers.
|
||||
parser.add_argument("-w","--workers", default=64, type=int)
|
||||
|
||||
# Timeout in seconds.
|
||||
parser.add_argument("-t","--timeout", default=5, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INPUT_FILE = args.samples
|
||||
LANGUAGE = args.language
|
||||
N_WORKERS = args.workers
|
||||
TIMEOUT= args.timeout
|
||||
|
||||
|
||||
|
||||
SCRIPT_PATH: str = Path(os.path.abspath(__file__))
|
||||
print(SCRIPT_PATH)
|
||||
SCRIPT_DIR: str = os.path.dirname(SCRIPT_PATH)
|
||||
print(SCRIPT_DIR)
|
||||
MAIN_DIR: str = os.path.dirname(SCRIPT_DIR)
|
||||
print(MAIN_DIR)
|
||||
|
||||
DATA_DIR=os.path.join(MAIN_DIR,"codegeex/benchmark/humaneval-x/" + LANGUAGE + "/data/humaneval_" + LANGUAGE + ".jsonl.gz")
|
||||
print(DATA_DIR)
|
||||
|
||||
TMP_DIR=os.path.join(MAIN_DIR, "/codegeex/benchmark/humaneval-x/")
|
||||
|
||||
|
||||
#Debugging
|
||||
INPUT_FILE='/home/rog0d/Escritorio/CodeGeeX/generations/humaneval_rust_generations.jsonl.gz'
|
||||
LANGUAGE='rust'
|
||||
DATA_DIR=os.path.join(MAIN_DIR,"codegeex/benchmark/humaneval-x/" + LANGUAGE + "/data/humaneval_" + LANGUAGE + ".jsonl.gz")
|
||||
|
||||
"""
|
||||
input_file: str = None,
|
||||
tmp_dir: str = "./",
|
||||
n_workers: int = 32,
|
||||
timeout: float = 5.0,
|
||||
problem_file: str = "../data/humaneval_python.jsonl.gz",
|
||||
out_dir: str = None,
|
||||
k: List[int] = [1, 10, 100],
|
||||
test_groundtruth: bool = False,
|
||||
example_test: bool = False,
|
||||
|
||||
"""
|
||||
|
||||
evaluate_functional_correctness(input_file=INPUT_FILE,
|
||||
n_workers=N_WORKERS,
|
||||
tmp_dir=TMP_DIR,
|
||||
problem_file=DATA_DIR,
|
||||
timeout=300.0)
|
||||
|
||||
|
Loading…
Reference in New Issue