Python/Pytorch

[Pytorch] squeeze와 unsqueeze 함수 사용법 정리

jimmy_AI 2022. 1. 25. 22:48
반응형

torch squeeze vs unsqueeze

이번 글에서는 파이토치에서 squeeze와 unsqueeze 함수의

용도와 사용 예시에 대해서 살펴보도록 하겠습니다.

 

이 글은 파이토치의 squeeze, unsqueeze 함수에 대한

공식 문서를 바탕으로 작성되었습니다.

 

 

Pytorch squeeze 함수 사용 방법

참고로, 지난 번에 numpy의 squeeze 함수 사용법에 대하여 다룬 글

있었는데, torch 라이브러리에서도 원리는 거의 비슷합니다.

 

[Numpy] np.squeeze 함수 사용법과 의미

파이썬 넘파이 np.squeeze() 함수 : 크기가 1인 axis 제거 이번 시간에는 파이썬 넘파이 배열에서 크기가 1인 추가 axis를 제거하는 np.squeeze 함수의 사용법과 의미에 대해서 간단히 살펴보도록 하겠습

jimmy-ai.tistory.com

torch squeeze 함수의 원리는 (A x B x 1 x C x 1) 형태의 텐서에서

차원이 1인 부분을 제거하여 (A x B x C) 형태로 만들어 주는 것입니다.

 

또한, 원하는 dimension 위치를 따로 선택하면, 해당 위치의 1만 삭제가 가능합니다.

단, 해당 차원 위치의 size가 1이 아니라면, 삭제가 불가능합니다.

import torch

# (A, B, 1, C, 1) 차원 형태 텐서
x = torch.ones(10, 5, 1, 3, 1)

# size가 1인 차원 전체 삭제 : (A, B, C) 차원 형태
x1 = x.squeeze() # torch.squeeze(x) 가능
x1.shape # torch.Size([10, 5, 3])

# size가 1인 차원 일부 삭제 : (A, B, 1, C) 차원 형태
x2 = x.squeeze(dim = 2) # x.squeeze(2) 가능
x2.shape # torch.Size([10, 5, 3, 1])

x3 = x.squeeze(dim = -1) # dim = 4와 동일한 결과
x3.shape # torch.Size([10, 5, 1, 3])

# size가 1이 아닌 차원 삭제 시도(불가능)
x4 = x.squeeze(dim = 1)
x4.shape # torch.Size([10, 5, 1, 3, 1])

크기가 1인 차원인 2, 4번 dim은 정상 삭제가 가능했으나,

dim = 1처럼 size가 1이 아닌 경우는 삭제가 이루어지지 않았습니다.

 

참고로, dim = -1처럼 뒤쪽 차원부터 접근도 가능하며,

torch.squeeze(tensor, dim)형태 사용과 dim = 2 대신 2처럼만 작성하는 것도 가능합니다.

 

 

Pytorch unsqueeze 함수 사용법

unsqueeze 메소드는 위의 함수와 반대의 기능을 가지고 있습니다.

지정한 dimension 자리에 size가 1인 빈 공간을 채워주면서 차원을 확장합니다.

 

여기서도 torch.unsqueeze(tensor, dim) 형태로도 사용이 가능하며,

'dim ='을 생략하고 첫 번째 인자에 차원의 위치를 바로 적어주어도 됩니다.

x = torch.ones(3, 5, 7)

# 1번과 2번 사이에 dimension 추가
x1 = x.unsqueeze(dim = 1)
x1.shape # torch.Size([3, 1, 5, 7])

# 마지막 자리에 dimension 추가
x2 = x.unsqueeze(dim = -1) # dim = 3과 동일한 결과
x2.shape # torch.Size([3, 5, 7, 1])

# 오류가 발생하는 경우
x3 = x.unsqueeze(dim = 4)
# IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)

dim = 1처럼 지정할 경우, 기존 텐서의 1번과 2번 차원 사이에

크기가 1인 dimension이 추가되는 것을 확인할 수 있습니다.

 

만일, 원래 텐서의 차원보다 큰 숫자를 넣은 경우에는(3차원 텐서에서 dim = 4 예시)

해당 위치에 dimension 추가가 불가능하므로, IndexError가 발생합니다.