반응형
파이토치 일부 layer의 파라미터만 freeze하기
파이토치에서 학습을 진행할 때, 특정 layer를 freeze하는 방법을 말씀드려보겠습니다.
이해를 돕기 위해 아래와 같은 매우 간단한 신경망이 있다고 가정해보겠습니다.
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc2 = nn.Linear(4, 2)
self.fc3 = nn.Linear(2, 1)
def forward(self, x):
return self.fc3(self.fc2(self.fc1(x)))
my_net = Net()
예를 들어, 위의 신경망에서 fc2에 해당하는 layer를 freeze하고 싶은 경우
선언된 신경망 my_net에 아래의 코드를 적용해주시면 됩니다.
for name, child in my_net.named_children():
for param in child.parameters():
if name == 'fc2': # 원하는 layer 이름 지정
param.requires_grad = False # 해당 layer freeze
위 코드 실행 이후 아래의 코드를 통하여 각 layer의 파라미터 상태를 출력해보면
freeze를 시킨 layer에만 requires_grad = True가 없는 것을 확인할 수 있습니다.
for name, child in my_net.named_children():
for param in child.parameters():
print(name, param)
'Python > Pytorch' 카테고리의 다른 글
[Pytorch] GPU 여부 확인, 사용할 GPU 번호 지정, 모델 및 텐서에 GPU 할당 방법 (4) | 2022.04.24 |
---|---|
[Pytorch] 파이토치 모델 저장, 불러오기 방법 (0) | 2022.03.08 |
[Pytorch] torch.view, torch.reshape의 사용법과 차이 비교 (1) | 2022.02.27 |