반응형
안녕하세요. 이번 시간에는 파이토치에서 추론에 사용되는 eval() 모드와
gradient를 생략하기 위해서 사용되는 torch.no_grad()의 차이에 대해서
간략하게 비교 예제로 차이를 이해해보도록 하겠습니다.
모델 예시
다음과 같은 아주 간단한 모델과 텐서가 있다고 가정해보겠습니다.
Dropout이 있는 상황을 주목해주시면 좋습니다.
import torch
import torch.nn as nn
# 모델 정의
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Dropout(p=0.5)
)
# 예시 input
input_tensor = torch.randn(1, 10)
코드 출력 비교 예시
model.train()
output_train_1 = model(input_tensor)
output_train_2 = model(input_tensor)
print("Training mode 출력 1:", output_train_1)
print("Training mode 출력 2:", output_train_2)
Training mode 출력 1: tensor([[0.0000, 1.0664, 0.0000, 0.2099, 0.0000]], grad_fn=<MulBackward0>)
Training mode 출력 2: tensor([[0.0000, 1.0664, 0.0000, 0.0000, 1.3870]], grad_fn=<MulBackward0>)
train 모드로 실행하면 Dropout이 랜덤하게 적용되어 출력도 다를 수 있고,
gradient 추적이 되어 backward 연산이 가능합니다.
model.eval()
output_eval_1 = model(input_tensor)
output_eval_2 = model(input_tensor)
print("Eval mode 출력 1:", output_eval_1)
print("Eval mode 출력 2:", output_eval_2)
Eval mode 출력 1: tensor([[0.0000, 0.5332, 0.0000, 0.1050, 0.6935]], grad_fn=<ReluBackward0>)
Eval mode 출력 2: tensor([[0.0000, 0.5332, 0.0000, 0.1050, 0.6935]], grad_fn=<ReluBackward0>)
반면, eval 모드로 실행하면 Dropout은 적용되지 않아 출력이 일정합니다.
그러나, gradient 추적은 여전히 가능하고, backward 연산도 가능한 상태입니다.
with torch.no_grad():
output_no_grad_1 = model(input_tensor)
output_no_grad_2 = model(input_tensor)
print("no grad mode 출력 1:", output_no_grad_1)
print("no grad mode 출력 2:", output_no_grad_2)
no grad mode 출력 1: tensor([[0.0000, 0.5332, 0.0000, 0.1050, 0.6935]])
no grad mode 출력 2: tensor([[0.0000, 0.5332, 0.0000, 0.1050, 0.6935]])
torch.no_grad() 모드로 실행한 경우에도 Dropout은 적용되지 않습니다.
여기서 중요한 점은 gradient 추적이 중단되고, 따라서 이렇게 연산을 진행한 경우에는
backward 연산이 막히는 것이 가장 큰 특징입니다.
이 외에도 BatchNorm 연산 시, eval 모드에서는 고정된 통계값으로만 정규화를 진행하지만,
torch.no_grad() 시에는 train 모드와 유사하게 배치마다 정규화를 진행한다는 특징도 다릅니다.
이 글이 파이토치 모델 구현 과정에 도움이 되셨다면 좋겠습니다.
잘 봐주셔서 감사드립니다.
'Python > Pytorch' 카테고리의 다른 글
| [Pytorch] onnx -> pth 파일 변환 방법 정리 (0) | 2025.09.22 |
|---|---|
| [Pytorch] 파이토치 ReLU 함수 종류 총정리(ReLU, LeakyReLU, PReLU, RReLU, ReLU6) (0) | 2025.05.18 |
| [Pytorch] checkpoint vs torchscript vs onnx 모델 속도 비교 (0) | 2023.09.14 |