Python/Tensorflow

[Tensorflow] 텐서플로우 모델 구조 시각화 방법 : tf.keras.utils.plot_model

jimmy_AI 2022. 11. 12. 23:44
반응형

Visualize Tensorflow/Keras Model Structures

텐서플로우에서 구현된 모델의 구조를 plot_model 메소드를 통하여

쉽게 시각화할 수 있는 방법에 대하여 살펴보도록 하겠습니다.

 

먼저, 예시로 아래와 같이 간단한 CNN 구조의 모델이 있다고 가정해보도록 하겠습니다.

from tensorflow.keras import models, layers

# CNN 구조 모델 예시
model = models.Sequential()

model.add(layers.Conv2D(5, 3, strides = 1, padding = 'same', activation = 'relu', input_shape = (28, 28, 1)))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))

model.add(layers.Conv2D(10, 3, strides = 1, padding = 'same', activation = 'relu'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation = 'relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(10, activation = 'softmax'))

 

모델의 구조를 시각화하여 이미지 파일로 저장한 예시 코드는 다음과 같습니다.

import tensorflow as tf

tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True)

 

반응형

 

저장된 이미지 파일을 열어보면 아래와 같이 시각화된 것을 확인해볼 수 있습니다.

 

dtype, layer_name 시각화 여부, 수평 방향 지정, dpi 등 다양한 옵션을 추가로

지정할 수 있는데 지원하는 상세한 옵션의 종류와 지정 방법은

아래의 텐서플로우 공식 문서를 참고해주시면 됩니다.

 

tf.keras.utils.plot_model  |  TensorFlow v2.10.0

Converts a Keras model to dot format and save to a file.

www.tensorflow.org

 

참고로, 해당 plot_model 메소드는 텐서플로우 2 버전에서는 모두 지원하고 있지만

비교적 구 버전에서는 일부 arguments 종류를 지원하지 않을 수도 있으니 유의해주세요.