Python/Debugging

transformers 모듈 model.generate() 과정 shape 관련 오류 해결

jimmy_AI 2022. 5. 24. 23:23
반응형

transformers generate 함수 RuntimeError, ValueError 디버깅

transformers 라이브러리의 generate 함수 사용 중 shape 미스매칭으로 인하여

발생할 수 있는 두 가지 오류에 대하여 정리해보도록 하겠습니다.

 

1. 추가 dimension을 가지는 경우

RuntimeError: The size of tensor a (100) must match the size of tensor b (10) at non-singleton dimension 1

 

2. 단일 sequence일 때, 추가 dimension이 필요한 경우

ValueError: not enough values to unpack (expected 2, got 1)

 

 

정상 실행 코드 예시

예를 들어, 100 토큰 길이의 sequence 10개를 한 batch로 삼아 generate를 하는 과정이

정상 실행되기 위한 token의 shape는 아래와 같습니다.

import torch
from transformers import BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

## 토큰 추출 코드 생략 가정 ##

print(token_ex.shape) # torch.Size([10, 100])

generated_tokens = model.generate(token_ex) # 정상 실행

[batch size, sequence length] 형태의 shape를 가지고 있어야

model.generate 함수가 정상적으로 실행된다는 점에 유의해주세요.

반응형

shape가 추가 dimension을 가지는 경우

이 경우는 [1, 10, 100]처럼 [10, 100]에서 추가 dimension을 가질 때 발생하는 오류입니다.

 

RuntimeError가 발생하며, 출력되는 에러 메시지는 아래와 같습니다.

print(token_ex.shape) # torch.Size([1, 10, 100])

generated_tokens = model.generate(token_ex)
# RuntimeError: The size of tensor a (100) must match the size of tensor b (10) at non-singleton dimension 1

이 경우, 토큰 input을 넣어주기 전에 squeeze 함수로 추가 dimension을 제거하시면 됩니다.

print(token_ex.shape) # torch.Size([1, 10, 100])

token_ex = token_ex.squeeze(0) # 0번 위치의 axis 제거
print(token_ex.shape) # torch.Size([10, 100])

generated_tokens = model.generate(token_ex) # 정상 실행

 

 

단일 sequence일 때, 추가 dimension이 필요한 경우

batch size = 1일 때 발생 가능한 오류로, 이 경우에도 토큰 텐서의 shape으로는

[1, sequence length]이 필요한데, [sequence length]의 1차원 텐서가 토큰 input으로

들어가면 해당 오류가 발생합니다.

 

ValueError가 발생하며, 오류 메시지의 예시는 아래와 같습니다.

print(token_ex.shape) # torch.Size([100])

generated_tokens = model.generate(token_ex)
# ValueError: not enough values to unpack (expected 2, got 1)

여기서는 unsqueeze 함수를 통하여 새로운 dimension을 추가해주시면 오류가 해결됩니다.

print(token_ex.shape) # torch.Size([100])

token_ex = token_ex.unsqueeze(0) # 0번 위치에 axis 추가
print(token_ex.shape) # torch.Size([1, 100])

generated_tokens = model.generate(token_ex) # 정상 실행