반응형
torch model save, load 예제
이번 글에서는 파이토치에서 학습된 모델을 저장하고, 저장된 모델을 다시 불러오는 방법을
파라미터만 저장하는 방법과 모델 전체를 save하는 방법으로 나누어서 설명해보겠습니다.
참고로, 이 글은 파이토치 공식 문서를 기반으로 작성되었습니다.
model save 방법 1 : 파라미터만 저장
state.dict() 메소드를 불러와 모델의 파라미터만 골라서 저장이 가능합니다.
이 방법은 모델의 클래스 종류와 argument를 아는 경우 용량을 절약할 수 있어 권장드립니다.
저장될 파일은 pt 확장자로 지정해주시면 됩니다.
# torch.save(모델이 저장된 변수 이름.state_dict(), 모델이 저장될 디렉토리)
torch.save(model.state_dict(), 'model.pt')
model save 방법 2 : 모델 전체를 저장
모델 전체의 모든 정보를 그대로 저장하는 방법입니다.
주로 커스터마이징이 많이되어 파라미터만 load할 경우 그대로 재현이 어렵다면
이 방법으로 학습된 모델을 저장하는 것을 권장드립니다.
여기서도 마찬가지로 파일의 확장자는 pt로 지정해주시면 됩니다.
# torch.save(모델이 저장된 변수 이름, 모델이 저장될 디렉토리)
torch.save(model, 'model.pt')
반응형
model load 방법 1 : 파라미터만 저장된 경우 불러오기
파라미터만 저장된 경우 모델을 불러오는 방법입니다.
이 경우, 기반이된 모델의 클래스 종류와 선언 시의 argument 조건을 알고있어야 합니다.
만일, CNN 클래스를 기반으로 작성된 모델이었던 경우 아래 예시처럼 작성해주시면 됩니다.
model = CNN() # 모델 선언 시의 파라미터 조건이 있던 경우 같이 추가
model.load_state_dict(torch.load('model.pt')) # load 함수 내에 저장 디렉토리 작성
model load 방법 2 : 모델 전체를 저장한 경우 불러오기
모델 전체를 저장했던 경우에는 load 시 클래스 종류, argument 정보가
필요하지 않다는 장점이 있습니다.
이 경우에는 load 함수를 통하여 바로 모델을 불러오는 것이 가능합니다.
model = torch.load('model.pt') # input으로 저장된 디렉토리만 지정하면 완료
'Python > Pytorch' 카테고리의 다른 글
[Pytorch] 파이토치 특정 layer freeze 방법 (0) | 2022.04.22 |
---|---|
[Pytorch] torch.view, torch.reshape의 사용법과 차이 비교 (1) | 2022.02.27 |
[Pytorch] 파이토치 설치 방법 정리 (1) | 2022.02.09 |