Python 파이토치 SimCLR Contrastive Learning 실습
이번 글에서는 파이썬에서 파이토치 모듈을 활용하여 SimCLR 논문의
Contrastive Learning 사례를 간략한 버전으로 구현해보도록 하겠습니다.
이번 사례에서는 설명 간략화를 위하여 비교적 간단한 MNIST 데이터셋을 사용해 보았으며,
모델로는 아주 단순한 CNN 구조를 가정하고 기법을 구현해 보았습니다.
또한, data augmentation은 cutout 이후 회전을 수행하는 1가지 방법에 대해서만
학습을 진행해보는 예제로 글을 구성하였습니다.
Step 1 : 데이터셋 불러오기
MNIST 데이터셋을 불러올 수 있는 방법은 torchvision 모듈을 활용할 수도 있지만
여기서는 사이킷런을 이용하여 데이터를 불러오고 직접 텐서로 바꿔주었습니다.
라벨 순서는 데이터 순서와 상관없이 뒤섞여 있다는 점을 고려하여
7만개의 데이터 중 앞 6만개를 학습용, 뒤 1만개를 테스트용으로 간주하였습니다.
from sklearn.datasets import fetch_openml
import torch
import numpy as np
mnist = fetch_openml('mnist_784')
# GPU 사용 지정
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# 7만개 중 앞 6만개 train 데이터 가정
X_train = torch.tensor(np.array(mnist.data)).float().reshape(-1, 1, 28, 28)[:60000].to(device)
y_train = torch.tensor(np.array(list(map(np.int_, mnist.target))))[:60000].to(device)
# 7만개 중 뒤 1만개 test 데이터 가정
X_test = torch.tensor(np.array(mnist.data)).float().reshape(-1, 1, 28, 28)[60000:].to(device)
y_test = torch.tensor(np.array(list(map(np.int_, mnist.target))))[60000:].to(device)
print(X_train.shape) # torch.Size([60000, 1, 28, 28])
print(y_train.shape) # torch.Size([60000])
print(X_test.shape) # torch.Size([10000, 1, 28, 28])
print(y_test.shape) # torch.Size([10000])
데이터셋을 불러오는 과정은 수 분 가량이 소요될 수 있으며,
이미지 픽셀 수와 CNN의 input 형태를 고려하여
각 이미지를 미리 1 * 28 * 28 차원으로 변환해주었습니다.
Step 2 : Data Augmentation 구현
랜덤으로 10 * 10 픽셀 부위를 골라 회색으로 마킹한 뒤,
왼쪽으로 90도만큼 회전하는 Data Augmentation 과정을 구현한 함수는 다음과 같습니다.
def cutout_and_rotate(image):
image = image.clone().detach() # 얕은 복사 문제 주의(원본 유지)
x_start = np.random.randint(20) # cut out 시작할 x축 위치(0~19 중 1개)
y_start = np.random.randint(20) # cut out 시작할 y축 위치(0~19 중 1개)
image[..., x_start:x_start+9, y_start:y_start+9] = 255 / 2 # 해당 부분 회색 마킹
return torch.rot90(image, 1, [-2, -1]) # 마지막 두 axis 기준 90도 회전
첫 번째 이미지를 가져와 Data Augmentation이 잘 되는지 여부를
시각화하여 확인해보는 예시 코드는 아래와 같습니다.
import matplotlib.pyplot as plt
from matplotlib.pyplot import style
# 흰색 배경 및 크기 지정
style.use('default')
figure = plt.figure()
figure.set_size_inches(4, 2)
# 흑백으로 출력하기 위한 스타일 설정
style.use('grayscale')
# 1 * 2 사이즈의 격자 설정
axes = []
for i in range(1, 3):
axes.append(figure.add_subplot(1, 2, i))
# 첫 이미지에 대한 원본 이미지 및 augmentation 수행된 이미지 시각화
img_example = X_train[0].clone().detach().cpu()
original = np.array(img_example).reshape(-1, 28).astype(int)
aug_img = np.array(cutout_and_rotate(img_example)).reshape(-1, 28).astype(int)
axes[0].matshow(original)
axes[1].matshow(aug_img)
# 제목 설정 및 눈금 제거
axes[0].set_axis_off()
axes[0].set_title('original')
axes[1].set_axis_off()
axes[1].set_title('augmentation')
plt.show()
Step 3 : CNN 모델 구조 구현
더 복잡한 CNN 구조나 ResNet 등의 모델도 활용이 가능하지만
여기서는 Convolution layer가 2개만 존재하는 간단한 CNN 구조를 가정해보았습니다.
각 layer를 거칠 때 마다 변화하는 shape는 주석에 포함시켜 두었습니다.
(각 이미지의 output 벡터의 차원은 100차원으로 가정하였습니다.)
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1)
self.fc = nn.Linear(4 * 4 * 20, 100)
def forward(self, x):
x = F.relu(self.conv1(x)) # (batch, 1, 28, 28) -> (batch, 10, 24, 24)
x = F.max_pool2d(x, kernel_size=2, stride=2) # (batch, 10, 24, 24) -> (batch, 10, 12, 12)
x = F.relu(self.conv2(x)) # (batch, 10, 12, 12) -> (batch, 20, 8, 8)
x = F.max_pool2d(x, kernel_size=2, stride=2) # (batch, 20, 8, 8) -> (batch, 20, 4, 4)
x = x.view(-1, 4 * 4 * 20) # (batch, 20, 4, 4) -> (batch, 320)
x = F.relu(self.fc(x)) # (batch, 320) -> (batch, 100)
return x # (batch, 100)
Step 4 : Loss 함수 구현
SimCLR 논문에서는 N개의 이미지로 구성된 배치에서 각 이미지에서 augmentation된 N개의
이미지를 합쳐 총 2N개의 이미지를 최종 배치로 구성합니다.
이후 해당 이미지 - augmentation된 이미지 pair만 positive data(분자 부분)으로
간주하고 해당 이미지 - 나머지 이미지의 2N-2개 pair들은 negative data(분모 부분)으로
간주하여 아래의 식처럼 contrastive loss를 계산하게 됩니다.
GPU의 연산 효율성을 위하여 배치 내 연산을 행렬 형태로 한 번에 수행되게 만드는 것이
중요한데, 이 과정이 포함된 해당 loss 함수의 구현은 아래 사이트의 코드를 가져왔습니다.
# 출처 : https://medium.com/the-owl/simclr-in-pytorch-5f290cb11dd7
class SimCLR_Loss(nn.Module):
def __init__(self, batch_size, temperature):
super().__init__()
self.batch_size = batch_size
self.temperature = temperature
self.mask = self.mask_correlated_samples(batch_size)
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = nn.CosineSimilarity(dim=2)
# loss 분모 부분의 negative sample 간의 내적 합만을 가져오기 위한 마스킹 행렬
def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
return mask
def forward(self, z_i, z_j):
N = 2 * self.batch_size
z = torch.cat((z_i, z_j), dim=0)
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
# loss 분자 부분의 원본 - augmentation 이미지 간의 내적 합을 가져오기 위한 부분
sim_i_j = torch.diag(sim, self.batch_size)
sim_j_i = torch.diag(sim, -self.batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
negative_samples = sim[self.mask].reshape(N, -1)
labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= N
return loss
Step 5 : Training
이제 모델 변수를 선언하고 dataloader를 만들어 학습을 진행해주면 됩니다.
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
X_train_aug = cutout_and_rotate(X_train) # 각 X_train 데이터에 대하여 augmentation
X_train_aug = X_train_aug.to(device) # 학습을 위하여 GPU에 선언
dataset = TensorDataset(X_train, X_train_aug) # augmentation된 데이터와 pair
batch_size = 32
dataloader = DataLoader(
dataset,
batch_size = batch_size)
model = CNN() # 모델 변수 선언
loss_func = SimCLR_Loss(batch_size, temperature = 0.5) # loss 함수 선언
# train 코드 예시
epochs = 10
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for i in range(1, epochs + 1):
total_loss = 0
for data in tqdm(dataloader):
origin_vec = model(data[0])
aug_vec = model(data[1])
loss = loss_func(origin_vec, aug_vec)
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch : %d, Avg Loss : %.4f'%(i, total_loss / len(dataloader)))
loss 함수가 epoch 8까지는 잘 줄어들다가 이후에는 소폭 증가하는 모습을
보이고 있는 점을 관찰할 수 있었습니다.
Step 6 : 분류를 위한 다운스트림 모델 선언
위에서 학습된 CNN 구조의 모델에 class 개수만큼의 차원으로 projection을
진행하는 mlp layer를 장착하여 최종 class 분류를 위한 다운스트림 모델을 선언해보겠습니다.
여기서는 단일 mlp layer만을 이용하여 projection하는 상황을 가정해보았습니다.
class CNN_classifier(nn.Module):
def __init__(self, model):
super().__init__()
self.CNN = model # contrastive learning으로 학습해둔 모델을 불러오기
self.mlp = nn.Linear(100, 10) # class 차원 개수로 projection
def forward(self, x):
x = self.CNN(x) # (batch, 100)으로 변환
x = self.mlp(x) # (batch, 10)으로 변환
return x # (batch, 10)
이후 학습을 위해서 augmentation된 이미지들이 필요하지는 않으며
여기서는 해당 이미지와 라벨 간의 pair를 이루어 dataloader를 선언해주어야 합니다.
class_dataset = TensorDataset(X_train, y_train) # 데이터와 라벨 간의 pair
batch_size = 32
class_dataloader = DataLoader(
class_dataset,
batch_size = batch_size)
Step 7 : 분류 다운스트림 모델 학습 및 테스트
분류 다운스트림 모델을 학습하는 코드는 아래와 같습니다.
여기서 loss 함수의 종류로는 nn.CrossEntropyLoss()를 활용하였으며,
각 epoch마다 정답 개수를 세어 train data 기준 정확도를 출력하게 해보았습니다.
classifier = CNN_classifier(model).to(device) # 모델 선언, GPU 활용 지정
classifier_loss = nn.CrossEntropyLoss() # 분류를 위한 loss 함수
epochs = 10
classifier.train()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)
for i in range(1, epochs + 1):
correct = 0
for data in tqdm(class_dataloader):
logits = classifier(data[0])
loss = classifier_loss(logits, data[1].long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct += torch.sum(torch.argmax(logits, 1) == data[1]).item() # 정확도 산출을 위하여 정답 개수 누적
print('Epoch : %d, Train Accuracy : %.2f%%'%(i, correct * 100 / len(X_train)))
epoch가 거듭될수록 train accuracy는 꾸준히 상승하는 모습을 보여주고 있었습니다.
이제 새로운 데이터에 대해서도 정확도가 충분히 나오도록 일반화된 학습이 잘되었는지 여부를
테스트 데이터셋을 통하여 검증해보도록 하겠습니다.
test_dataset = TensorDataset(X_test, y_test) # 테스트 데이터와 라벨 pair
batch_size = 32
test_dataloader = DataLoader(
test_dataset,
batch_size = batch_size)
classifier.eval() # 테스트 모드로 전환
correct = 0
for data in tqdm(test_dataloader):
logits = classifier(data[0])
correct += torch.sum(torch.argmax(logits, 1) == data[1]).item() # 정확도 산출을 위하여 정답 개수 누적
print('Test Accuracy : %.2f%%'%(correct * 100 / len(X_test)))
약 97% 이상의 test accuracy가 산출되는 것으로 보아 학습이 충분히 잘되었다는 점을
살펴볼 수 있었습니다.
'Python > Pytorch' 카테고리의 다른 글
[Pytorch] 파이토치 허브(torch.hub) 사용법 (0) | 2022.07.28 |
---|---|
[Pytorch] 체크포인트(checkpoint) 설명, 저장 및 불러오기 예제(epoch별, step별, best) (0) | 2022.07.18 |
[Pytorch] 파이썬 파이토치 데이터 병렬 처리 적용 예제 : nn.DataParallel (0) | 2022.07.14 |