Python/Pytorch

[Pytorch] 파이토치 특정 layer freeze 방법

jimmy_AI 2022. 4. 22. 00:48
반응형

파이토치 일부 layer의 파라미터만 freeze하기

파이토치에서 학습을 진행할 때, 특정 layer를 freeze하는 방법을 말씀드려보겠습니다.

 

이해를 돕기 위해 아래와 같은 매우 간단한 신경망이 있다고 가정해보겠습니다.

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, 2)
        self.fc3 = nn.Linear(2, 1)

    def forward(self, x):
        return self.fc3(self.fc2(self.fc1(x)))

my_net = Net()

 

예를 들어, 위의 신경망에서 fc2에 해당하는 layer를 freeze하고 싶은 경우

선언된 신경망 my_net에 아래의 코드를 적용해주시면 됩니다.

for name, child in my_net.named_children():
    for param in child.parameters():
        if name == 'fc2': # 원하는 layer 이름 지정
            param.requires_grad = False # 해당 layer freeze

 

위 코드 실행 이후 아래의 코드를 통하여 각 layer의 파라미터 상태를 출력해보면

freeze를 시킨 layer에만 requires_grad = True가 없는 것을 확인할 수 있습니다.

for name, child in my_net.named_children():
    for param in child.parameters():
        print(name, param)