0. 들어가며
- Hydra는 python code의 configuration을 쉽게 관리하기 위한 라이브러리다.
 - Lightning Fabric은 최소한의 코드 변경으로 Pytorch 모델의 효율적인 학습을 도와주는 라이브러리다.
- Fabric accelerates your PyTorch training or inference code with minimal changes required.
 
 - Hydra와 Lightning Fabric을 사용하여 딥러닝 학습 코드 template을 만들었다.
 - 예시를 위해 한국어 문장 분류 데이터인 KLUE YNAT 데이터를 사용했다.
 
1. 설치
- hydra와 lightning을 설치한다.
 
pip install hydra-core
pip install lightning
2. Config 정의
- hydra는 config 정의를 위해 yaml 파일을 사용한다.
 - 데이터, 모델, 학습에 사용하는 주요 변수들을 작성한다.
 - 이 변수들은 command line에서 실행할 때 override 할 수 있다. [url]
 - fabric arguments [url]
 
config.yaml
# data
data_path: klue
data_name: ynat
max_length: 128
batch_size: 16
labels: ['IT과학', '경제', '사회', '생활문화', '세계', '스포츠', '정치']
# model
model_name: klue/roberta-base
lr: 5e-5
# train
num_epochs: 3
ckpt_path: roberta-base-ynat
# fabric
fabric:
    accelerator: gpu
    strategy: ddp
    devices: [1, 2]
    precision: 32
3. 데이터
train_loader와test_loader를 정의하는 함수를 작성한다.- Sequence classification에서 사용하는 일반적인 방법을 따랐다.
 
main.py
def _collate_fn(batch, tokenizer, max_length):
    texts = [b['title'] for b in batch]
    labels = [b['label'] for b in batch]
    inputs = tokenizer(texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor(labels)
    return inputs, labels
def prepare_data(cfg, tokenizer):
    dataset = load_dataset(cfg.data_path, cfg.data_name)
    train_dataset, test_dataset = dataset['train'], dataset['validation']
    collate_fn = lambda x: _collate_fn(x, tokenizer, cfg.max_length)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    return train_loader, test_loader
3. 모델
model과optimizer를 정의하는 함수를 작성한다.
main.py
def prepare_model(cfg):
    num_labels = len(cfg.labels)
    id2label = {i:l for i,l in enumerate(cfg.labels)}
    model = AutoModelForSequenceClassification.from_pretrained(cfg.model_name, num_labels=num_labels, id2label=id2label)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    return model, optimizer
4. 학습 및 평가
- 학습과 평가를 위한 함수를 작성한다.
 - 예시에서는 epoch을 단위로 작성하였다. 필요에 따라 step 단위로 변경할 수 있다.
 fabric.is_global_zero: 여러 gpu 중 가장 첫 번째 gpu인지 여부를 확인할 수 있다. 모델 평가, 저장, 로깅의 조건으로 사용가능하다.
main.py
def train_epoch(fabric, model, optimizer, loader):
    model.train()
    pbar = tqdm(loader, disable=not fabric.is_global_zero)
    for inputs, labels in pbar:
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        fabric.backward(loss)
        optimizer.step()
        preds = outputs.logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()
        pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})
def evaluate_epoch(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    for inputs, labels in tqdm(loader):
        with torch.no_grad():
            outputs = model(**inputs)
        preds = outputs.logits.argmax(dim=-1)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    acc = (all_preds == all_labels).float().mean().item()
    return acc
5. main
- 모든 프로세스를 실행하는 main 함수를 작성한다.
 fabric = L.Fabric(**cfg.fabric): Fabric instance를 생성한다.fabric.launch(): distributed process를 실행한다.fabric.setup(model, optimizer): device에 맞게 model과 optimizer를 준비한다.fabric.setup_dataloaders(train_loader, test_loader): device에 맞게 data loader를 준비한다.- 이 코드를 사용하지 않을 경우,
tensor.to(fabric.device)를 통해 명시적으로 device를 설정해야 한다. 
- 이 코드를 사용하지 않을 경우,
 fabric.is_global_zero인 경우에만 로깅, 평가, 저장을 진행한다.fabric.barrier(): 이 함수가 실행될 때 까지 모든 device를 기다린다. 모델 평가를 zero rank device에서만 진행하기 때문에 나머지 process에서 이를 기다리는 처리가 필요하다.
@hydra.main(version_base='1.3', config_path='.', config_name='config.yaml')
def main(cfg):
    print(cfg)
    fabric = L.Fabric(**cfg.fabric)
    fabric.launch()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_loader, test_loader = prepare_data(cfg, tokenizer)
    model, optimizer = prepare_model(cfg)
    wrapped_model, optimizer = fabric.setup(model, optimizer)
    train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)
    best_acc = 0.
    for ep in range(cfg.num_epochs):
        fabric.barrier()
        train_epoch(fabric, model, optimizer, train_loader)
        if fabric.is_global_zero:
            acc = evaluate_epoch(model, test_loader)
            print(f'ep {ep} | acc {acc:.3f}')
            if acc > best_acc:
                tokenizer.save_pretrained(cfg.ckpt_path)
                model.save_pretrained(cfg.ckpt_path)
                best_acc = acc
if __name__ == '__main__':
    main()
6. 전체 코드
main.py
import hydra
import torch
import lightning as L
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def _collate_fn(batch, tokenizer, max_length):
    texts = [b['title'] for b in batch]
    labels = [b['label'] for b in batch]
    inputs = tokenizer(texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor(labels)
    return inputs, labels
def prepare_data(cfg, tokenizer):
    dataset = load_dataset(cfg.data_path, cfg.data_name)
    train_dataset, test_dataset = dataset['train'], dataset['validation']
    train_dataset = train_dataset.train_test_split(train_size=3000)['train']
    collate_fn = lambda x: _collate_fn(x, tokenizer, cfg.max_length)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    return train_loader, test_loader
def prepare_model(cfg):
    num_labels = len(cfg.labels)
    id2label = {i:l for i,l in enumerate(cfg.labels)}
    model = AutoModelForSequenceClassification.from_pretrained(cfg.model_name, num_labels=num_labels, id2label=id2label)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    return model, optimizer
def train_epoch(fabric, model, optimizer, loader):
    model.train()
    pbar = tqdm(loader, disable=not fabric.is_global_zero)
    for inputs, labels in pbar:
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        fabric.backward(loss)
        optimizer.step()
        preds = outputs.logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()
        pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})
def evaluate_epoch(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    for inputs, labels in tqdm(loader):
        with torch.no_grad():
            outputs = model(**inputs)
        preds = outputs.logits.argmax(dim=-1)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    acc = (all_preds == all_labels).float().mean().item()
    return acc
@hydra.main(version_base='1.3', config_path='.', config_name='config.yaml')
def main(cfg):
    print(cfg)
    fabric = L.Fabric(**cfg.fabric)
    fabric.launch()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_loader, test_loader = prepare_data(cfg, tokenizer)
    model, optimizer = prepare_model(cfg)
    wrapped_model, optimizer = fabric.setup(model, optimizer)
    train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)
    best_acc = 0.
    for ep in range(cfg.num_epochs):
        fabric.barrier()
        train_epoch(fabric, model, optimizer, train_loader)
        if fabric.is_global_zero:
            acc = evaluate_epoch(model, test_loader)
            print(f'ep {ep} | acc {acc:.3f}')
            if acc > best_acc:
                tokenizer.save_pretrained(cfg.ckpt_path)
                model.save_pretrained(cfg.ckpt_path)
                best_acc = acc
if __name__ == '__main__':
    main()'ai' 카테고리의 다른 글
| BERTScore Knowledge Distillation (0) | 2023.07.06 | 
|---|---|
| Maximal Marginal Relevance를 사용한 뉴스 요약 (0) | 2023.07.04 | 
| 꼬맨틀 풀이 프로그램 개발 (0) | 2023.05.17 | 
| Transformer BERT Fine-tuning: Named Entity Recognition (0) | 2023.05.16 | 
| Transformer T5 Fine-tuning: Question Answering (0) | 2023.05.16 |