Skip to content
Snippets Groups Projects
Commit 919ed606 authored by Benoit Favre's avatar Benoit Favre
Browse files

add llama2

parent df1ff7a8
No related branches found
No related tags found
No related merge requests found
Showing
with 72769 additions and 6 deletions
System for generating DEFT 2023 outputs from LLMs System for generating DEFT 2023 outputs from LLMs
================================================= =================================================
* Update 2023-09-26: add code to finetune and perform inference with LLaMa2
The DEFT'23 shared task consists in answering pharma exam MCQs. This system converts the questions and possible answers to prompts and uses LLMs to generate answers. The DEFT'23 shared task consists in answering pharma exam MCQs. This system converts the questions and possible answers to prompts and uses LLMs to generate answers.
The approach is described in our [paper](http://talnarchives.atala.org/CORIA-TALN/CORIA-TALN-2023/479307.pdf). It ranked 1st at the shared task. The approach is described in our [paper](http://talnarchives.atala.org/CORIA-TALN/CORIA-TALN-2023/479307.pdf). It ranked 1st at the shared task.
This repository contains scripts to generate prompts, run off-the-shelf models and finetune the LLaMa models. It also contains the LoRA weights for the finetuned models. This repository contains scripts to generate prompts, run off-the-shelf models and finetune the LLaMa models. It also contains the LoRA weights for the finetuned models.
......
...@@ -29,11 +29,18 @@ en/tk-instruct-11b-def 0.1442 ...@@ -29,11 +29,18 @@ en/tk-instruct-11b-def 0.1442
(int8) (int8)
llama_7B 0.0576 llama_7B 0.0576
llama2-7b 0.0833
llama2-7b-chat 0.0801
llama_7B+alpaca_fr 0.1185 llama_7B+alpaca_fr 0.1185
llama_7B+alpaca 0.1217 llama_7B+alpaca 0.1217
llama_7B+alpaca-native 0.1153 llama_7B+alpaca-native 0.1153
llama_7B+deft 0.1378 llama_7B+deft 0.1378
llama2-7b-deft 0.1410
llama2-7b-deft-noprompt 0.2179
llama_13B 0.0769 llama_13B 0.0769
llama2-13b 0.1442
llama2-13b-chat 0.1474
llama2-13b-deft 0.2788
llama_13B+alpaca 0.1474 llama_13B+alpaca 0.1474
llama_13B+vicuna 0.1538 llama_13B+vicuna 0.1538
llama_13B+deft 0.1730 llama_13B+deft 0.1730
...@@ -41,7 +48,11 @@ llama_30B 0.1442 ...@@ -41,7 +48,11 @@ llama_30B 0.1442
llama_30B+alpaca 0.1923 llama_30B+alpaca 0.1923
llama_30B+deft 0.2467 llama_30B+deft 0.2467
llama_65B 0.1730 llama_65B 0.1730
llama2-70b 0.2051
llama2-70b-chat 0.2211
llama_65B+deft 0.3044 llama_65B+deft 0.3044
llama2-70b-deft 0.4455
llama2-70b-comp 0.4679
(fp16) (fp16)
llama_30B 0.1891 llama_30B 0.1891
...@@ -74,3 +85,5 @@ PubMedBERT 33.98 14.14 34.00 13.98 35.66 15.59 33.87 14.79 35.44 14.79 ...@@ -74,3 +85,5 @@ PubMedBERT 33.98 14.14 34.00 13.98 35.66 15.59 33.87 14.79 35.44 14.79
CamemBERT-base 36.24 16.55 34.19 14.46 34.78 15.43 34.66 14.79 34.61 14.95 CamemBERT-base 36.24 16.55 34.19 14.46 34.78 15.43 34.66 14.79 34.61 14.95
XLM-RoBERTa-base 37.92 17.20 31.26 11.89 35.84 16.07 32.47 14.63 33.00 14.95 XLM-RoBERTa-base 37.92 17.20 31.26 11.89 35.84 16.07 32.47 14.63 33.00 14.95
BART-base 31.93 15.91 34.98 18.64 33.80 17.68 29.65 12.86 34.65 18.32 BART-base 31.93 15.91 34.98 18.64 33.80 17.68 29.65 12.86 34.65 18.32
File moved
This diff is collapsed.
...@@ -14,12 +14,15 @@ lm_templates_en = [ ...@@ -14,12 +14,15 @@ lm_templates_en = [
letters = 'abcdefghijklmnopqrstuvwxyz' letters = 'abcdefghijklmnopqrstuvwxyz'
def linearize_instance(instance, include_correct_answers=False, add_left_parenthesis=False, bare=False): def linearize_instance(instance, include_correct_answers=False, include_full_answers=False, add_left_parenthesis=False, bare=False, **kwargs):
result = instance['question'] + '\n' + '\n'.join('(%s) %s.' % (k, v) for k, v in instance['answers'].items()) result = instance['question'] + '\n' + '\n'.join('(%s) %s.' % (k, v) for k, v in instance['answers'].items())
if bare: if bare:
return result return result
elif include_correct_answers: elif include_correct_answers:
result += '\nRéponse(s) : ' + ' '.join('(%s)' % a for a in instance['correct_answers']) if include_full_answers:
result += '\nRéponse(s) : ' + '; '.join('(%s) %s' % (a, instance['answers'][a]) for a in instance['correct_answers']) + '.\n'
else:
result += '\nRéponse(s) : ' + ' '.join('(%s)' % a for a in instance['correct_answers'])
else: else:
result += '\nRéponse(s) :' + (' (' if add_left_parenthesis else '') result += '\nRéponse(s) :' + (' (' if add_left_parenthesis else '')
return result return result
...@@ -34,8 +37,10 @@ def get_prompt(prompt, instance, few_shots=[], **kwargs): ...@@ -34,8 +37,10 @@ def get_prompt(prompt, instance, few_shots=[], **kwargs):
shots = [linearize_instance(shot, include_correct_answers=True, **kwargs) for shot in few_shots] shots = [linearize_instance(shot, include_correct_answers=True, **kwargs) for shot in few_shots]
return prompt % ('\n\n'.join(shots + [linearize_instance(instance, **kwargs)]),) return prompt % ('\n\n'.join(shots + [linearize_instance(instance, **kwargs)]),)
def extract_answer(answer, num_answers=5): def extract_answer(answer, num_answers=5, stop_at_line_break=False, **kwargs):
answer = re.sub('Ceci est une question de QCM.*', '', answer).strip().lower() answer = re.sub('Ceci est une question de QCM.*', '', answer).strip().lower()
if stop_at_line_break:
answer = re.split(r'\n[ \t]*\n', answer)[0]
selected = re.findall(r'^[a-%s]\)|\([a-%s]\)' % (letters[num_answers - 1], letters[num_answers - 1]), answer) selected = re.findall(r'^[a-%s]\)|\([a-%s]\)' % (letters[num_answers - 1], letters[num_answers - 1]), answer)
if len(selected) == 0: if len(selected) == 0:
selected = re.findall(r'(\b[a-%s]\b)' % letters[num_answers - 1], answer) selected = re.findall(r'(\b[a-%s]\b)' % letters[num_answers - 1], answer)
...@@ -69,10 +74,10 @@ def run_inference(generator, corpus_path, template, **kwargs): ...@@ -69,10 +74,10 @@ def run_inference(generator, corpus_path, template, **kwargs):
results = [] results = []
for instance in dev_corpus: for instance in dev_corpus:
prompt = get_prompt(template, instance, **kwargs) prompt = get_prompt(template, instance, **kwargs)
print(prompt) print('PROMPT: [%s]' % prompt)
generated = generator(prompt) generated = generator(prompt)
print(generated) print('GENERATED: [%s]' % generated)
answer = extract_answer(generated, len(instance['answers'])) answer = extract_answer(generated, len(instance['answers']), **kwargs)
print(answer, instance['correct_answers']) print(answer, instance['correct_answers'])
if set(answer) == set(instance['correct_answers']): if set(answer) == set(instance['correct_answers']):
num_exact_correct += 1 num_exact_correct += 1
......
import os, json
import torch
from datasets import Dataset
import pandas
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
pipeline,
logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
def load_data(path: str):
import deft
with open(path) as fp:
corpus = json.loads(fp.read())
template = deft.lm_templates[0]
corpus = [{"text": deft.get_prompt(template, instance, include_correct_answers=True, include_full_answers=True)} for instance in corpus]
dataset = Dataset.from_pandas(pandas.DataFrame(data=corpus))
return dataset
def finetune_lora(model_name: str,
train_dataset_name: str,
eval_dataset_name: str,
new_model: str,
run_name: str,
lora_r:int = 4,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
use_4bit: bool = True,
bnb_4bit_compute_dtype: str = "float16",
bnb_4bit_quant_type: str = "nf4",
use_nested_quant: bool = False,
output_dir: str = "./results",
num_train_epochs: int = 1,
fp16: bool = False,
bf16: bool = True,
batch_size: int = 4,
micro_batch_size: int = 4,
gradient_checkpointing: bool = True,
max_grad_norm: float = 0.3,
learning_rate: float = 3e-4,
weight_decay: float = 0.001,
optim: str = "paged_adamw_32bit",
lr_scheduler_type: str = "cosine",
max_steps: int = -1,
warmup_ratio: float = 0.05,
group_by_length: bool = True,
save_steps: int = 0,
logging_steps: int = 1,
max_seq_length: int = 256,
packing: bool = False,
device_map: str = '{"":0}',
train_on_completions_only: bool = False):
device_map = json.loads(device_map)
train_dataset = load_data(train_dataset_name)
#eval_dataset = load_data(eval_dataset_name)
# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)
# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16: accelerate training with bf16=True")
print("=" * 80)
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
# Load LoRA configuration
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj"],
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
)
# Set training parameters
training_arguments = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=batch_size // micro_batch_size,
optim=optim,
save_steps=save_steps,
logging_steps=logging_steps,
learning_rate=learning_rate,
weight_decay=weight_decay,
fp16=fp16,
bf16=bf16,
max_grad_norm=max_grad_norm,
max_steps=max_steps,
warmup_ratio=warmup_ratio,
group_by_length=group_by_length,
lr_scheduler_type=lr_scheduler_type,
run_name=run_name,
report_to="wandb",
)
if train_on_completions_only:
response_template = "Réponse(s) :"
tokens = tokenizer.tokenize(response_template, add_special_tokens=False)
token_ids = tokenizer.encode(response_template, add_special_tokens=False)
print(list(zip(tokens, token_ids)))
collator = DataCollatorForCompletionOnlyLM(token_ids[1:], tokenizer=tokenizer)
else:
collator = None
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
#eval_dataset=eval_dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=packing,
data_collator=collator,
)
# Train model
trainer.train()
# Save trained model
trainer.model.save_pretrained(new_model)
def load_for_inference(model_name: str,
lora_model: str,
use_4bit: bool = True,
bnb_4bit_compute_dtype: str = "float16",
bnb_4bit_quant_type: str = "nf4",
use_nested_quant: bool = False,
device_map='{"":0}'):
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
#low_cpu_mem_usage=True,
quantization_config=bnb_config,
return_dict=True,
torch_dtype=torch.float16,
device_map=json.loads(device_map),
)
model = PeftModel.from_pretrained(base_model, lora_model)
#model = model.merge_and_unload()
# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return model, tokenizer
if __name__ == '__main__':
import fire
#print(load_data("../json/train.json")['text'][0])
fire.Fire(finetune_lora)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment