진행 과정

trainer → LightningModule.validation_step() → sanity check → LightningModule.training_step()

장점

  1. data.to(’cuda’), model.to(’cuda’)를 직접 다 안해줘도 된다.
  2. 유지보수가 쉽다.
  3. checkpoint에 저장이 되어 모델을 불러오기 쉽다.
  4. CustomLoop를 작성해 Trainer도 수정이 가능하다.

구성

모듈 다운로드

import torch
import torchmetrics
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint

GPU or CPU 사용

USE_CUDA = torch.cuda.is_available() # GPU를 사용가능하면 True, 아니라면 False를 리턴
device = torch.device("cuda" if USE_CUDA else "cpu") # GPU 사용 가능하면 사용하고 아니면 CPU 사용
print("다음 기기로 학습합니다:", device)

모델 base line

class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(...)
        self.accuracy = torchmetrics.Accuracy(...)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
				return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
				return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

학습, 검증, 테스트 데이터셋 불러오기

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
train_dataset, val_dataset = random_split(training_data, [55000, 5000])

batch_size = 28
epochs = 10

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

모델 트레이닝

model = LitModel()
trainer = Trainer(max_epochs=epochs, gpus=0)
trainer.fit(model, train_dataloader, val_dataloader)

모델 테스트

# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)