본문 바로가기

ai

BERTScore Knowledge Distillation

1. Introduction

  • BERTScore는 pretrained language model을 사용하여 두 문장의 유사도를 측정하는 방법이다. 주로 번역, 요약 등 문장 생성 모델을 평가하는 데 사용한다 [1].
  • language model의 크기가 클수록 BERTScore와 Human evalution의 상관 관계가 큰 경향이 있다.
  • 하지만 큰 모델은 어플리케이션에서 실시간으로 사용되기 어렵다는 단점이 있다.
  • 이를 해결하고자 Knowledge distillation을 통해 작은 모델이 큰 모델의 BERTScore를 따라하도록 학습시켰다.
  • 결과 모델: yongsun-yoon/minilmv2-bertscore-distilled

모델별 BERTScore와 Human evaluation과의 상관관계 [2]

2. Setup

import math
import wandb
import easydict
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn.functional as F

import huggingface_hub
from bert_score import BERTScorer
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel

cfg = easydict.EasyDict(
    device = 'cuda:0',
    student_name = 'nreimers/MiniLMv2-L6-H384-distilled-from-RoBERTa-Large',
    teacher_name = 'microsoft/deberta-large-mnli',
    teacher_layer_idx = 18,
    lr = 5e-5,
    batch_size = 8,
    num_epochs = 5
)

3. Data

  • 어느정도 유사성이 있는 문장쌍을 사용하기 위해 GLUE MNLI 데이터셋을 사용했다.
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, text1_key, text2_key):
        self.data = data
        self.text1_key = text1_key
        self.text2_key = text2_key

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text1 = item[self.text1_key]
        text2 = item[self.text2_key]
        return text1, text2

    def collate_fn(self, batch):
        texts1, texts2 = zip(*batch)
        return texts1, texts2

    def get_dataloader(self, batch_size, shuffle):
        return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle)

data = load_dataset('glue', 'mnli')

train_data = data['train']
train_data = train_data.train_test_split(train_size=80000)['train']
train_dataset = Dataset(train_data, 'premise', 'hypothesis')
train_loader = train_dataset.get_dataloader(cfg.batch_size, True)

test_data = data['validation_mismatched'].train_test_split(test_size=4000)['test']
test_dataset = Dataset(test_data, 'premise', 'hypothesis')
test_loader = test_dataset.get_dataloader(cfg.batch_size, False)

3. Model

teacher_tokenizer = AutoTokenizer.from_pretrained(cfg.teacher_name)
teacher_model = AutoModel.from_pretrained(cfg.teacher_name)
_ = teacher_model.eval().requires_grad_(False).to(cfg.device)

student_tokenizer = AutoTokenizer.from_pretrained(cfg.student_name)
student_model = AutoModel.from_pretrained(cfg.student_name)
_ = student_model.train().to(cfg.device)
optimizer = torch.optim.Adam(student_model.parameters(), lr=cfg.lr)

4. Train

  • 두 문장의 cross attention score를 계산한 뒤 teacher의 attention 분포를 student가 따라하도록 학습했다.
  • loss function으로 두 분포간의 차이를 계산하는 kl divergence를 사용했다.
  • 이때 teacher model과 student model의 tokenizer가 다를 경우 token 단위의 비교가 불가능하다. 이를 해결하기 위해 token을 word 단위로 변환했다.
def get_word_embeds(model, tokenizer, texts, layer_idx=-1, max_length=384):
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(model.device)
    outputs = model(**inputs, output_hidden_states=True)

    num_texts = inputs.input_ids.size(0)
    token_embeds = outputs.hidden_states[layer_idx]

    batch_word_embeds = []
    for i in range(num_texts):
        text_word_embeds = []

        j = 0
        while True:
            token_span = inputs.word_to_tokens(i, j)
            if token_span is None: break

            word_embed = token_embeds[i][token_span.start:token_span.end].mean(dim=0)
            text_word_embeds.append(word_embed)
            j += 1

        text_word_embeds = torch.stack(text_word_embeds, dim=0).unsqueeze(0) # (1, seq_length, hidden_dim)
        batch_word_embeds.append(text_word_embeds) 

    return batch_word_embeds


def kl_div_loss(s, t, temperature):
    if len(s.size()) != 2:
        s = s.view(-1, s.size(-1))
        t = t.view(-1, t.size(-1))

    s = F.log_softmax(s / temperature, dim=-1)
    t = F.softmax(t / temperature, dim=-1)
    return F.kl_div(s, t, reduction='batchmean') * (temperature ** 2)


def transpose_for_scores(h, num_heads):
    batch_size, seq_length, dim = h.size()
    head_size = dim // num_heads
    h = h.view(batch_size, seq_length, num_heads, head_size)
    return h.permute(0, 2, 1, 3) # (batch, num_heads, seq_length, head_size)


def attention(h1, h2, num_heads, attention_mask=None):
    # assert h1.size() == h2.size()
    head_size = h1.size(-1) // num_heads
    h1 = transpose_for_scores(h1, num_heads) # (batch, num_heads, seq_length, head_size)
    h2 = transpose_for_scores(h2, num_heads) # (batch, num_heads, seq_length, head_size)

    attn = torch.matmul(h1, h2.transpose(-1, -2)) # (batch_size, num_heads, seq_length, seq_length)
    attn = attn / math.sqrt(head_size)
    if attention_mask is not None:
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = (1 - attention_mask) * -10000.0
        attn = attn + attention_mask

    return attn


def train_epoch(
    teacher_model, teacher_tokenizer, 
    student_model, student_tokenizer,
    train_loader,
    teacher_layer_idx,
):

    student_model.train()
    pbar = tqdm(train_loader)
    for texts1, texts2 in pbar:
        teacher_embeds1 = get_word_embeds(teacher_model, teacher_tokenizer, texts1, layer_idx=teacher_layer_idx)
        teacher_embeds2 = get_word_embeds(teacher_model, teacher_tokenizer, texts2, layer_idx=teacher_layer_idx)

        student_embeds1 = get_word_embeds(student_model, student_tokenizer, texts1, layer_idx=-1)
        student_embeds2 = get_word_embeds(student_model, student_tokenizer, texts2, layer_idx=-1)

        teacher_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
        student_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
        loss1 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores1, student_scores1)]).mean()

        teacher_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
        student_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
        loss2 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores2, student_scores2)]).mean()

        loss = (loss1 + loss2) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        log = {'loss': loss.item(), 'loss1': loss.item(), 'loss2': loss2.item()}
        wandb.log(log)
        pbar.set_postfix(log)


def test_epoch(
    teacher_model, teacher_tokenizer, 
    student_model, student_tokenizer,
    test_loader,
    teacher_layer_idx,
):
    student_model.eval()
    test_loss, num_data = 0, 0
    for texts1, texts2 in test_loader:
        with torch.no_grad():
            teacher_embeds1 = get_word_embeds(teacher_model, teacher_tokenizer, texts1, layer_idx=teacher_layer_idx)
            teacher_embeds2 = get_word_embeds(teacher_model, teacher_tokenizer, texts2, layer_idx=teacher_layer_idx)

            student_embeds1 = get_word_embeds(student_model, student_tokenizer, texts1, layer_idx=-1)
            student_embeds2 = get_word_embeds(student_model, student_tokenizer, texts2, layer_idx=-1)

        teacher_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
        student_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
        loss1 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores1, student_scores1)]).mean()

        teacher_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
        student_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
        loss2 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores2, student_scores2)]).mean()

        loss = (loss1 + loss2) * 0.5
        batch_size = len(texts1)
        test_loss += loss.item() * batch_size
        num_data += batch_size

    test_loss /= num_data
    return test_loss



wandb.init(project='bert-score-distillation')

best_loss = 1e10
for ep in range(cfg.num_epochs):
    train_epoch(teacher_model, teacher_tokenizer, student_model, student_tokenizer, train_loader, cfg.teacher_layer_idx)
    test_loss = test_epoch(teacher_model, teacher_tokenizer, student_model, student_tokenizer, test_loader, cfg.teacher_layer_idx)

    print(f'ep {ep:02d} | loss {test_loss:.3f}')
    if test_loss < best_loss:
        student_model.save_pretrained('checkpoint')
        student_tokenizer.save_pretrained('checkpoint')
        best_loss = test_loss
        wandb.log({'test_loss': test_loss})

5. Evaluate

  • 학습 결과 teacher model과의 BERTScore 상관관계가 0.806에서 0.936으로 향상되었다.
def calculate_score(scorer, loader):
    scores = []
    for texts1, texts2 in tqdm(loader):
        P, R, F = scorer.score(texts1, texts2)
        scores += F.tolist()
    return scores


teacher_scorer = BERTScorer(model_type=cfg.teacher_name, num_layers=cfg.teacher_layer_idx)
student_scorer = BERTScorer(model_type=cfg.student_name, num_layers=6)
distilled_student_scorer = BERTScorer(model_type='checkpoint', num_layers=6)

teacher_scores = calculate_score(teacher_scorer, test_loader)
student_scores = calculate_score(student_scorer, test_loader)
distilled_scores = calculate_score(distilled_student_scorer, test_loader)

scores = pd.DataFrame({'teacher': teacher_scores, 'student': student_scores, 'distilled': distilled_scores})
scores.corr().round(3)

  • scatterplot 상에서도 distillation한 후에 teacher의 BERTScore를 더 잘 따라하는 것을 확인할 수 있다.
sns.scatterplot(x = 'teacher', y='student', data=scores)
sns.scatterplot(x = 'teacher', y='distilled', data=scores)
plt.title('scatterplot with teacher model')
plt.legend(['not distilled', 'distilled'])
plt.show()

 

 

Reference

[1] Zhang, T., Kishore, V., Wu, F., Weinberger, K. Q., & Artzi, Y. (2019). Bertscore: Evaluating text generation with bert. arXiv preprint arXiv:1904.09675.

[2] https://github.com/Tiiiger/bert_score