Python/Pytorch

[Pytorch] contiguous 원리와 의미

jimmy_AI 2022. 2. 8. 16:15
반응형

torch의 contiguous에 대해서

안녕하세요. 이번 시간에는

파이토치에서 메모리 내에서의 자료형 저장 상태로 등장하는 contiguous의

원리와 의미에 대해서 간단히 살펴보도록 하겠습니다.

 

 

contiguous 여부와 stride 의미

간단한 예시를 들어 설명하기 위해서

shape이 (4, 3)으로 동일한 두 tensor a, b를 다음과 같이 선언해보겠습니다.

import torch

a = torch.randn(3, 4)

a.transpose_(0, 1)

b = torch.randn(4, 3)

# 두 tensor는 모두 (4, 3) shape
print(a)
'''
tensor([[-0.7290,  0.7509,  1.1666],
        [-0.9321, -0.4360, -0.2715],
        [ 0.1232, -0.6812, -0.0358],
        [ 1.1923, -0.8931, -0.1995]])'''

print(b)
'''
tensor([[-0.1630,  0.1704,  1.8583],
        [-0.1231, -1.5241,  0.2243],
        [-1.3705,  1.2717, -0.6051],
        [ 0.0412,  1.3312, -1.2066]])'''

이제, a, b 텐서에 저장된 값들의 메모리 주소

axis 방향(오른쪽 방향 우선) 순서로 불러와보도록 하겠습니다.

# a 텐서 메모리 주소 예시
for i in range(4):
    for j in range(3):
        print(a[i][j].data_ptr())
'''
94418119497152
94418119497168
94418119497184
94418119497156
94418119497172
94418119497188
94418119497160
94418119497176
94418119497192
94418119497164
94418119497180
94418119497196'''

# b 텐서 메모리 주소 예시
for i in range(4):
    for j in range(3):
        print(b[i][j].data_ptr())
'''
94418119613696
94418119613700
94418119613704
94418119613708
94418119613712
94418119613716
94418119613720
94418119613724
94418119613728
94418119613732
94418119613736
94418119613740'''

각 데이터의 타입인 torch.float32 자료형은 4바이트이므로,

메모리 1칸 당 주소 값이 4씩 증가함을 알 수 있습니다.

 

그런데 자세히 보시면 b는 한 줄에 4씩 값이 증가하고 있지만,

a는 그렇지 않은 상황임을 알 수 있습니다.

 

위 상황을 요약하여 그림으로 표현하면 아래처럼 이해해볼 수 있습니다.

즉, b는 axis = 0인 오른쪽 방향으로 자료가 순서대로 저장됨에 비해,

a는 transpose 연산을 거치며 axis = 1인 아래 방향으로 자료가 저장되고 있었습니다.

 

여기서, b처럼 axis 순서대로 자료가 저장된 상태를 contiguous = True 상태라고 부르며,

a같이 자료 저장 순서가 원래 방향과 어긋난 경우를 contiguous = False 상태라고 합니다.

 

각 텐서에 stride() 메소드를 호출하여 데이터의 저장 방향을 조회할 수 있습니다.

또한, is_contiguous() 메소드로 contiguous = True 여부도 쉽게 파악할 수 있습니다.

a.stride() # (1, 4)
b.stride() # (3, 1)

a.is_contiguous() # False
b.is_contiguous() # True

여기에서 a.stride() 결과가 (1, 4)라는 것은

a[0][0] -> a[1][0]으로 증가할 때는 자료 1개 만큼의 메모리 주소가 이동되고,

a[0][0] -> a[0][1]로 증가할 때는 자료 4개 만큼의 메모리 주소가 바뀐다는 의미입니다.

 

 

contiguous 여부가 바뀌는 경우, contiguous() 메소드

텐서의 shape을 조작하는 과정에서 메모리 저장 상태가 변경되는 경우가 있습니다.

주로 narrow(), view(), expand(), transpose() 등 메소드를 사용하는 경우에

이 상태가 깨지는 것으로 알려져 있습니다.

 

해당 상태의 여부를 체크하지 않더라도 텐서를 다루는데 문제가 없는 경우가 많습니다.

다만, RuntimeError: input is not contiguous의 오류가 발생하는 경우에는

input tensor를 contiguous = True인 상태로 변경해주어야 할 수 있습니다.

 

이럴 때에는 아래 예시 코드처럼 contiguous() 메소드를 텐서에 적용하여

contiguous 여부가 True인 상태로 메모리 상 저장 구조를 바꿔줄 수 있습니다.

a.is_contiguous() # False

# 텐서를 contiguous = True 상태로 변경
a = a.contiguous()

a.is_contiguous() # True