Python/NLP Code

Transformers 라이브러리에서 학습한 모델 저장, 불러오기 방법

jimmy_AI 2022. 3. 10. 20:36
반응형

Transformers model save, load

Hugging Face에서 제공하는 Transformers 라이브러리의 모델들을

학습 뒤 저장하는 방법과, 저장된 모델을 불러오는 방법에 대해서 살펴보겠습니다.

 

 

모델 저장 방법 : save_pretrained(디렉토리)

예를 들어, Transformers의 BertForMaskedLM, TFAutoModelWithLMHead 모델을 불러와서

(BertForMaskedLM는 파이토치 기반, TFAutoModelWithLMHead는 텐서플로우 기반)

fine-tune 과정을 수행한 이후 학습된 모델을 파일로 저장하고 싶은 상황을 가정해보겠습니다.

from transformers import BertForMaskedLM, TFAutoModelWithLMHead

torch_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
tf_model = TFAutoModelWithLMHead.from_pretrained("t5-small")

###
# fine-tune 과정 수행
###

모델을 저장하는 방법은 save_pretrained 함수 내에 원하는 디렉토리를 input으로 적어주시면

저장이 완료됩니다.

 

이 때, 파이토치 기반 모델은 pt 확장자를, 텐서플로우 기반 모델은 h5 확장자로 지정해줍니다.

# 모델 저장 변수 이름.save_pretrained(원하는 디렉토리) 형태
torch_model.save_pretrained('model.pt') # 파이토치 기반 모델
tf_model.save_pretrained('model.h5') # 텐서플로우 기반 모델

 

 

모델 불러오기 방법 : from_pretrained(디렉토리)

저장된 model을 load하는 방법은 pre-train model을 가져오는 경우와 마찬가지로,

모델 클래스 이름.from_pretrained() 메소드로 가져오면 됩니다.

 

이 때, input으로 모델이 저장된 파일의 디렉토리를 적어주셔야 합니다.

# 모델 클래스 이름.from_pretrained(저장된 디렉토리) 형태
torch_model = BertForMaskedLM.from_pretrained('model.pt')
tf_model = TFAutoModelWithLMHead.from_pretrained('model.h5')
반응형

파이토치/텐서플로우 기반 save, load 방법 이용

위 방법들이 Transformers 라이브러리에서 구현된 기능을 통한 저장/불러오기 방법이었다면,

여기서 구현된 모델이 기본적으로 Pytorch 혹은 Tensorflow로 구현된 모델이라는 점을 통하여

파이토치 및 텐서프로우의 모델 저장 및 불러오기 과정을 그대로 사용할 수도 있습니다.

 

Pytorch 기반의 save 및 load 방법은 아래 글을 참고해보세요.

 

[Pytorch] 파이토치 모델 저장, 불러오기 방법

torch model save, load 예제 이번 글에서는 파이토치에서 학습된 모델을 저장하고, 저장된 모델을 다시 불러오는 방법을 파라미터만 저장하는 방법과 모델 전체를 save하는 방법으로 나누어서 설명해

jimmy-ai.tistory.com

Tensorflow 기반의 save/load 매뉴얼에 대해서는 공식 document의 링크를 남기겠습니다.

 

모델 저장과 복원  |  TensorFlow Core

TensorFlow.js의 새로운 온라인 과정에서 웹 ML을 통해 0에서 영웅으로 거듭나십시오. 지금 등록하세요 모델 저장과 복원 모델 진행 상황은 훈련 중 및 훈련 후에 저장할 수 있습니다. 즉, 모델이 중

www.tensorflow.org