본문 바로가기

ai

Hydra + Lightning Fabric으로 딥러닝 학습 template 만들기

0. 들어가며

  • Hydra는 python code의 configuration을 쉽게 관리하기 위한 라이브러리다.
  • Lightning Fabric은 최소한의 코드 변경으로 Pytorch 모델의 효율적인 학습을 도와주는 라이브러리다.
    • Fabric accelerates your PyTorch training or inference code with minimal changes required.
  • Hydra와 Lightning Fabric을 사용하여 딥러닝 학습 코드 template을 만들었다.
  • 예시를 위해 한국어 문장 분류 데이터인 KLUE YNAT 데이터를 사용했다.

1. 설치

  • hydra와 lightning을 설치한다.
pip install hydra-core
pip install lightning

2. Config 정의

  • hydra는 config 정의를 위해 yaml 파일을 사용한다.
  • 데이터, 모델, 학습에 사용하는 주요 변수들을 작성한다.
  • 이 변수들은 command line에서 실행할 때 override 할 수 있다. [url]
  • fabric arguments [url]

config.yaml

# data
data_path: klue
data_name: ynat
max_length: 128
batch_size: 16
labels: ['IT과학', '경제', '사회', '생활문화', '세계', '스포츠', '정치']

# model
model_name: klue/roberta-base
lr: 5e-5

# train
num_epochs: 3
ckpt_path: roberta-base-ynat

# fabric
fabric:
    accelerator: gpu
    strategy: ddp
    devices: [1, 2]
    precision: 32

3. 데이터

  • train_loadertest_loader를 정의하는 함수를 작성한다.
  • Sequence classification에서 사용하는 일반적인 방법을 따랐다.

main.py

def _collate_fn(batch, tokenizer, max_length):
    texts = [b['title'] for b in batch]
    labels = [b['label'] for b in batch]
    inputs = tokenizer(texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor(labels)
    return inputs, labels

def prepare_data(cfg, tokenizer):
    dataset = load_dataset(cfg.data_path, cfg.data_name)
    train_dataset, test_dataset = dataset['train'], dataset['validation']
    collate_fn = lambda x: _collate_fn(x, tokenizer, cfg.max_length)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    return train_loader, test_loader

3. 모델

  • modeloptimizer를 정의하는 함수를 작성한다.

main.py

def prepare_model(cfg):
    num_labels = len(cfg.labels)
    id2label = {i:l for i,l in enumerate(cfg.labels)}
    model = AutoModelForSequenceClassification.from_pretrained(cfg.model_name, num_labels=num_labels, id2label=id2label)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    return model, optimizer

4. 학습 및 평가

  • 학습과 평가를 위한 함수를 작성한다.
  • 예시에서는 epoch을 단위로 작성하였다. 필요에 따라 step 단위로 변경할 수 있다.
  • fabric.is_global_zero: 여러 gpu 중 가장 첫 번째 gpu인지 여부를 확인할 수 있다. 모델 평가, 저장, 로깅의 조건으로 사용가능하다.

main.py

def train_epoch(fabric, model, optimizer, loader):
    model.train()

    pbar = tqdm(loader, disable=not fabric.is_global_zero)
    for inputs, labels in pbar:
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        fabric.backward(loss)
        optimizer.step()

        preds = outputs.logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()
        pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})

def evaluate_epoch(model, loader):
    model.eval()

    all_preds, all_labels = [], []
    for inputs, labels in tqdm(loader):
        with torch.no_grad():
            outputs = model(**inputs)
        preds = outputs.logits.argmax(dim=-1)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    acc = (all_preds == all_labels).float().mean().item()
    return acc

5. main

  • 모든 프로세스를 실행하는 main 함수를 작성한다.
  • fabric = L.Fabric(**cfg.fabric): Fabric instance를 생성한다.
  • fabric.launch(): distributed process를 실행한다.
  • fabric.setup(model, optimizer): device에 맞게 model과 optimizer를 준비한다.
  • fabric.setup_dataloaders(train_loader, test_loader): device에 맞게 data loader를 준비한다.
    • 이 코드를 사용하지 않을 경우,tensor.to(fabric.device)를 통해 명시적으로 device를 설정해야 한다.
  • fabric.is_global_zero인 경우에만 로깅, 평가, 저장을 진행한다.
  • fabric.barrier(): 이 함수가 실행될 때 까지 모든 device를 기다린다. 모델 평가를 zero rank device에서만 진행하기 때문에 나머지 process에서 이를 기다리는 처리가 필요하다.
@hydra.main(version_base='1.3', config_path='.', config_name='config.yaml')
def main(cfg):
    print(cfg)

    fabric = L.Fabric(**cfg.fabric)
    fabric.launch()

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_loader, test_loader = prepare_data(cfg, tokenizer)
    model, optimizer = prepare_model(cfg)

    wrapped_model, optimizer = fabric.setup(model, optimizer)
    train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)

    best_acc = 0.
    for ep in range(cfg.num_epochs):
        fabric.barrier()
        train_epoch(fabric, model, optimizer, train_loader)

        if fabric.is_global_zero:
            acc = evaluate_epoch(model, test_loader)
            print(f'ep {ep} | acc {acc:.3f}')
            if acc > best_acc:
                tokenizer.save_pretrained(cfg.ckpt_path)
                model.save_pretrained(cfg.ckpt_path)
                best_acc = acc


if __name__ == '__main__':
    main()

6. 전체 코드

main.py

import hydra
import torch
import lightning as L
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

def _collate_fn(batch, tokenizer, max_length):
    texts = [b['title'] for b in batch]
    labels = [b['label'] for b in batch]
    inputs = tokenizer(texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor(labels)
    return inputs, labels

def prepare_data(cfg, tokenizer):
    dataset = load_dataset(cfg.data_path, cfg.data_name)
    train_dataset, test_dataset = dataset['train'], dataset['validation']
    train_dataset = train_dataset.train_test_split(train_size=3000)['train']

    collate_fn = lambda x: _collate_fn(x, tokenizer, cfg.max_length)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    return train_loader, test_loader

def prepare_model(cfg):
    num_labels = len(cfg.labels)
    id2label = {i:l for i,l in enumerate(cfg.labels)}
    model = AutoModelForSequenceClassification.from_pretrained(cfg.model_name, num_labels=num_labels, id2label=id2label)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    return model, optimizer


def train_epoch(fabric, model, optimizer, loader):
    model.train()

    pbar = tqdm(loader, disable=not fabric.is_global_zero)
    for inputs, labels in pbar:
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        fabric.backward(loss)
        optimizer.step()

        preds = outputs.logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()
        pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})

def evaluate_epoch(model, loader):
    model.eval()

    all_preds, all_labels = [], []
    for inputs, labels in tqdm(loader):
        with torch.no_grad():
            outputs = model(**inputs)
        preds = outputs.logits.argmax(dim=-1)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    acc = (all_preds == all_labels).float().mean().item()
    return acc

@hydra.main(version_base='1.3', config_path='.', config_name='config.yaml')
def main(cfg):
    print(cfg)

    fabric = L.Fabric(**cfg.fabric)
    fabric.launch()

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_loader, test_loader = prepare_data(cfg, tokenizer)
    model, optimizer = prepare_model(cfg)

    wrapped_model, optimizer = fabric.setup(model, optimizer)
    train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)

    best_acc = 0.
    for ep in range(cfg.num_epochs):
        fabric.barrier()
        train_epoch(fabric, model, optimizer, train_loader)

        if fabric.is_global_zero:
            acc = evaluate_epoch(model, test_loader)
            print(f'ep {ep} | acc {acc:.3f}')
            if acc > best_acc:
                tokenizer.save_pretrained(cfg.ckpt_path)
                model.save_pretrained(cfg.ckpt_path)
                best_acc = acc


if __name__ == '__main__':
    main()