Python/Pytorch

[Pytorch] 실시간 파일 불러오기 방식 Dataset/Dataloader 구현 예제

jimmy_AI 2023. 9. 13. 22:57
반응형

파이토치에서 전체 데이터셋의 크기가 너무 크거나 다양한 경로에 나눠져 있는 경우

real-time으로 파일을 불러오는 방식의 데이터셋을 구현할 필요가 있습니다.

 

이 글에서는 실시간으로 파일을 불러오는 데이터셋의 형태를 구현하는 예제를

간략하게 정리해보도록 하겠습니다.

 

 

상황 가정

이해를 돕기 위하여 아래와 같이 images 디렉토리 내에 여러 이미지 파일

저장된 상황을 가정해보도록 하겠습니다.

지금은 이미지 파일이 4개뿐이지만, 아주 많은 수의 이미지가 있는 경우

모든 데이터를 메모리에 동시에 올리는 것이 어려울 수 있습니다.

 

이런 문제를 해결하기 위하여 real-time 방식의 데이터셋 및 데이터로더를 정의하게 되면

메모리 부족 문제도 해결되면서 iteration 마다 다른 Data Augmentation 기법을 적용하는

Dynamic Augmentation이 가능해져 학습 효과를 향상할 수도 있습니다.

또한, 디렉토리 내 내용이 바뀌면 다음 번에 학습 데이터로 가져올 수 있어

Online 학습의 상황에서도 유리할 것입니다.

 

우선, 불러올 파일의 디렉토리의 목록을 리스트로 가져와 보겠습니다.

import os

dir_name = 'images'
dir_list = os.listdir(dir_name)
dir_list = [os.path.join(dir_name, x) for x in dir_list]

print(dir_list) # ['images/1.png', 'images/2.png', 'images/3.png', 'images/4.png']

 

각 이미지에 대한 label 정보도 아래와 같이 있다고 가정해 보겠습니다.

labels = [0, 0, 1, 1]

 

반응형

 

Real-Time 방식의 Dataset/Dataloader 정의 방법

각 이미지 디렉토리를 기준으로 파일을 불러와 전처리(resize 등)을 수행하고,
Data Augmentation을 랜덤으로 적용하기 위한 함수들을 정의해보겠습니다.

여기서는 50% 확률로 flip을 적용하는 간단한 Augmentation을 예시로 들었습니다.

import numpy as np
import cv2

# 이미지 불러오기 + 전처리 함수
def image_read(img_dir):
    img_arr = cv2.imread(img_dir)
    img_arr = cv2.resize(img_arr, (256, 256)) # input size = 256 가정
    return img_arr

# Random Augmentation 함수
def image_aug(img_arr):
    is_flip = np.random.randint(2, size = 1)[0] # 0 or 1
    if is_flip == 1:
        img_arr = cv2.flip(img_arr, 0)
    return img_arr

 

Dataset 선언을 위하여 초기화, 길이 조회, iteration 과정을 init, len, getitem 매직 메소드
각각 정의해주어야 합니다.

 

데이터셋 선언에 대한 세부적인 내용이 궁금하시다면 아래 글을 참고해보세요.

 

03-07 커스텀 데이터셋(Custom Dataset)

앞 내용을 잠깐 복습해봅시다. 파이토치에서는 데이터셋을 좀 더 쉽게 다룰 수 있도록 유용한 도구로서 torch.utils.data.Dataset과 torch.utils.data.…

wikidocs.net

 

이번 예제에서는 이미지 파일 디렉토리 목록 및 라벨 목록을 input으로 받아
실시간 불러오기 방식의 Dataset을 생성하는 예시 코드를 작성해 보았습니다.

from torch.utils.data import Dataset
import torch

class MyDataset(Dataset):
    # 초기화: 디렉토리 목록 및 라벨 목록을 input으로 받음
    def __init__(self, dir_list, labels):
        self.dir_list = dir_list
        self.labels = labels
	
    # 길이: 디렉토리 목록의 길이를 데이터셋의 길이로 취급
    def __len__(self):
        return len(self.dir_list)
    
    # 순회: 목록에서 idx번째를 불러와 불러오기 및 증강 수행
    def __getitem__(self, idx):
        img_arr = image_read(self.dir_list[idx])
        img_arr = image_aug(img_arr)
        X = torch.FloatTensor(img_arr) / 255.0 # numpy array -> tensor 및 정규화
        y = self.labels[idx]

        return X, y

 

이렇게 정의된 Dataset을 DataLoader를 통하여 학습에 실제로 활용하는 법
아래의 코드를 참고해주시면 됩니다.

from torch.utils.data import DataLoader

dataset = MyDataset(dir_list, labels)
dataloader = DataLoader(dataset, batch_size=2)

 

위에서 선언한 DataLoader의 iteration 결과는 다음과 같이 나타납니다.

for batch in dataloader:
    X, y = batch
    print(X.shape)
    print(y)

# 출력 기록
torch.Size([2, 256, 256, 3])
tensor([0, 0])
torch.Size([2, 256, 256, 3])
tensor([1, 1])