Python/Pytorch

[Pytorch] 파이토치 모델 저장, 불러오기 방법

jimmy_AI 2022. 3. 8. 20:22
반응형

torch model save, load 예제

이번 글에서는 파이토치에서 학습된 모델을 저장하고, 저장된 모델을 다시 불러오는 방법을

파라미터만 저장하는 방법과 모델 전체를 save하는 방법으로 나누어서 설명해보겠습니다.

 

참고로, 이 글은 파이토치 공식 문서를 기반으로 작성되었습니다.

 

 

model save 방법 1 : 파라미터만 저장

state.dict() 메소드를 불러와 모델의 파라미터만 골라서 저장이 가능합니다.

이 방법은 모델의 클래스 종류와 argument를 아는 경우 용량을 절약할 수 있어 권장드립니다.

 

저장될 파일은 pt 확장자로 지정해주시면 됩니다.

# torch.save(모델이 저장된 변수 이름.state_dict(), 모델이 저장될 디렉토리)
torch.save(model.state_dict(), 'model.pt')

 

 

model save 방법 2 : 모델 전체를 저장

모델 전체의 모든 정보를 그대로 저장하는 방법입니다.

주로 커스터마이징이 많이되어 파라미터만 load할 경우 그대로 재현이 어렵다면

이 방법으로 학습된 모델을 저장하는 것을 권장드립니다.

 

여기서도 마찬가지로 파일의 확장자는 pt로 지정해주시면 됩니다.

# torch.save(모델이 저장된 변수 이름, 모델이 저장될 디렉토리)
torch.save(model, 'model.pt')
반응형

model load 방법 1 : 파라미터만 저장된 경우 불러오기

파라미터만 저장된 경우 모델을 불러오는 방법입니다.

이 경우, 기반이된 모델의 클래스 종류와 선언 시의 argument 조건을 알고있어야 합니다.

 

만일, CNN 클래스를 기반으로 작성된 모델이었던 경우 아래 예시처럼 작성해주시면 됩니다.

model = CNN() # 모델 선언 시의 파라미터 조건이 있던 경우 같이 추가
model.load_state_dict(torch.load('model.pt')) # load 함수 내에 저장 디렉토리 작성

 

 

model load 방법 2 : 모델 전체를 저장한 경우 불러오기

모델 전체를 저장했던 경우에는 load 시 클래스 종류, argument 정보가

필요하지 않다는 장점이 있습니다.

 

이 경우에는 load 함수를 통하여 바로 모델을 불러오는 것이 가능합니다.

model = torch.load('model.pt') # input으로 저장된 디렉토리만 지정하면 완료