Skip to content
Snippets Groups Projects
Commit ce7b4d8d authored by Benoit Favre's avatar Benoit Favre
Browse files

add generation of random folds

parent a70841cb
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ set -e -u -o pipefail
# output location
out="$dir/data/"`date '+%Y%m%d'`
mkdir -p "$out"
mkdir -p "$out/folds"
# CORD-19 metadata
curl https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/latest/metadata.csv > "$out/cord19-metadata_stage1.csv"
......@@ -27,6 +27,10 @@ python "$dir/bibliovid_scrapper.py" "$out/bibliovid_stage1.json" > "$out/bibliov
python "$dir/bibliovid_add_abstract.py" "$out/bibliovid_stage2.json" "$out/bibliovid_stage3.json"
python "$dir/bibliovid_normalize.py" "$out/bibliovid_stage3.json" > "$out/bibliovid.json"
# generate folds
python "$dir/split_json_random.py" "$out/folds/bibliovid" 5 .1 .1 < "$out/bibliovid.json"
python "$dir/split_json_random.py" "$out/folds/litcovid" 5 .1 .1 < "$out/litcovid.json"
# cleanup
rm "$out/"*_stage*
import json
import sys
import random
if len(sys.argv) != 5:
print('usage: %s <output-stem> <n-folds> <test-percent> <valid-percent>' % sys.argv[0])
sys.exit(1)
output_stem = sys.argv[1]
num_folds = int(sys.argv[2])
items = json.loads(sys.stdin.read())
num_test = int(float(sys.argv[3]) * len(items))
num_valid = int(float(sys.argv[4]) * len(items))
for n in range(num_folds):
random.shuffle(items)
with open(output_stem + '-%d.test' % n, 'w') as fp:
fp.write(json.dumps(items[:num_test]))
with open(output_stem + '-%d.valid' % n, 'w') as fp:
fp.write(json.dumps(items[num_test: num_test + num_valid]))
with open(output_stem + '-%d.train' % n, 'w') as fp:
fp.write(json.dumps(items[num_test + num_valid:]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment