텐서플로우 케라스 layer weight freezing
파이썬의 텐서플로우 케라스에서 모델 내의 특정 layer의 가중치를
학습 시에 freeze하는 방법에 대하여 살펴보겠습니다.
먼저, 다음과 같이 Dense layer 3개로 구성된 간단한 모델 구조를 가정해보겠습니다.
(각 layer의 이름은 fc1, fc2, fc3로 설정하였습니다.)
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
layer1 = Dense(3, activation='relu', name = 'fc1', input_shape = (4,))
layer2 = Dense(2, activation='relu', name = 'fc2')
layer3 = Dense(1, activation='relu', name = 'fc3')
model.add(layer1)
model.add(layer2)
model.add(layer3)
여기서는 모델 선언 시에 특정 layer를 freeze 하는 방법과
layer의 위치 및 이름을 기준으로 freeze 하는 방법으로 나누어서 설명하겠습니다.
1. 모델 선언 시 특정 layer freeze
Sequential로 선언된 모델 내에 add를 수행하기 전에
해당 layer의 trainable 속성을 False로 선언해주면 됩니다.
fc1과 fc3 layer를 freeze하는 코드는 아래와 같습니다.
model = Sequential()
layer1 = Dense(3, activation='relu', name = 'fc1', input_shape = (4,))
layer2 = Dense(2, activation='relu', name = 'fc2')
layer3 = Dense(1, activation='relu', name = 'fc3')
# add 전에 trainable = False 지정
layer1.trainable = False
layer3.trainable = False
model.add(layer1)
model.add(layer2)
model.add(layer3)
2. layer의 위치를 기준으로 freeze
layer의 순서를 정확하게 파악하고 있다면 몇 번째 layer인지를 기준으로
freeze하는 것이 가능합니다.
0, 2번째에 해당하는 fc1, fc3 layer를 다음과 같이 freeze할 수 있습니다.
for i, layer in enumerate(model.layers):
if i in [0, 2]:
layer.trainable = False
3. layer의 이름을 기준으로 freeze
때로는 layer의 정확한 순서가 몇 번째인지 매번 세는 것이 불편할 수 있습니다.
2번 방법을 응용하면 layer 선언 시 지정한 name을 기준으로도 freeze 지정이 가능합니다.
fc1, fc3의 이름을 가지는 layer를 freeze한 예시는 다음과 같습니다.
for layer in model.layers:
if layer.name in ['fc1', 'fc3']:
layer.trainable = False
layer freeze 시 주의 사항, 예제
위의 방법들로 layer freeze를 진행할 경우, trainable = False 지정 후
model.compile을 다시 실행시켜 주어야 layer의 freeze 여부가 제대로 반영됩니다.
아래는 3번 방법으로 fc1, fc3 layer를 freeze한 예시입니다.
먼저, 초기 가중치를 출력하여 살펴보겠습니다.
model.weights
layer freeze 후 아래의 코드로 학습을 진행해 보겠습니다.
import numpy as np
model.compile(loss = 'mse', optimizer = 'adam', metrics = ['accuracy'])
X = np.random.randn(100, 4)
y = np.random.randint(2, size = (100, ))
model.fit(X, y)
다시 가중치를 출력해보면 fc1, fc3에 해당하는 가중치 부분은 위와 동일하고,
fc2에 해당하는 가중치만 변경된 것을 확인할 수 있습니다.
'Python > Tensorflow' 카테고리의 다른 글
[Tensorflow] 텐서플로우 모델 구조 시각화 방법 : tf.keras.utils.plot_model (0) | 2022.11.12 |
---|---|
[Tensorflow] TFDV 활용 파이썬 데이터 EDA 실습 예제 (0) | 2022.08.30 |
[Tensorflow] 파이썬 keras RNN/LSTM/GRU 구현 예제(IMDB 감성 분석) (0) | 2022.06.17 |