Python/Debugging

RuntimeError: expected scalar type Long but found Float / RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Long' 해결

jimmy_AI 2022. 7. 23. 22:07
반응형

파이토치 자료형 관련 오류 디버깅

Pytorch 사용 중 자료형 타입과 관련하여 발생할 수 있는 에러 종류 2가지에 대하여

원인 및 해결법을 간단히 정리해보도록 하겠습니다.

 

 

1. RuntimeError: expected scalar type Long but found Float

해당 오류는 정수형 타입 중 하나인 long 타입으로 구현되어야 하는 함수에게

float와 같이 다른 자료형의 텐서가 주어질 때 발생할 수 있는 오류입니다.

 

예를 들어, 다음과 같이 nn.CrossEntropyLoss 함수를 적용하는 간단한 상황을 살펴보겠습니다.

import torch
import torch.nn as nn

X = torch.tensor([[1, 2, 0], [1, 0, 1]]).float() # 데이터 부분 : float형
y = torch.tensor([1, 2]).float() ### 라벨 부분 : float형 ###

loss = nn.CrossEntropyLoss()
loss(X, y) # 라벨 부분 y는 long 자료형이어야 함
# RuntimeError: expected scalar type Long but found Float

해당 함수에서는 라벨 텐서 y의 자료형이 long인 상태에 대하여 지원하고 있는데

float 자료형으로 주어지고 있기 때문에 해당 오류가 발생하고 있습니다.

 

이 경우, long 자료형으로 라벨 텐서 y의 타입을 캐스팅하여 아래와 같이 다시 함수를 적용하면

정상적으로 작동되는 것을 볼 수 있습니다.

X = torch.tensor([[1, 2, 0], [1, 0, 1]]).float() # 데이터 부분 : float형
y = torch.tensor([1, 2]).long() ### 라벨 부분 : long형 ###

loss = nn.CrossEntropyLoss()
loss(X, y) # tensor(0.6348) -> 정상 작동

 

반응형

 

2. RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Long'

위에서 다룬 오류와 마찬가지로 input 텐서의 자료형이 함수에서 지원되지 않는 타입인 경우

다음과 같은 유형과 같은 오류가 발생할 수도 있습니다.

 

앞에서 다루었던 상황에서 이번에는 데이터 텐서 X의 자료형을 long으로 선언해보겠습니다.

X = torch.tensor([[1, 2, 0], [1, 0, 1]]).long() ### 데이터 부분 : long형 ###
y = torch.tensor([1, 2]).long() # 라벨 부분 : long형

loss = nn.CrossEntropyLoss()
loss(X, y) # 데이터 부분 X가 long형일 수 없음
# RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Long'

위의 1번 오류와 다른 점이 있다면 여기서는 어떤 자료형으로 바꾸어주어야 하는지에 대한

내용이 명시적으로 제시되지는 않았습니다.

 

다만, 일반적으로 정수형인 경우는 실수형으로 재시도해보고 실수형이었다면 정수형으로

재시도해보는 것처럼 반대의 자료형 시도부터 해보신다면 오류 해결 확률이 높은 편입니다.

(오히려 여러 종류의 타입의 텐서가 input으로 가능할 수도 있습니다.)

X = torch.tensor([[1, 2, 0], [1, 0, 1]]).double() ### 데이터 부분 : double형 ###
y = torch.tensor([1, 2]).long() # 라벨 부분 : long형

loss = nn.CrossEntropyLoss()
loss(X, y) # tensor(0.6348, dtype=torch.float64) -> 정상 작동

위의 함수의 경우에는 float, double 등 실수 자료형으로 데이터 텐서 X를 캐스팅하여

다시 시도해주신다면 정상 작동이 되는 것을 보실 수 있습니다.