인공지능 논문정리/Vision 논문

[술술 읽히는 논문 요약] Supervised Contrastive Learning

jimmy_AI 2021. 10. 31. 22:20
반응형

Supervised Contrastive Learning

저자 : Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, Dilip Krishnan 외

학회 : Neural Information Processing Systems(NIPS)

연도 : 2020년

논문 링크 : https://arxiv.org/abs/2004.11362

 

실험 목적

Contrastive Learning : 가까운 대상은 가깝게, 먼 대상은 멀게 가상의 공간 내에 Mapping하는 모델을 학습

ex) 자연어 처리 -> Word2Vec, 비전 분야에서도 활발히 연구

 

비전 분야에서 기존 Contrastive Learning

고양이 A 사진 Augmentation 1 적용 <-> 고양이 A 사진 Augmentation 2 적용 : Positive

고양이 A 사진 Augmentation 1 적용 <-> 고양이 B 사진 Augmentation 1 적용 : Negative

 

-> 같은 label(고양이)를 가지는 이미지인데 거리를 멀게 Mapping하게 만드는 것은 이상하다!

 

즉, 지도 학습 상의 label을 고려하여 Positive / Negative를 다시 정의하여 Contrastive Learning을 해보자는 것이 이 논문의 실험 목적

 

Methods

 

실험 세팅

 

한 미니 배치 내 이미지 n개 : x1,x2,...,xn

각 이미지 마다 랜덤하게 2개의 Augmentation 실행 : x¯ -> 총 2n개의 이미지가 미니 배치 내 존재

각 이미지에 대한 encoder : r=Enc(x¯) -> ResNet을 주로 사용

각 encoding 결과에 대한 projector : z=Proj(r) -> single layer MLP 사용(차원수 감소 ex. 2048 -> 128), 최종 활용시에는 버림

 

Loss 함수(Contrastive Loss)

1. Self-Supervised Contrastive Loss(기존 방법)

같은 이미지에서 등장한 output만을 positive로 간주하는 기존 방법

zi,zj(i)가 positive라고 가정, I는 미니 배치 내 전체 output 집합, τ는 temperature hyperparameter(0.1을 주로 사용)

Cross Entropy와 유사한 형태의 Contrastive Loss 사용

Lself=iIlogexp(zizj(i)/τ)aIiexp(ziza/τ)

 

2. Supervised Contrastive Loss(지도 학습 Label 이용 방법)

같은 Label을 가진 이미지는 모두 positive set이었다고 간주하는 새로운 방법

P(i)는 미니 배치 내 zi와 같은 Class에 속했던 데이터의 output 집합

log의 위치에 따라 변형된 2가지의 loss 함수 구상

Loutsup=iI1|P(i)|pP(i)logexp(zizp/τ)aIiexp(ziza/τ)

Linsup=iIlog(1|P(i)|pP(i)exp(zizp/τ)aIiexp(ziza/τ))

 

해당 Loss 함수의 특징 3가지

1. positive의 개수가 랜덤으로 설정되는 원리라 일반화에서 유리하게 작용

2. 미니 배치의 크기가 커질수록(negative 데이터가 많아질수록) 학습 성능이 높아짐 -> 기존 방법과 동일한 특징

3. hard positive / negative mining에서도 학습에 유리함을 보임

 

결과

Loss 함수 비교

LoutsupLinsup보다 약 10% 이상 높은 Top-1 정확도를 기록했는데,

log함수의 concave 성질에 의한 Jensen's Inequality에 따라 항상 LinsupLoutsup을 보장하기에,

Loutsup가 더 많은 내재된 정보를 학습시킬 수 있는 것으로 분석해볼 수 있었음

 

기존 기법과의 비교

ResNet-50을 Encoder로 사용했을 때, CIFAR10, CIFAR100, ImageNet 데이터 셋에서 Top-1과 Top-5 정확도 모두 SimCLR, Self-supervised loss 사용 버전 모델, Max-margin 모델을 압도했음

 

다른 모델 구조(ResNet-200 등), 다른 조합의 Augmentation 적용 시에도 기존 기법과 비슷하거나 더 좋은 성능을 보장했음을 보임

 

추가 실험 결과

Image Corruption에 비교적 강인한 모습을 보임

하이퍼 파라미터의 변동에도 다른 모델에 비해서 비교적 준수한 성능을 유지함

transfer learning의 활용에서는 추가 연구가 필요해보임