torch expand, expand_as, repeat 함수 차이 비교
이번 글에서는 파이토치에서 원소를 반복하여 텐서의 차원을 확장하는 함수들인
expand, expand_as, repeat 함수들의 사용법 차이를 분석해보는 시간을
가져보도록 하겠습니다.
파이토치 expand 함수
expand() 메소드는 원하는 차원 크기를 input으로 받아
텐서의 값들을 뒤쪽 axis에서 반복하여
확장된 차원의 반복 텐서를 생성합니다.
반복을 원하는 텐서의 size가 (x, y, z) 식이라면,
input으로는 (a, b, c, ... , x, y, z) 처럼
마지막 axis들의 크기는 input으로 넣은 차원과 동일한 크기여야 합니다.
첫 번째로, 1차원 텐서의 확장 예시입니다.
import torch
x = torch.tensor([1, 2, 3, 4]) # size = 4
# (A, B, C, ... , 4) 형태로 마지막 axis size만 4이면 확장 가능
x.expand(3, 4)
'''
tensor([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])'''
x.expand(2, 2, 4)
'''
tensor([[[1, 2, 3, 4],
[1, 2, 3, 4]],
[[1, 2, 3, 4],
[1, 2, 3, 4]]])'''
x.expand(1, 2, 1, 4)
'''
tensor([[[[1, 2, 3, 4]],
[[1, 2, 3, 4]]]])'''
x.expand(1, 2, 8) # last axis size = 4가 아니라서 오류
# RuntimeError: The expanded size of the tensor (3) must match the existing size (4) at non-singleton dimension 2. Target sizes: [1, 2, 3]. Tensor sizes: [4]
크기가 4인 텐서를 반복하려면 마지막 axis 자리가 4가 되어야 하는 것을
확인할 수 있었습니다.
만일, 마지막 axis 자리가 불일치하는 경우에는 RuntimeError가 발생합니다.
이번에는 2차원 텐서의 확장 예시도 살펴보도록 하겠습니다.
x = torch.tensor([[1, 2],
[3, 4]]) # size = (2, 2)
# (A, B, C , ..., 2, 2) 꼴이면 허용
x.expand(3, 2, 2)
'''
tensor([[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]]])'''
x.expand(4, 2, 3) # 마지막 두 자리가 (2, 2)가 아니라서 불가능
# RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 2. Target sizes: [4, 2, 3]. Tensor sizes: [2, 2]
(2, 2) 차원의 텐서를 확장시키기 위해서는
마지막 두 axis 크기가 2, 2로 끝나야하는 점을 확인할 수 있었습니다.
파이토치 expand_as 함수
파이토치의 expand_as() 메소드는 expand() 함수와 동일한 기능을 수행합니다.
다만, input으로 shape가 아닌 tensor가 직접 들어가는 점이 다릅니다.
x = torch.tensor([1, 2, 3, 4])
y = torch.ones(3, 4)
x.expand_as(y)
'''
tensor([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])'''
z = torch.ones(2, 2, 4)
x.expand_as(z)
'''
tensor([[[1, 2, 3, 4],
[1, 2, 3, 4]],
[[1, 2, 3, 4],
[1, 2, 3, 4]]])'''
x.expand(3, 4) 형태 대신 x.expand_as(y) 형태로 사용된 점을
눈여겨보시면 좋을 듯 합니다.
파이토치 repeat 함수
repeat() 메소드는 expand와 동일하게 shape를 input으로 받는데,
해당 shape 크기로 확장하는 것이 아니라
해당 shape 만큼 타일 형태로 tensor를 쌓는 기능을 수행합니다.
1차원과 2차원 텐서에 대해서 repeat 함수를 적용한 예시를 살펴보겠습니다.
# 1차원 텐서 예시
x = torch.tensor([1, 2, 3, 4])
x.repeat(3, 4) # 해당 모양을 3 * 4 크기로 반복
'''
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],
[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],
[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]])'''
x.repeat(2, 2, 1) # x.expand(2, 2, 4)와 같은 결과
'''
tensor([[[1, 2, 3, 4],
[1, 2, 3, 4]],
[[1, 2, 3, 4],
[1, 2, 3, 4]]])'''
# 2차원 텐서 예시
y = torch.tensor([[1, 2],
[3, 4]])
y.repeat(2, 3)
'''
tensor([[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4]])'''
repeat(3, 4)는 (3, 4) shape으로 텐서를 만드는 것이 아니라,
쌓기를 원하는 ([1, 2, 3, 4]) 전체를 (3, 4) 만큼 쌓아 총 (3, 12)의 size로 만들어주고
있는 상황을 이해해주시면 좋습니다.
2차원 텐서의 경우도 원래 텐서의 위상을 유지한채로
타일을 쌓는 듯한 원리로 input shape만큼 쌓아 차원을 확장합니다.
'Python > Pytorch' 카테고리의 다른 글
[Pytorch] 텐서 쌓기 함수 torch.cat(), torch.stack() 비교 (0) | 2022.01.27 |
---|---|
[Pytorch] squeeze와 unsqueeze 함수 사용법 정리 (2) | 2022.01.25 |
[Pytorch] 쿠다 버전 확인, 파이토치 버전 체크, 업데이트 방법 (0) | 2022.01.23 |