본문 바로가기

ai

Transformer T5 Fine-tuning: Question Answering

0. 들어가며

  • pretrained T5 모델을 KorQuAD 데이터로 finetuning하는 방법

1. Setup

  • python libraries
import numpy as np
from tqdm import tqdm

import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
  • 주요 변수
DEVICE = 'cuda:0'
DATA_NAME = 'KETI-AIR/korquad'
MODEL_NAME = 'google/mt5-base'
SAVE_PATH = 'mt5-base-korquad'
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
NUM_TRAINING_STEPS = 10000

2. Data

  • Dataset code
  • return_overflowing_tokens를 사용하여 context를 max_length 크기의 paragraph로 나눴다.
  • 나눠진 paragraph 중 하나를 랜덤하게 선택한다.
  • answer span이 선택된 paragraph에 포함되지 않을 경우 ‘알 수 없음’을 출력하도록 학습시킨다.
  • 본문: {context} 질문: {question} 형태로 encoder 입력을 구성했다.
  • label의 padding 부분은 cross entropy에 영향을 주지 않기 위해 -100으로 설정했다.
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, context_max_length=256, answer_max_length=100, doc_stride=128):
        self.data = data
        self.tokenizer = tokenizer
        self.context_max_length = context_max_length
        self.answer_max_length = answer_max_length
        self.doc_stride = doc_stride
        self.pad_token_id = self.tokenizer.pad_token_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        context, question = item['context'], item['question']
        answer_text, answer_start = item['answers']['text'][0], item['answers']['answer_start'][0]

        context_inputs = tokenizer(context, max_length=self.context_max_length, stride=self.doc_stride, truncation=True, return_overflowing_tokens=True, add_special_tokens=False)
        num_paragraphs = len(context_inputs['input_ids'])
        paragraph_idx = np.random.randint(0, num_paragraphs)

        paragraph = self.tokenizer.decode(context_inputs.input_ids[paragraph_idx])
        input_text = f'본문: {paragraph} 질문: {question}'

        if context_inputs.char_to_token(paragraph_idx, answer_start) is None:
            answer_text = '알 수 없음'

        input_ids = self.tokenizer(input_text).input_ids
        labels = self.tokenizer(answer_text).input_ids
        return input_ids, labels

    def collate_fn(self, batch):
        input_ids, labels = zip(*batch)
        input_ids = pad_seqs(input_ids, self.pad_token_id, None)
        labels = pad_seqs(labels, -100, self.answer_max_length)
        return input_ids, labels

    def get_dataloader(self, batch_size, shuffle):
        return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn)

def pad_seqs(seqs, pad_val, max_length=None):
    _max_length = max([len(s) for s in seqs])
    max_length = _max_length if max_length is None else min(_max_length, max_length)

    padded_seqs = []
    for seq in seqs:
        seq = seq[:max_length]
        pads = [pad_val] * (max_length - len(seq))
        padded_seqs.append(seq + pads)

    return torch.tensor(padded_seqs)
  • 실험 시간 단축을 위해 1000개의 test data만 사용
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

data = load_dataset(DATA_NAME, 'v1.0')
train_data, test_data = data['train'], data['dev']
test_data = test_data.train_test_split(test_size=1000)['test']

train_dataset = Dataset(train_data, tokenizer)
train_loader = train_dataset.get_dataloader(BATCH_SIZE, True)

test_dataset = Dataset(test_data, tokenizer)
test_loader = test_dataset.get_dataloader(BATCH_SIZE, False)

input_ids, labels = next(iter(train_loader))

3. Train

  • evaluation code
  • exact match와 character-level f1를 사용하여 성능 평가
def metric_fn(pred, target):
    em = int(pred == target)

    common = set(pred) & set(target)
    if len(common) == 0:
        f1 = 0
    else:
        precision = len(common) / len(pred)
        recall = len(common) / len(target)
        f1 = 2 * (precision * recall) / (precision + recall)

    return em, f1

def evaluate(model, test_loader):
    model.eval()
    preds, targets = [], []
    for input_ids, labels in tqdm(test_loader, disable=True):
        input_ids, labels = input_ids.to(DEVICE), labels.to(DEVICE)
        outputs = model.generate(input_ids, max_length=100)

        _preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        clean_labels = torch.where(labels==-100, 0, labels)
        _targets = tokenizer.batch_decode(clean_labels, skip_special_tokens=True)

        preds += _preds
        targets += _targets

    em, f1 = 0, 0
    for p, t in zip(preds, targets):
        _em, _f1 = metric_fn(p, t)
        em += _em
        f1 += _f1

    em /= len(preds)
    f1 /= len(preds)

    return em, f1
  • train code
  • 1000 step 마다 evaluation 진행
  • f1 score를 기준으로 checkpoint 저장
  • 학습 결과
    • em: 0.817
    • f1: 0.805
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
_ = model.train().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

best_f1 = 0.
train_iter = iter(train_loader)
pbar = tqdm(range(1, NUM_TRAINING_STEPS+1))
for st in pbar:
    try:
        input_ids, labels = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        input_ids, labels = next(train_iter)

    input_ids, labels = input_ids.to(DEVICE), labels.to(DEVICE)
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    pbar.set_postfix({'loss': loss.item()})

    if st % 1000 == 0:
        em, f1 = evaluate(model, test_loader)
        print(f'st {st:05d} | em {em:.3f} | f1 {f1:.3f}')

        if f1 > best_f1:
            tokenizer.save_pretrained(SAVE_PATH)
            model.save_pretrained(SAVE_PATH)
            best_f1 = f1

4. Push to Huggingface Hub

  • 학습한 모델을 Huggingface Hub에 업로드
import huggingface_hub

huggingface_hub.login(os.environ['HF_HUB_TOKEN'])

tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(SAVE_PATH)

tokenizer.push_to_hub(SAVE_PATH)
model.push_to_hub(SAVE_PATH)

5. Test

  • text2text-gereneration pipeline을 사용한 QA inference
from transformers import pipeline

def run_qa(nlp, context, question):
    input_text = f'본문: {context} 질문: {question}'
    outputs = nlp(input_text)
    answer = outputs[0]['generated_text']
    return answer

nlp = pipeline(task='text2text-generation', model='yongsun-yoon/mt5-base-korquad')

context = '오늘 전국이 대체로 맑고 기온이 최고 34도까지 올라가 여름처럼 더운 날씨가 예상된다.'
question = '오늘 최고 기온 몇도야?'
run_qa(nlp, context, question) # 34도