PyTorch AMP 기본 구조

####################################################### PyTorch AMP를 활용한 mixed precision learning 예시####################################################### 모델 생성
model = Net().cuda()
# Optimizer 생성
optimizer = optim.SGD(model.parameters(), ...)

# AMP : loss scale을 위한 GradScaler 생성
scaler = GradScaler()

### 학습 시작for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

# AMP : Forward pass 진행# AMP : autocast를 통한 자동 FP32 -> FP16 변환 (가능한 연산에 한하여)with autocast():
            output = model(input)
            loss = loss_fn(output, target)

# AMP : scaled loss를 이용해 backward 진행 (gradient 모두 같은 scale factor로 scale됨)# AMP : backward pass는 autocast 영역 내에 진행될 필요 없음# AMP : forward pass에서 사용된 같은 data type으로 backward pass는 실행됨
        scaler.scale(loss).backward()

# AMP : scaler.step은 가장 먼저 unscale(grad를 scale factor만큼 나눠기)# AMP : weight update 실시, 단 만약 grad 중 infs or NaNs이 있으면 step 스킵됨
        scaler.step(optimizer)

# AMP : scale factor 업데이트
        scaler.update()

위 코드는 AMP example[2]에서 가져온 코드로 전형적인 AMP를 이용한 딥러닝 학습 코드이다. 코드 중 AMP와 관련된 부분은 주석 "AMP :"로 표시하였다. 일반적인 학습 과정 중,

  1. forward pass를 autocast 상태에서 진행하는 것과
  2. backward pass 시 gradient를 scale하고, weight update 시 gradient를 unscale하는 과정이 추가되었다. 그것을 제외하고는 학습 코드에 큰 변화가 없어 편리성이 뛰어나다.

정리하자면, automatic mixed precision learning 시 torch.autocast(torch.cuda.amp.autocast), torch.cuda.amp.GradScaler API만을 사용하면 된다.

torch.autocast (torch.cuda.amp.autocast)

autocast는 context manager로서, autocast가 선언된 코드 영역에서는 mixed precision 연산이 진행된다. 이 영역 내에서 연산들은 FP16(BF16) or FP32 중 autocast가 선택한 data type으로 연산이 되는데, FP16으로 변환되어 연산 목록은 밑에 정리해 두었다. 따로 타입 변환을 위한 함수를 호출할 필요 없이 영역에서 실행되는 것만으로 기준에 따라 데이터 타입이 변환된다.

autocast는 딥러닝 네트워크 학습 시 forward pass(loss를 계산하는 것까지)에서만 선언되어야 한다. Backward pass는 forward pass에서 선택된 data type으로 맞춰져서 실행된다.

autocast는 thread local이기에, 여러 thread에서 학습 실행 시, 모든 thread에서 각각 autocast를 선언해줘야 한다. Multi GPUs를 사용하거나 multiple nodes를 사용할 때 주의해져야 한다.

FP16(BF16)으로 자동 타입 변환되는 연산 목록

matmul, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell

torch.cuda.amp.GradScalar

FP16 데이터 타입으로 gradient를 저장할 시, 작은 크기를 가지는 gradients는 zero가 될 가능성(underflow 현상)있다. 이 경우 weight update가 제대로 되지 않아 네트워크 학습이 수렴하지 않을 수 있다. 원래 gradients가 0.00001인데 underflow 현상에 의해 0이 되면 아무리 gradients가 발생하더라도 gradients는 0이기에 weight 값은 변하지 않는다.

이를 해결하는 방법이 loss를 어떤 수(scale factor)만큼 곱해 크게 만드는 loss scaling 기법이다. Backward pass에서 gradient를 계산할 때, 그 값을 scaling 하여 작은 gradient를 큰 수로 만든다. 이를 통해 gradients가 0이 되는 underflow를 방지할 수 있다. 다만 gradient가 실제 값보다 scale factor만큼 곱해진 값을 가지므로 weight update 시에는 scale factor만큼 나눠주는 unscale이 필수이다. (gradients는 FP16인데 비해, weights는 FP32이므로 weights update 시 scale factor로 나눠주어도 underflow 발생확률이 매우 작다.)

GradScaler 클래스는 이를 수행하기 위해 만들어진 클래스로 loss를 scale()하여 backward 계산을 실행한다. 이후 step() 함수를 이용해 weight update를 진행한다. 이때 loss가 scale factor만큼 곱해졌기 때문에 weight update 전에 unscale을 진행한다.

GradScalar 내의 scale factor는 미리 정해진 수이다. Scale factor는 gradient 계산 시에만 이용되고 weight update시에는 unscale되기에 적절한 알려진 수이기만 하면 된다. 다만 scale factor가 너무 클 경우 gradient를 FP16으로 표현할 수 있는 수보다 크게 scale할 수 있다(이를 overflow라고 함). overflow가 발생하면 gradient가 매우 이상한 값(inf, NaN 등)을 가지게 되므로 학습을 한번에 diverge하게 할 수 있다. 따라서 overflow가 발생하면 다음 두 가지 방법으로 학습을 안정화한다.

  1. Scaled gradients가 inf, NaN이 되면, step() 함수는 skip되어 해당 gradients는 weight update에 사용되지 않고 버린다.

  2. Scale factor가 크다고 판단되어 update() 함수를 통해 더 작은 수로 교체된다. 미리 정해진  backoff_factor 만큼 곱해져 scale factor의 크기는 감소된다.

다만 scale factor가 작아지기만 하면 작은 gradient가 underflow되는 현상이 발생할 수 있다. 따라서 growth_interval 만큼의 iteration동안 overflow가 발생하지 않는다면 growth_factor를 곱하여 scale factor의 크기를 키운다.

참고

PyTorch AMP - 2 : AMP 코드 구조 및 autocast & GradScalar 정리