0. 들어가며
- Hydra는 python code의 configuration을 쉽게 관리하기 위한 라이브러리다.
- Lightning Fabric은 최소한의 코드 변경으로 Pytorch 모델의 효율적인 학습을 도와주는 라이브러리다.
- Fabric accelerates your PyTorch training or inference code with minimal changes required.
- Hydra와 Lightning Fabric을 사용하여 딥러닝 학습 코드 template을 만들었다.
- 예시를 위해 한국어 문장 분류 데이터인 KLUE YNAT 데이터를 사용했다.
1. 설치
- hydra와 lightning을 설치한다.
pip install hydra-core
pip install lightning
2. Config 정의
- hydra는 config 정의를 위해 yaml 파일을 사용한다.
- 데이터, 모델, 학습에 사용하는 주요 변수들을 작성한다.
- 이 변수들은 command line에서 실행할 때 override 할 수 있다. [url]
- fabric arguments [url]
config.yaml
# data
data_path: klue
data_name: ynat
max_length: 128
batch_size: 16
labels: ['IT과학', '경제', '사회', '생활문화', '세계', '스포츠', '정치']
# model
model_name: klue/roberta-base
lr: 5e-5
# train
num_epochs: 3
ckpt_path: roberta-base-ynat
# fabric
fabric:
accelerator: gpu
strategy: ddp
devices: [1, 2]
precision: 32
3. 데이터
train_loader와test_loader를 정의하는 함수를 작성한다.- Sequence classification에서 사용하는 일반적인 방법을 따랐다.
main.py
def _collate_fn(batch, tokenizer, max_length):
texts = [b['title'] for b in batch]
labels = [b['label'] for b in batch]
inputs = tokenizer(texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
labels = torch.tensor(labels)
return inputs, labels
def prepare_data(cfg, tokenizer):
dataset = load_dataset(cfg.data_path, cfg.data_name)
train_dataset, test_dataset = dataset['train'], dataset['validation']
collate_fn = lambda x: _collate_fn(x, tokenizer, cfg.max_length)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
return train_loader, test_loader
3. 모델
model과optimizer를 정의하는 함수를 작성한다.
main.py
def prepare_model(cfg):
num_labels = len(cfg.labels)
id2label = {i:l for i,l in enumerate(cfg.labels)}
model = AutoModelForSequenceClassification.from_pretrained(cfg.model_name, num_labels=num_labels, id2label=id2label)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
return model, optimizer
4. 학습 및 평가
- 학습과 평가를 위한 함수를 작성한다.
- 예시에서는 epoch을 단위로 작성하였다. 필요에 따라 step 단위로 변경할 수 있다.
fabric.is_global_zero: 여러 gpu 중 가장 첫 번째 gpu인지 여부를 확인할 수 있다. 모델 평가, 저장, 로깅의 조건으로 사용가능하다.
main.py
def train_epoch(fabric, model, optimizer, loader):
model.train()
pbar = tqdm(loader, disable=not fabric.is_global_zero)
for inputs, labels in pbar:
optimizer.zero_grad()
outputs = model(**inputs, labels=labels)
loss = outputs.loss
fabric.backward(loss)
optimizer.step()
preds = outputs.logits.argmax(dim=-1)
acc = (preds == labels).float().mean()
pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})
def evaluate_epoch(model, loader):
model.eval()
all_preds, all_labels = [], []
for inputs, labels in tqdm(loader):
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.argmax(dim=-1)
all_preds.append(preds.cpu())
all_labels.append(labels.cpu())
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)
acc = (all_preds == all_labels).float().mean().item()
return acc
5. main
- 모든 프로세스를 실행하는 main 함수를 작성한다.
fabric = L.Fabric(**cfg.fabric): Fabric instance를 생성한다.fabric.launch(): distributed process를 실행한다.fabric.setup(model, optimizer): device에 맞게 model과 optimizer를 준비한다.fabric.setup_dataloaders(train_loader, test_loader): device에 맞게 data loader를 준비한다.- 이 코드를 사용하지 않을 경우,
tensor.to(fabric.device)를 통해 명시적으로 device를 설정해야 한다.
- 이 코드를 사용하지 않을 경우,
fabric.is_global_zero인 경우에만 로깅, 평가, 저장을 진행한다.fabric.barrier(): 이 함수가 실행될 때 까지 모든 device를 기다린다. 모델 평가를 zero rank device에서만 진행하기 때문에 나머지 process에서 이를 기다리는 처리가 필요하다.
@hydra.main(version_base='1.3', config_path='.', config_name='config.yaml')
def main(cfg):
print(cfg)
fabric = L.Fabric(**cfg.fabric)
fabric.launch()
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
train_loader, test_loader = prepare_data(cfg, tokenizer)
model, optimizer = prepare_model(cfg)
wrapped_model, optimizer = fabric.setup(model, optimizer)
train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)
best_acc = 0.
for ep in range(cfg.num_epochs):
fabric.barrier()
train_epoch(fabric, model, optimizer, train_loader)
if fabric.is_global_zero:
acc = evaluate_epoch(model, test_loader)
print(f'ep {ep} | acc {acc:.3f}')
if acc > best_acc:
tokenizer.save_pretrained(cfg.ckpt_path)
model.save_pretrained(cfg.ckpt_path)
best_acc = acc
if __name__ == '__main__':
main()
6. 전체 코드
main.py
import hydra
import torch
import lightning as L
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def _collate_fn(batch, tokenizer, max_length):
texts = [b['title'] for b in batch]
labels = [b['label'] for b in batch]
inputs = tokenizer(texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
labels = torch.tensor(labels)
return inputs, labels
def prepare_data(cfg, tokenizer):
dataset = load_dataset(cfg.data_path, cfg.data_name)
train_dataset, test_dataset = dataset['train'], dataset['validation']
train_dataset = train_dataset.train_test_split(train_size=3000)['train']
collate_fn = lambda x: _collate_fn(x, tokenizer, cfg.max_length)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
return train_loader, test_loader
def prepare_model(cfg):
num_labels = len(cfg.labels)
id2label = {i:l for i,l in enumerate(cfg.labels)}
model = AutoModelForSequenceClassification.from_pretrained(cfg.model_name, num_labels=num_labels, id2label=id2label)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
return model, optimizer
def train_epoch(fabric, model, optimizer, loader):
model.train()
pbar = tqdm(loader, disable=not fabric.is_global_zero)
for inputs, labels in pbar:
optimizer.zero_grad()
outputs = model(**inputs, labels=labels)
loss = outputs.loss
fabric.backward(loss)
optimizer.step()
preds = outputs.logits.argmax(dim=-1)
acc = (preds == labels).float().mean()
pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})
def evaluate_epoch(model, loader):
model.eval()
all_preds, all_labels = [], []
for inputs, labels in tqdm(loader):
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.argmax(dim=-1)
all_preds.append(preds.cpu())
all_labels.append(labels.cpu())
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)
acc = (all_preds == all_labels).float().mean().item()
return acc
@hydra.main(version_base='1.3', config_path='.', config_name='config.yaml')
def main(cfg):
print(cfg)
fabric = L.Fabric(**cfg.fabric)
fabric.launch()
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
train_loader, test_loader = prepare_data(cfg, tokenizer)
model, optimizer = prepare_model(cfg)
wrapped_model, optimizer = fabric.setup(model, optimizer)
train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)
best_acc = 0.
for ep in range(cfg.num_epochs):
fabric.barrier()
train_epoch(fabric, model, optimizer, train_loader)
if fabric.is_global_zero:
acc = evaluate_epoch(model, test_loader)
print(f'ep {ep} | acc {acc:.3f}')
if acc > best_acc:
tokenizer.save_pretrained(cfg.ckpt_path)
model.save_pretrained(cfg.ckpt_path)
best_acc = acc
if __name__ == '__main__':
main()'ai' 카테고리의 다른 글
| BERTScore Knowledge Distillation (0) | 2023.07.06 |
|---|---|
| Maximal Marginal Relevance를 사용한 뉴스 요약 (0) | 2023.07.04 |
| 꼬맨틀 풀이 프로그램 개발 (0) | 2023.05.17 |
| Transformer BERT Fine-tuning: Named Entity Recognition (0) | 2023.05.16 |
| Transformer T5 Fine-tuning: Question Answering (0) | 2023.05.16 |