Python/Pytorch

[Pytorch] 텐서 쌓기 함수 torch.cat(), torch.stack() 비교

jimmy_AI 2022. 1. 27. 22:48
반응형

torch cat vs stack 함수 차이

이번 글에서는 파이토치에서 텐서를 쌓는 경우 사용하게 되는

cat과 stack 함수의 차이와 사용 방법에 대해서 살펴보도록 하겠습니다.

 

먼저, 다음과 같이 간단한 (2, 3) shape의 2차원 텐서 2개를 선언하겠습니다.

import torch

# (2, 3) 사이즈 2차원 텐서 2개 생성
a = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

b = torch.tensor([[7, 8, 9],
                  [10, 11, 12]])

이제 이 두 개의 텐서를 가지고 cat과 stack 함수를 적용해보겠습니다.

 

 

파이토치 cat 함수 사용법

cat 함수는 원하는 dimension 방향으로 텐서를 나란하게 쌓아줍니다.

 

예를 들어,

(a, x, y)와 (b, x, y) 차원의 두 텐서를 dim = 0 방향으로 쌓는 경우

(a + b, x, y) 차원의 텐서가 형성됩니다.(x, y는 두 텐서에서 동일해야 합니다.)

torch.cat([a, b], dim = 0)
'''
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]]), size = (2+2, 3) = (4, 3)'''

torch.cat([a, b], dim = 1)
'''
tensor([[ 1,  2,  3,  7,  8,  9],
        [ 4,  5,  6, 10, 11, 12]]), size = (2, 3+3) = (2, 6)'''

a와 b를 dim = 0으로 쌓은 경우는 (2+2, 3) = (4, 3) size의 텐서가 되고,

dim = 1으로 쌓은 경우에는 (2, 3+3) = (2, 6) size의 텐서가 되었습니다.

 

원리를 그림으로 표현하면 아래 그림과 같이 표현이 됩니다.

반응형

파이토치 stack 함수 사용법

반면, stack 함수는 텐서를 새로운 차원에 차곡차곡 쌓아주는 기능을 수행합니다.

 

예시로, (x, y, z) 사이즈의 텐서 3개를 dim = 2 방향으로 쌓는다면

(x, y, 3, z) 형태로 (x, y, z) 텐서 3개를 쌓게 됩니다.

(같은 사이즈의 텐서끼리만 쌓을 수 있습니다.)

 

cat 함수에서는 (x, y, 3 * z) 형태로 쌓였던 점을 참고해주시면

이 두 함수의 차이를 이해하기가 편리합니다.

torch.stack([a, b], dim = 0)
'''
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]]), size = (2, 2, 3)'''

torch.stack([a, b], dim = 1)
'''
tensor([[[ 1,  2,  3],
         [ 7,  8,  9]],

        [[ 4,  5,  6],
         [10, 11, 12]]]), size = (2, 2, 3)'''

torch.stack([a, b], dim = 2)
'''
tensor([[[ 1,  7],
         [ 2,  8],
         [ 3,  9]],

        [[ 4, 10],
         [ 5, 11],
         [ 6, 12]]]), size = (2, 3, 2)'''

dim = 0, 1 방향으로 쌓은 두 예제도 원소의 순서가 쌓인 방향의 차이에 의해

달라진 점을 확인해주시면 좋습니다.

 

여기서는 2개의 텐서를 쌓았기에 쌓인 방향의 dimension size가 2가 되었습니다.

또한, 마지막 차원의 다음 차원에 쌓을 공간을 만드는 것까지 가능하므로,

2차원 텐서의 경우 쌓는 공간의 위치를 dim = 0, 1, 2 중에서 고를 수 있었습니다.

 

stack 함수의 원리도 그림으로 표현하면 아래와 같이 나타낼 수 있습니다.