trainer → LightningModule.validation_step() → sanity check → LightningModule.training_step()
모듈 다운로드
import torch
import torchmetrics
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
GPU or CPU 사용
USE_CUDA = torch.cuda.is_available() # GPU를 사용가능하면 True, 아니라면 False를 리턴
device = torch.device("cuda" if USE_CUDA else "cpu") # GPU 사용 가능하면 사용하고 아니면 CPU 사용
print("다음 기기로 학습합니다:", device)
모델 base line
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(...)
self.accuracy = torchmetrics.Accuracy(...)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
학습, 검증, 테스트 데이터셋 불러오기
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataset, val_dataset = random_split(training_data, [55000, 5000])
batch_size = 28
epochs = 10
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
모델 트레이닝
model = LitModel()
trainer = Trainer(max_epochs=epochs, gpus=0)
trainer.fit(model, train_dataloader, val_dataloader)
모델 테스트
# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)