파이토치 허브 메소드 정리, 사용 예제
Pytorch에서 github repository 등의 위치에 구현되어 있는 ResNet 등의 pre-train된 모델을
쉽게 불러올 수 있는 hub 기능에 대하여 구현된 함수들의 종류를 살펴보고
간단한 사용 예제들에 대하여 다루어 보겠습니다.
torch.hub의 더 상세한 사용법이 궁금하시다면 아래 링크의 공식 사이트를 참고해주세요.
(해당 포스팅도 해당 공식 글의 내용을 참조하여 작성되었습니다.)
torch.hub.list : 사용 가능 모델 목록 출력
원하는 깃허브 리포지토리 내의 이용 가능한 모델 이름들의 목록을 반환하는 함수입니다.
예를 들어, 0.10.0 버전의 pytorch/vision 리포지토리 내의 모델 목록을 출력하려면
아래와 같이 코드를 작성해주시면 되며, 출력 결과는 다음과 같습니다.
import torch
torch.hub.list('pytorch/vision:v0.10.0')
참고로, torch.hub의 메소드들(list, help 및 load)에서 재다운로드 및 repo 활용 옵션을
force_reload, skip_validation 및 trust_repo 인자의 값을 설정하여
조정할 수 있습니다.(윗 링크의 공식 글을 참고하세요.)
torch.hub.help : 모델을 불러올 때 지정할 수 있는 인자 출력
불러오기를 원하는 모델에 대하여 추가로 지정할 수 있는 인자의 종류를
출력하는 메소드도 torch.hub에서 지원하고 있습니다.
예를 들어, pytorch/vision 리포지토리 내에서 resnet50 모델의 설명을 출력하는 코드와
출력 결과는 아래와 같습니다.
# 리포지토리, 모델 이름 순으로 input 지정
print(torch.hub.help('pytorch/vision:v0.10.0', 'resnet50'))
여기서는 pretrained와 progress 인자를 지정할 수 있는 것으로 보입니다.
torch.hub.load : 모델 불러오기
torch.hub 내에서 가장 중요한 메소드로, 원하는 repository 내의 특정 모델을
바로 불러올 수 있게 해주는 메소드입니다.
예를 들어, pytorch/vision 리포지토리 내에서 resnet50을 불러오는 과정의 코드는
다음과 같으며, help 출력 결과 사용 가능하다고 출력되었던
pretrained 및 progress 인자도 지정해보도록 하겠습니다.
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True, progress = True)
또 다른 예시로 허깅페이스 내 transformers 모듈에 구현된 bert에 대한
tokenizer 및 model도 다음과 같이 torch.hub의 기능을 통해서도 불러올 수 있습니다.
# !pip install tqdm boto3 requests regex sentencepiece sacremoses 실행 필요
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
torch.hub.download_url_to_file : 특정 url의 파일 다운받기
특정 url 내에 저장된 object를 다운받을 수 있게 해주는 기능도 torch.hub에서 지원합니다.
기본적인 사용법은 torch.hub.download_url_to_file(다운로드 할 url, 저장받을 경로)
형태로 코드를 작성해주시면 됩니다.
torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
또한, 다운받을 object가 파이토치 모델인 경우 해당 model의 state_dict를 변수 내에 바로
저장할 수 있는 형태로 다운받는 함수도 torch.hub.load_state_dict_from_url로 지원합니다.
state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
'Python > Pytorch' 카테고리의 다른 글
[Pytorch] 텐서를 넘파이 배열, 리스트로 변환하는 방법 정리 (1) | 2022.08.12 |
---|---|
[Pytorch] 파이썬 Contrastive Learning 구현 예제(feat. SimCLR) (2) | 2022.07.20 |
[Pytorch] 체크포인트(checkpoint) 설명, 저장 및 불러오기 예제(epoch별, step별, best) (0) | 2022.07.18 |