Python/Pytorch

[Pytorch] 파이토치 허브(torch.hub) 사용법

jimmy_AI 2022. 7. 28. 17:39
반응형

파이토치 허브 메소드 정리, 사용 예제

Pytorch에서 github repository 등의 위치에 구현되어 있는 ResNet 등의 pre-train된 모델을

쉽게 불러올 수 있는 hub 기능에 대하여 구현된 함수들의 종류를 살펴보고

간단한 사용 예제들에 대하여 다루어 보겠습니다.

 

torch.hub의 더 상세한 사용법이 궁금하시다면 아래 링크의 공식 사이트를 참고해주세요.

(해당 포스팅도 해당 공식 글의 내용을 참조하여 작성되었습니다.)

 

torch.hub — PyTorch 1.12 documentation

torch.hub Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. Publishing models Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simp

pytorch.org

 

 

torch.hub.list : 사용 가능 모델 목록 출력

원하는 깃허브 리포지토리 내의 이용 가능한 모델 이름들의 목록을 반환하는 함수입니다.

 

예를 들어, 0.10.0 버전의 pytorch/vision 리포지토리 내의 모델 목록을 출력하려면

아래와 같이 코드를 작성해주시면 되며, 출력 결과는 다음과 같습니다.

import torch

torch.hub.list('pytorch/vision:v0.10.0')

참고로, torch.hub의 메소드들(list, help 및 load)에서 재다운로드 및 repo 활용 옵션을

force_reload, skip_validation 및 trust_repo 인자의 값을 설정하여

조정할 수 있습니다.(윗 링크의 공식 글을 참고하세요.)

 

 

torch.hub.help : 모델을 불러올 때 지정할 수 있는 인자 출력

불러오기를 원하는 모델에 대하여 추가로 지정할 수 있는 인자의 종류를

출력하는 메소드도 torch.hub에서 지원하고 있습니다.

 

예를 들어, pytorch/vision 리포지토리 내에서 resnet50 모델의 설명을 출력하는 코드와

출력 결과는 아래와 같습니다.

# 리포지토리, 모델 이름 순으로 input 지정
print(torch.hub.help('pytorch/vision:v0.10.0', 'resnet50'))

여기서는 pretrained와 progress 인자를 지정할 수 있는 것으로 보입니다.

 

반응형

 

torch.hub.load : 모델 불러오기

torch.hub 내에서 가장 중요한 메소드로, 원하는 repository 내의 특정 모델을

바로 불러올 수 있게 해주는 메소드입니다.

 

예를 들어, pytorch/vision 리포지토리 내에서 resnet50을 불러오는 과정의 코드는

다음과 같으며, help 출력 결과 사용 가능하다고 출력되었던

pretrained 및 progress 인자도 지정해보도록 하겠습니다.

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True, progress = True)

 

또 다른 예시로 허깅페이스 내 transformers 모듈에 구현된 bert에 대한

tokenizer 및 model도 다음과 같이 torch.hub의 기능을 통해서도 불러올 수 있습니다.

# !pip install tqdm boto3 requests regex sentencepiece sacremoses 실행 필요
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')

 

 

torch.hub.download_url_to_file : 특정 url의 파일 다운받기

특정 url 내에 저장된 object를 다운받을 수 있게 해주는 기능도 torch.hub에서 지원합니다.

 

기본적인 사용법은 torch.hub.download_url_to_file(다운로드 할 url, 저장받을 경로)

형태로 코드를 작성해주시면 됩니다.

torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

 

또한, 다운받을 object가 파이토치 모델인 경우 해당 model의 state_dict를 변수 내에 바로

저장할 수 있는 형태로 다운받는 함수도 torch.hub.load_state_dict_from_url로 지원합니다.

state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')