Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • main
1 result

Target

Select target project
No results found
Select Git revision
  • main
1 result
Show changes
69 files
+ 84457
25
Compare changes
  • Side-by-side
  • Inline

Files

+8 −3
Original line number Diff line number Diff line
System for generating DEFT 2023 outputs from LLMs
=================================================

* Update 2023-09-26: add code to finetune and perform inference with LLaMa2 (performance as good as ChatGPT)

The DEFT'23 shared task consists in answering pharma exam MCQs. This system converts the questions and possible answers to prompts and uses LLMs to generate answers.
The approach is described in our [paper](http://talnarchives.atala.org/CORIA-TALN/CORIA-TALN-2023/479307.pdf). It ranked 1st at the shared task.
This repository contains scripts to generate prompts, run off-the-shelf models and finetune the LLaMa models. It also contains the LoRA weights for the finetuned models. 

This repository uses git LFS for large files.
Use 'git lfs install' before cloning to retrive the binary files.
Use 'git lfs clone...' for cloning with the binary files.

Install:
```
pip install -r requirements.txt
pip install -r requirements.txt # for llama1 
pip install -r requirements.llama2-freeze.txt # for llama2
```

Note that bitsandbytes may need to be recompiled to support your cuda version.
Note that llama2 was finetuned with Python/3.10.10 and CUDA/11.6 on a single A100-80 GPU

See RESULTS for the exact match results on the dev.
See runs for how to generate runs.
See trains for llama2 finetuning runs.

Note that external APIs require API keys. Please rename api_keys.template.py to api_keys.py and set keys you need inside.
Note that external APIs require API keys. Rename api_keys.template.py to api_keys.py and set keys you need inside.

Please cite the follwing paper:
```
+13 −0
Original line number Diff line number Diff line
@@ -29,11 +29,18 @@ en/tk-instruct-11b-def 0.1442

(int8)
llama_7B                   0.0576    
llama2-7b                  0.0833
llama2-7b-chat             0.0801
llama_7B+alpaca_fr         0.1185                                                                          
llama_7B+alpaca            0.1217
llama_7B+alpaca-native     0.1153
llama_7B+deft              0.1378    
llama2-7b-deft             0.1410
llama2-7b-deft-noprompt    0.2179
llama_13B                  0.0769    
llama2-13b                 0.1442
llama2-13b-chat            0.1474
llama2-13b-deft            0.2788
llama_13B+alpaca           0.1474
llama_13B+vicuna           0.1538
llama_13B+deft             0.1730    
@@ -41,7 +48,11 @@ llama_30B 0.1442
llama_30B+alpaca           0.1923
llama_30B+deft             0.2467    
llama_65B                  0.1730    
llama2-70b                 0.2051
llama2-70b-chat            0.2211
llama_65B+deft             0.3044      
llama2-70b-deft            0.4455
llama2-70b-comp            0.4679

(fp16)
llama_30B                  0.1891    
@@ -74,3 +85,5 @@ PubMedBERT 33.98 14.14 34.00 13.98 35.66 15.59 33.87 14.79 35.44 14.79
CamemBERT-base   36.24 16.55 34.19 14.46 34.78 15.43 34.66 14.79   34.61 14.95
XLM-RoBERTa-base 37.92 17.20 31.26 11.89 35.84 16.07 32.47 14.63   33.00 14.95
BART-base        31.93 15.91 34.98 18.64 33.80 17.68 29.65 12.86   34.65 18.32

data/dev-en.json

0 → 100644
+5591 −0

File added.

Preview size limit exceeded, changes collapsed.

+11 −6
Original line number Diff line number Diff line
@@ -14,11 +14,14 @@ lm_templates_en = [

letters = 'abcdefghijklmnopqrstuvwxyz'

def linearize_instance(instance, include_correct_answers=False, add_left_parenthesis=False, bare=False):
def linearize_instance(instance, include_correct_answers=False, include_full_answers=False, add_left_parenthesis=False, bare=False, **kwargs):
    result = instance['question'] + '\n' + '\n'.join('(%s) %s.' % (k, v) for k, v in instance['answers'].items())
    if bare:
        return result
    elif include_correct_answers:
        if include_full_answers:
            result += '\nRéponse(s) : ' + '; '.join('(%s) %s' % (a, instance['answers'][a]) for a in instance['correct_answers']) + '.\n'
        else:
            result += '\nRéponse(s) : ' + ' '.join('(%s)' % a for a in instance['correct_answers'])
    else:
        result += '\nRéponse(s) :' + (' (' if add_left_parenthesis else '')
@@ -34,8 +37,10 @@ def get_prompt(prompt, instance, few_shots=[], **kwargs):
    shots = [linearize_instance(shot, include_correct_answers=True, **kwargs) for shot in few_shots]
    return prompt % ('\n\n'.join(shots + [linearize_instance(instance, **kwargs)]),)

def extract_answer(answer, num_answers=5):
def extract_answer(answer, num_answers=5, stop_at_line_break=False, **kwargs):
    answer = re.sub('Ceci est une question de QCM.*', '', answer).strip().lower()
    if stop_at_line_break:
      answer = re.split(r'\n[ \t]*\n', answer)[0]
    selected = re.findall(r'^[a-%s]\)|\([a-%s]\)' % (letters[num_answers - 1], letters[num_answers - 1]), answer)
    if len(selected) == 0:
        selected = re.findall(r'(\b[a-%s]\b)' % letters[num_answers - 1], answer)
@@ -69,10 +74,10 @@ def run_inference(generator, corpus_path, template, **kwargs):
    results = []
    for instance in dev_corpus:
        prompt = get_prompt(template, instance, **kwargs)
        print(prompt)
        print('PROMPT: [%s]' % prompt)
        generated = generator(prompt)
        print(generated)
        answer = extract_answer(generated, len(instance['answers']))
        print('GENERATED: [%s]' % generated)
        answer = extract_answer(generated, len(instance['answers']), **kwargs)
        print(answer, instance['correct_answers'])
        if set(answer) == set(instance['correct_answers']):
            num_exact_correct += 1

finetune_llama2.py

0 → 100644
+193 −0
Original line number Diff line number Diff line
import os, json
import torch
from datasets import Dataset
import pandas
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

def load_data(path: str):
    import deft
    with open(path) as fp:
        corpus = json.loads(fp.read())
    template = deft.lm_templates[0]
    corpus = [{"text": deft.get_prompt(template, instance, include_correct_answers=True, include_full_answers=True)} for instance in corpus]
    dataset = Dataset.from_pandas(pandas.DataFrame(data=corpus))
    return dataset

def finetune_lora(model_name: str, 
        train_dataset_name: str, 
        eval_dataset_name: str, 
        new_model: str, 
        run_name: str,
        lora_r:int = 4, 
        lora_alpha: int = 16, 
        lora_dropout: float = 0.05, 
        use_4bit: bool = True, 
        bnb_4bit_compute_dtype: str = "float16", 
        bnb_4bit_quant_type: str = "nf4", 
        use_nested_quant: bool = False, 
        output_dir: str = "./results", 
        num_train_epochs: int = 1, 
        fp16: bool = False, 
        bf16: bool = True, 
        batch_size: int = 4,
        micro_batch_size: int = 4,
        gradient_checkpointing: bool = True, 
        max_grad_norm: float = 0.3, 
        learning_rate: float = 3e-4, 
        weight_decay: float = 0.001, 
        optim: str = "paged_adamw_32bit", 
        lr_scheduler_type: str = "cosine", 
        max_steps: int = -1, 
        warmup_ratio: float = 0.05, 
        group_by_length: bool = True, 
        save_steps: int = 0, 
        logging_steps: int = 1, 
        max_seq_length: int = 256, 
        packing: bool = False, 
        device_map: str = '{"":0}',
        train_on_completions_only: bool = False):

    device_map = json.loads(device_map)
    train_dataset = load_data(train_dataset_name)
    #eval_dataset = load_data(eval_dataset_name)

    # Load tokenizer and model with QLoRA configuration
    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant,
    )

    # Check GPU compatibility with bfloat16
    if compute_dtype == torch.float16 and use_4bit:
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            print("=" * 80)
            print("Your GPU supports bfloat16: accelerate training with bf16=True")
            print("=" * 80)

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map=device_map
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load LLaMA tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

    # Load LoRA configuration
    peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=["q_proj", "v_proj"],
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Set training parameters
    training_arguments = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=batch_size // micro_batch_size,
        optim=optim,
        save_steps=save_steps,
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        fp16=fp16,
        bf16=bf16,
        max_grad_norm=max_grad_norm,
        max_steps=max_steps,
        warmup_ratio=warmup_ratio,
        group_by_length=group_by_length,
        lr_scheduler_type=lr_scheduler_type,
        run_name=run_name,
        report_to="wandb",
    )

    if train_on_completions_only:
        response_template = "Réponse(s) :"
        tokens =  tokenizer.tokenize(response_template, add_special_tokens=False)
        token_ids = tokenizer.encode(response_template, add_special_tokens=False)
        print(list(zip(tokens, token_ids)))
        collator = DataCollatorForCompletionOnlyLM(token_ids[1:], tokenizer=tokenizer)
    else:
        collator = None

    # Set supervised fine-tuning parameters
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        #eval_dataset=eval_dataset,
        peft_config=peft_config,
        dataset_text_field="text",
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=packing,
        data_collator=collator,
    )

    # Train model
    trainer.train()

    # Save trained model
    trainer.model.save_pretrained(new_model)

def load_for_inference(model_name: str, 
        lora_model: str, 
        use_4bit: bool = True, 
        bnb_4bit_compute_dtype: str = "float16", 
        bnb_4bit_quant_type: str = "nf4", 
        use_nested_quant: bool = False, 
        device_map='{"":0}'):

    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant,
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        #low_cpu_mem_usage=True,
        quantization_config=bnb_config,
        return_dict=True,
        torch_dtype=torch.float16,
        device_map=json.loads(device_map),
    )
    model = PeftModel.from_pretrained(base_model, lora_model)
    #model = model.merge_and_unload()

    # Reload tokenizer to save it
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    return model, tokenizer

if __name__ == '__main__':
    import fire
    #print(load_data("../json/train.json")['text'][0])
    fire.Fire(finetune_lora)
Original line number Diff line number Diff line
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0
Original line number Diff line number Diff line
{
  "auto_mapping": null,
  "base_model_name_or_path": "/storage/raid1/corpora/llama2-weights/convert/llama-2-13b-hf/",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "lora_alpha": 16,
  "lora_dropout": 0.05,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 4,
  "revision": null,
  "target_modules": [
    "q_proj",
    "v_proj"
  ],
  "task_type": "CAUSAL_LM"
}
 No newline at end of file
Original line number Diff line number Diff line
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0
Original line number Diff line number Diff line
{
  "auto_mapping": null,
  "base_model_name_or_path": "/storage/raid1/corpora/llama2-weights/convert/llama-2-13b-hf/",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "lora_alpha": 16,
  "lora_dropout": 0.05,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 4,
  "revision": null,
  "target_modules": [
    "q_proj",
    "v_proj"
  ],
  "task_type": "CAUSAL_LM"
}
 No newline at end of file
Original line number Diff line number Diff line
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0
Original line number Diff line number Diff line
{
  "auto_mapping": null,
  "base_model_name_or_path": "/storage/raid1/corpora/llama2-weights/convert/llama-2-70b-hf/",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "lora_alpha": 16,
  "lora_dropout": 0.05,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 4,
  "revision": null,
  "target_modules": [
    "q_proj",
    "v_proj"
  ],
  "task_type": "CAUSAL_LM"
}
 No newline at end of file
Original line number Diff line number Diff line
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0
Original line number Diff line number Diff line
{
  "auto_mapping": null,
  "base_model_name_or_path": "/storage/raid1/corpora/llama2-weights/convert/llama-2-70b-hf/",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "lora_alpha": 16,
  "lora_dropout": 0.05,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 4,
  "revision": null,
  "target_modules": [
    "q_proj",
    "v_proj"
  ],
  "task_type": "CAUSAL_LM"
}
 No newline at end of file
Original line number Diff line number Diff line
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0
Original line number Diff line number Diff line
{
  "auto_mapping": null,
  "base_model_name_or_path": "/storage/raid1/corpora/llama2-weights/convert/llama-2-7b-hf/",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "lora_alpha": 16,
  "lora_dropout": 0.05,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 4,
  "revision": null,
  "target_modules": [
    "q_proj",
    "v_proj"
  ],
  "task_type": "CAUSAL_LM"
}
 No newline at end of file
+21 −0
Original line number Diff line number Diff line
---
library_name: peft
---
## Training procedure


The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
### Framework versions


- PEFT 0.5.0
Original line number Diff line number Diff line
{
  "auto_mapping": null,
  "base_model_name_or_path": "/storage/raid1/corpora/llama2-weights/convert/llama-2-7b-hf/",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "lora_alpha": 16,
  "lora_dropout": 0.05,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 4,
  "revision": null,
  "target_modules": [
    "q_proj",
    "v_proj"
  ],
  "task_type": "CAUSAL_LM"
}
 No newline at end of file
+75 −0
Original line number Diff line number Diff line
accelerate==0.23.0
aiohttp==3.8.5
aiosignal==1.3.1
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
bitsandbytes==0.41.1
certifi==2023.7.22
charset-normalizer==3.2.0
click==8.1.7
cmake==3.27.5
datasets==2.14.5
dill==0.3.7
docker-pycreds==0.4.0
filelock==3.12.4
fire==0.5.0
frozenlist==1.4.0
fsspec==2023.6.0
gitdb==4.0.10
GitPython==3.1.37
huggingface-hub==0.17.2
idna==3.4
Jinja2==3.1.2
Levenshtein==0.21.1
lit==16.0.6
MarkupSafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.1
numpy==1.26.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
packaging==23.1
pandas==2.1.1
pathtools==0.1.2
peft==0.5.0
protobuf==3.20.0
psutil==5.9.5
pyarrow==13.0.0
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
rapidfuzz==3.3.0
regex==2023.8.8
requests==2.31.0
safetensors==0.3.3
scipy==1.11.2
sentry-sdk==1.31.0
setproctitle==1.3.2
six==1.16.0
smmap==5.0.1
sympy==1.12
termcolor==2.3.0
tokenizers==0.13.3
torch==2.0.1
tqdm==4.66.1
transformers==4.33.2
triton==2.0.0
trl==0.7.1
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.0.5
wandb==0.15.11
xxhash==3.3.0
yarl==1.9.2
+11 −0
Original line number Diff line number Diff line
protobuf==3.20
accelerate>=0.21.0
bitsandbytes>=0.40.2
torch>=1.13.1
transformers>=4.31.0
fire
levenshtein
peft>=0.4.0
trl>=0.4.7
scipy
wandb

run_llama2.py

0 → 100644
+38 −0
Original line number Diff line number Diff line
#https://huggingface.co/bigscience/bloomz-7b1-mt
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

def main(result_path: str, corpus_path: str, model: str = 'llama-2-7b-hf', template_id: str = '0'):
    checkpoint = 'llama2-weights/convert/' + model

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    quant_config=BitsAndBytesConfig(
        #load_in_8bit=True,
        # llm_int8_threshold=6.0,
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16,
        #llm_int8_enable_fp32_cpu_offload=True, 
    )
    device_map = {
        "": 0
    }
    llm = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=device_map, torch_dtype=torch.float16, load_in_8bit=True) #quantization_config=quant_config)#, load_in_8bit=True) 

    def generate(input_string):
        inputs = tokenizer(input_string, return_tensors="pt")
        outputs = llm.generate(input_ids=inputs.input_ids.to('cuda'), attention_mask=inputs.attention_mask, max_new_tokens=32, pad_token_id=tokenizer.eos_token_id)

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated[len(input_string):]

    import deft
    results = deft.run_inference(generate, corpus_path, deft.template_from_id(template_id))
    deft.write_results(results, result_path)

if __name__ == '__main__':
    import fire
    fire.Fire(main)
+38 −0
Original line number Diff line number Diff line
#https://huggingface.co/bigscience/bloomz-7b1-mt
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

import finetune_llama2

def main(result_path: str, corpus_path: str, base_model: str, lora_model: str, template_id: str = '0'):
    checkpoint = 'llama2-weights/convert/' + base_model

    llm, tokenizer = finetune_llama2.load_for_inference(checkpoint, lora_model)
    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        top_k=40,
        num_beams=4,
    )

    def generate(input_string):
        inputs = tokenizer(input_string, return_tensors="pt")
        outputs = llm.generate(input_ids=inputs.input_ids.to('cuda'), 
            attention_mask=inputs.attention_mask, 
            max_new_tokens=128, 
            pad_token_id=tokenizer.eos_token_id,
            #generation_config=generation_config,
            temperature=0,
            #do_sample=True,
        )

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated[len(input_string):]

    import deft
    results = deft.run_inference(generate, corpus_path, deft.template_from_id(template_id), stop_at_line_break=True)
    deft.write_results(results, result_path)

if __name__ == '__main__':
    import fire
    fire.Fire(main)
+30 −16
Original line number Diff line number Diff line
python run_flan-ul2.py output/flan-ul2_prompt0.txt ../json/dev.json | tee logs/flan-ul2_prompt0.txt
python run_flan-t5-xxl.py output/flan-t5-xxl_prompt0.txt ../json/dev.json | tee logs/flan-t5-xxl_prompt0.txt
python run_bloomz.py output/bloomz-7b1-mt_prompt0.txt ../json/dev.json bloomz-7b1-mt | tee logs/bloomz-7b1-mt_prompt0.txt
python run_bloomz.py output/bloomz-560m_prompt0.txt ../json/dev.json bloomz-560m | tee logs/bloomz-560m_prompt0.txt
python run_tkinstruct.py output/tk-instruct-3b-def_prompt0.txt ../json/dev.json tk-instruct-3b-def | tee logs/tk-instruct-3b-def_prompt0.txt
python run_tkinstruct.py output/tk-instruct-11b-def_prompt0.txt ../json/dev.json tk-instruct-11b-def | tee logs/tk-instruct-11b-def_prompt0.txt
python run_opt-iml.py output/opt-iml-30b_prompt0.txt ../json/dev.json opt-iml-30b | tee logs/opt-iml-30b_prompt0.txt
python run_galactica.py output/galactica-30b_prompt0.txt ../json/dev.json galactica-30b | tee logs/galactica-30b_prompt0.txt
python run_flan-ul2.py output/flan-ul2_prompt0.txt data/dev.json | tee logs/flan-ul2_prompt0.txt
python run_flan-t5-xxl.py output/flan-t5-xxl_prompt0.txt data/dev.json | tee logs/flan-t5-xxl_prompt0.txt
python run_bloomz.py output/bloomz-7b1-mt_prompt0.txt data/dev.json bloomz-7b1-mt | tee logs/bloomz-7b1-mt_prompt0.txt
python run_bloomz.py output/bloomz-560m_prompt0.txt data/dev.json bloomz-560m | tee logs/bloomz-560m_prompt0.txt
python run_tkinstruct.py output/tk-instruct-3b-def_prompt0.txt data/dev.json tk-instruct-3b-def | tee logs/tk-instruct-3b-def_prompt0.txt
python run_tkinstruct.py output/tk-instruct-11b-def_prompt0.txt data/dev.json tk-instruct-11b-def | tee logs/tk-instruct-11b-def_prompt0.txt
python run_opt-iml.py output/opt-iml-30b_prompt0.txt data/dev.json opt-iml-30b | tee logs/opt-iml-30b_prompt0.txt
python run_galactica.py output/galactica-30b_prompt0.txt data/dev.json galactica-30b | tee logs/galactica-30b_prompt0.txt

python run_api.py output/code-cushman-001_prompt0.txt ../json/dev.json openai/code-cushman-001 | tee logs/code-cushman-001_prompt0.txt
python run_api.py output/code-davinci-002_prompt0.txt ../json/dev.json openai/code-davinci-002 | tee logs/code-davinci-002_prompt0.txt
python run_api.py output/gpt-4-0314_prompt0.txt ../json/dev.json openai/gpt-4-0314 | tee logs/gpt-4-0314_prompt0.txt
python run_api.py output/code-cushman-001_prompt0.txt data/dev.json openai/code-cushman-001 | tee logs/code-cushman-001_prompt0.txt
python run_api.py output/code-davinci-002_prompt0.txt data/dev.json openai/code-davinci-002 | tee logs/code-davinci-002_prompt0.txt
python run_api.py output/gpt-4-0314_prompt0.txt data/dev.json openai/gpt-4-0314 | tee logs/gpt-4-0314_prompt0.txt

python run_api.py output/j1-jumbo_prompt0.txt ../json/dev.json ai21/j1-jumbo | tee logs/j1-jumbo_prompt0.txt
python run_api.py output/j1-jumbo_prompt0.txt data/dev.json ai21/j1-jumbo | tee logs/j1-jumbo_prompt0.txt

python run_bloomz.py output/en_bloomz-560m_prompt0.txt ../json/dev-en.json bloomz-560m en/0 | tee logs/en_bloomz-560m_prompt0.txt
python run_bloomz.py output/en_bloomz-3b_prompt0.txt ../json/dev-en.json bloomz-3b en/0 | tee logs/en_bloomz-3b_prompt0.txt
python run_bloomz.py output/en_bloomz-560m_prompt0.txt data/dev-en.json bloomz-560m en/0 | tee logs/en_bloomz-560m_prompt0.txt
python run_bloomz.py output/en_bloomz-3b_prompt0.txt data/dev-en.json bloomz-3b en/0 | tee logs/en_bloomz-3b_prompt0.txt

python run_tkinstruct.py output/en_tk-instruct-11b-def_prompt0.txt ../json/dev-en.json tk-instruct-11b-def en/0 | tee logs/en_tk-instruct-11b-def_prompt0.txt
python run_open_assistant.py output/oasst-sft-1-pythia-12b_prompt0.txt ../json/dev.json | tee logs/oasst-sft-1-pythia-12b_prompt0.txt
python run_tkinstruct.py output/en_tk-instruct-11b-def_prompt0.txt data/dev-en.json tk-instruct-11b-def en/0 | tee logs/en_tk-instruct-11b-def_prompt0.txt
python run_open_assistant.py output/oasst-sft-1-pythia-12b_prompt0.txt data/dev.json | tee logs/oasst-sft-1-pythia-12b_prompt0.txt
python run_pmc_llama.py output/pmc_llama_prompt0.txt data/dev.json | tee logs/pmc_llama_prompt0.txt

# llama2
python run_llama2.py output/llama2-7b_prompt0.txt data/dev.json llama-2-7b-hf | tee logs/llama2-7b_prompt0.txt
python run_llama2.py output/llama2-7b-chat_prompt0.txt data/dev.json llama-2-7b-chat-hf | tee logs/llama2-7b-chat_prompt0.txt
python run_llama2.py output/llama2-13b_prompt0.txt data/dev.json llama-2-13b-hf | tee logs/llama2-13b_prompt0.txt
python run_llama2.py output/llama2-13b-chat_prompt0.txt data/dev.json llama-2-13b-chat-hf | tee logs/llama2-13b-chat_prompt0.txt
python run_llama2.py output/llama2-70b_prompt0.txt data/dev.json llama-2-70b-hf | tee logs/llama2-70b_prompt0.txt
python run_llama2.py output/llama2-70b-chat_prompt0.txt data/dev.json llama-2-70b-chat-hf | tee logs/llama2-70b-chat_prompt0.txt
python run_llama2_finetuned.py output/llama2-7b-deft_prompt0.txt data/dev.json llama-2-7b-hf models/llama-2-7b-deft | tee logs/llama2-7b-deft_prompt0.txt
python run_llama2_finetuned.py output/llama2-7b-deft-comp_prompt0.txt data/dev.json llama-2-7b-hf models/llama-2-7b-deft-comp | tee logs/llama2-7b-deft-comp_prompt0.txt
python run_llama2_finetuned.py output/llama2-13b-deft_prompt0.txt data/dev.json llama-2-13b-hf models/llama-2-13b-deft | tee logs/llama2-13b-deft_prompt0.txt
python run_llama2_finetuned.py output/llama2-13b-deft-comp_prompt0.txt data/dev.json llama-2-13b-hf models/llama-2-13b-deft-comp | tee logs/llama2-13b-deft-comp_prompt0.txt
python run_llama2_finetuned.py output/llama2-70b-deft_prompt0.txt data/dev.json llama-2-70b-hf models/llama-2-70b-deft | tee logs/llama2-70b-deft_prompt0.txt
python run_llama2_finetuned.py output/llama2-70b-deft-comp_prompt0.txt data/dev.json llama-2-70b-hf models/llama-2-70b-deft-comp | tee logs/llama2-70b-deft-comp_prompt0.txt

runs.test

0 → 100644
+4 −0
Original line number Diff line number Diff line
python run_llama.py --model_path=llama_models/deft_llama-65b-hf_lora_98075de5-9200-4d66-ab35-61ca2a380692/ --output_path=output_test/test_llama-65b-lora_prompt0.txt --corpus_path=data/test.json | tee logs/test_llama-65b-lora_prompt0.txt
python run_api.py output_test/test_gpt-4-0314_prompt0.txt data/test.json openai/gpt-4-0314 | tee logs/test_gpt-4-0314_prompt0.txt
python run_api.py output_test/test_gpt-3.5-turbo-0301_prompt0.txt data/test.json openai/gpt-3.5-turbo-0301 | tee logs/test_gpt-3.5-turbo-0301_prompt0.txt

trains

0 → 100644
+8 −0
Original line number Diff line number Diff line
#!/bin/bash
source env.sh
python finetune.py llama2-weights/convert/llama-2-7b-hf/ data/train.json data/dev.json ./models/llama-2-7b-deft/ llama-2-7b-deft_001
python finetune.py llama2-weights/convert/llama-2-7b-hf/ data/train.json data/dev.json ./models/llama-2-7b-deft-comp/ llama-2-7b-deft_002 --train_on_completions_only=True
python finetune.py llama2-weights/convert/llama-2-13b-hf/ data/train.json data/dev.json ./models/llama-2-13b-deft/ llama-2-13b-deft_003
python finetune.py llama2-weights/convert/llama-2-13b-hf/ data/train.json data/dev.json ./models/llama-2-13b-deft-comp/ llama-2-13b-deft_004 --train_on_completions_only=True
python finetune.py llama2-weights/convert/llama-2-70b-hf/ data/train.json data/dev.json ./models/llama-2-70b-deft/ llama-2-70b-deft_005
python finetune.py llama2-weights/convert/llama-2-70b-hf/ data/train.json data/dev.json ./models/llama-2-70b-deft-comp/ llama-2-70b-deft_006 --train_on_completions_only=True