본문 바로가기

ai

Transformer BERT Fine-tuning: Named Entity Recognition

0. 들어가며

  • Pretrained BERT를 KLUE NER 데이터셋에 fine-tuning

1. Setup

  • import libraries
import easydict
import itertools
from tqdm import tqdm

from sklearn.metrics import f1_score
from seqeval.metrics import f1_score as ner_f1_score
from seqeval.scheme import IOB2

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
  • config
cfg = easydict.EasyDict(
    device = 'cuda:0',
    model_name = 'klue/roberta-base',
    save_path = 'roberta-base-ner',
    batch_size = 16,
    num_epochs = 5,
    lr = 5e-5,
)

2. Data

  • dataset class code
  • character-level tag를 token-level로 변환
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.outside_label = LABEL2ID['O']
        self.pad_token_id = self.tokenizer.pad_token_id
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        text = ''.join(item['tokens'])
        char_labels = item['ner_tags']

        inputs = self.tokenizer(text)
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask

        labels = [self.outside_label] * len(input_ids)
        for i in range(len(input_ids)):
            span = inputs.token_to_chars(i)
            if span is not None:
                labels[i] = char_labels[span.start]
        
        return input_ids, attention_mask, labels
    
    
    def collate_fn(self, batch):
        input_ids, attention_mask, labels = zip(*batch)
        input_ids = pad_seqs(input_ids, self.pad_token_id)
        attention_mask = pad_seqs(attention_mask, 0)
        labels = pad_seqs(labels, -100)
        return input_ids, attention_mask, labels
    
    def get_dataloader(self, batch_size, shuffle):
        return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn)

    
def pad_seqs(seqs, pad_val, max_length=256):
    _max_length = max([len(s) for s in seqs])
    max_length = _max_length if max_length is None else min(_max_length, max_length)
    
    padded_seqs = []
    for seq in seqs:
        seq = seq[:max_length]
        pads = [pad_val] * (max_length - len(seq))
        padded_seqs.append(seq + pads)
    
    return torch.tensor(padded_seqs)
  • data code
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)

data = load_dataset('klue', 'ner')
train_data, test_data = data['train'], data['validation']

LABELS = train_data.info.features['ner_tags'].feature.names
NUM_LABELS = len(LABELS)
LABEL2ID = {l:i for i,l in enumerate(LABELS)}
ID2LABEL = {i:l for i,l in enumerate(LABELS)}

3. Train

  • train/evaluate function code
  • KLUE benchmark에서 사용한 entity f1과 character f1을 평가지표로 사용
def train_epoch(model, loader):
    device = next(model.parameters()).device
    model.train()

    pbar = tqdm(loader)
    for batch in pbar:
        batch = [b.to(device) for b in batch]
        input_ids, attention_mask, labels = batch
        
        outputs = model(input_ids, attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({'loss': loss.item()})

def predict_epoch(model, loader):
    device = next(model.parameters()).device
    model.eval()

    total_preds, total_labels = [], []
    for batch in tqdm(loader):
        batch = [b.to(device) for b in batch]
        input_ids, attention_mask, labels = batch
        with torch.no_grad():
            outputs = model(input_ids, attention_mask, labels=labels)
        
        preds = outputs.logits.argmax(dim=-1)
        total_preds += preds.cpu().tolist()
        total_labels += labels.cpu().tolist()

    return total_preds, total_labels


def remove_padding(preds, labels):
    removed_preds, removed_labels = [], []
    for p, l in zip(preds, labels):
        if -100 not in l: continue

        idx = l.index(-100)
        removed_preds.append(p[:idx])
        removed_labels.append(l[:idx])
    
    return removed_preds, removed_labels


def entity_f1_func(preds, targets):
    preds = [[ID2LABEL[p] for p in pred] for pred in preds]
    targets = [[ID2LABEL[t] for t in target] for target in targets]
    entity_macro_f1 = ner_f1_score(targets, preds, average="macro", mode="strict", scheme=IOB2)
    f1 = entity_macro_f1 * 100.0
    return round(f1, 2)


def char_f1_func(preds, targets):
    label_indices = list(range(len(LABELS)))
    preds = list(itertools.chain(*preds))
    targets = list(itertools.chain(*targets))
    f1 = f1_score(targets, preds, labels=label_indices, average='macro', zero_division=True) * 100.0
    return round(f1, 2)


def evaluate_epoch(model, loader):
    preds, labels = predict_epoch(model, loader)
    preds, labels = remove_padding(preds, labels)
    entity_f1 = entity_f1_func(preds, labels)
    char_f1 = char_f1_func(preds, labels)
    return entity_f1, char_f1
  • training code
  • entity f1을 기준으로 checkpoint 저장
  • results
    • entity f1: 87.68
    • char f1: 92.61
model = AutoModelForTokenClassification.from_pretrained(cfg.model_name, num_labels=NUM_LABELS, id2label=ID2LABEL, label2id=LABEL2ID)
_ = model.train().to(cfg.device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

best_score = 0.
for ep in range(1, cfg.num_epochs+1):
    train_epoch(model, train_loader)
    entity_f1, char_f1 = evaluate_epoch(model, test_loader)
    print(f'ep: {ep:02d} | entity f1: {entity_f1:.2f} | char f1: {char_f1:.2f}')

    if entity_f1 > best_score:
        model.save_pretrained(cfg.save_path)
        tokenizer.save_pretrained(cfg.save_path)
        best_score = entity_f1

4. Test

  • transformers pipeline을 활용한 inference code
nlp = pipeline('token-classification', model='roberta-base-ner', aggregation_strategy='simple')
text = '김하성은 16일 펫코파크에서 열린 캔자스시티 로얄스와 홈 맞대결에 7번 타자로 선발 출전해 4타수 1안타를 기록했다.'
nlp(text)