Python/Numpy

[Numpy] np.squeeze 함수 사용법과 의미

jimmy_AI 2022. 1. 14. 17:08
반응형

파이썬 넘파이 np.squeeze() 함수 : 크기가 1인 axis 제거

이번 시간에는 파이썬 넘파이 배열에서 크기가 1인 추가 axis를 제거하는

np.squeeze 함수의 사용법과 의미에 대해서 간단히 살펴보도록 하겠습니다.

 

먼저, 다음과 같은 2차원 배열처럼 보이는 3차원 배열이 있다고 가정해보겠습니다.

import numpy as np

a = np.array([[[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]]])

a.shape # (1, 3, 3)

겉으로 보기에는 3 * 3 크기의 2차원 배열처럼 생겼으나,

괄호가 가장 바깥쪽에 1개가 추가로 있어 실제 shape은 (3, 3)이 아닌 (1, 3, 3)입니다.

 

[1, 2, 3, 4]와 [[1, 2, 3, 4]]의 관계로 생각하시면 이해가 쉽습니다.

 

np.squeeze 함수의 기능은 [[1, 2, 3, 4]]를 [1, 2, 3, 4]의 형태로 바꿔주는 것으로

이해해주시면 됩니다. 이제, 실제 예제를 살펴보도록 하겠습니다.

 

 

np.squeeze 함수 기본 사용법

np.squeeze 함수의 사용법은 np.squeeze(배열) 혹은 배열.squeeze() 형태

지정해주시면 간단히 완료됩니다.

a.squeeze() # np.squeeze(a) 처럼도 가능
'''array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])'''

a.squeeze().shape # (3, 3)

a 배열의 바깥쪽 추가 axis가 제거되고 2차원 배열로 변환이 되었습니다.

 

이번에는 더 복잡한 경우로 2차원 배열처럼 보이는 4차원 배열의 예시입니다.

b = np.array([[[[1, 2, 3]],
              [[4, 5, 6]],
              [[7, 8, 9]]]])

b.shape # (1, 3, 1, 3)

b.squeeze()
'''array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]]), shape = (3, 3)'''

a 배열보다 안쪽 축에도 괄호가 1개씩 더 있는 상태로,

(1, 3, 1, 3) 형태의 shape를 가지고 있었습니다.(axis = 0, 2의 자리 크기가 1)

 

이 경우에도 squeeze 함수 1번으로 크기가 1인 axis 자리를 모두 제거하여

2차원 배열로의 변환이 성공적으로 완료되었습니다.

 

 

np.squeeze 함수 axis 설정

b 배열에서 크기가 1인 모든 추가 axis자리가 전부 제거되었는데,

일부 axis만 삭제하고 싶은 경우 axis 인자를 설정해주면 됩니다.

 

np.squeeze(배열, axis) 혹은 배열.squeeze(axis) 형태로 지정이 가능합니다.

 

b 배열에서 일부 axis만 제거하는 예시를 살펴보겠습니다.

b.squeeze(axis = 0) # np.squeeze(b, axis = 0) 처럼도 가능
'''array([[[1, 2, 3]],

       [[4, 5, 6]],

       [[7, 8, 9]]]), shape = (3, 1, 3)'''

b.squeeze(axis = 2)
'''array([[[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]]), shape = (1, 3, 3)'''

b.squeeze(axis = 1) # ValueError: cannot select an axis to squeeze out which has size not equal to one
b.squeeze(axis = 3) # ValueError: cannot select an axis to squeeze out which has size not equal to one

b 배열의 shape는 (1, 3, 1, 3)이었기에, axis = 0, 2자리는 크기가 1이고,

axis = 1, 3자리는 크기가 1이 아니었습니다.

 

axis = 0, 2에 해당하는 자리를 각각 지우는 것은 위의 예시처럼 가능했으나,

axis = 1, 3에 해당하는 자리를 지우려는 시도는 squeeze 함수의 기능에서

지원하지 않는 상황으로 ValueError가 발생하게 됩니다.

 

(1, 3, 1, 3)에서 axis = 0자리를 지우면 (3, 1, 3) shape가,

axis = 2자리를 지우면 (1, 3, 3) shape가 되고 있는 점도 참고해주시면 좋습니다.