Python/Pytorch

[Pytorch] 파이토치 과적합 방지(Early Stopping) 구현 방법 정리

jimmy_AI 2022. 9. 8. 23:26
반응형

파이토치 학습 과정에서 각 epoch가 끝나는 시점에서 validation loss 혹은

validation accuracy 등의 성능 지표를 측정하여 과적합이 의심되는 특정 시점을 넘어가는 경우

조기에 학습을 종료하는 early stopping 기능을 구현하는 방법들을 정리해보도록 하겠습니다.

 

 

1. 직접 구현

각 epoch가 끝나는 시점마다 evaluation을 진행 후, 성능 개선 여부를 감시하는 식의 코드를

간단하게 구현해볼 수 있습니다.

 

예시 pseudo-code의 형태는 다음과 같습니다.

best_loss = 10 ** 9 # 매우 큰 값으로 초기값 가정
patience_limit = 3 # 몇 번의 epoch까지 지켜볼지를 결정
patience_check = 0 # 현재 몇 epoch 연속으로 loss 개선이 안되는지를 기록

### 전체 학습 코드 스니펫 ###
for i in range(epochs):
	
    ### 각 epoch의 train 부분 ###
    model.train()

    for X, y in train_dataloader:
        y_pred = model(X)
        loss = loss_function(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    ### 각 epoch train 이후 evaluation 진행 ###
    model.eval()
    val_loss = 0

    for X, y in eval_dataloader:
        
        y_pred = model(X)
        loss = loss_function(y_pred, y)

        val_loss += loss.item()   
        
    ### early stopping 여부를 체크하는 부분 ###
    if val_loss > best_loss: # loss가 개선되지 않은 경우
        patience_check += 1

        if patience_check >= patience_limit: # early stopping 조건 만족 시 조기 종료
            break

    else: # loss가 개선된 경우
        best_loss = val_loss
        patience_check = 0

 

반응형

 

2. pytorchtools 모듈 이용

아래 깃허브 사이트에 있는 py 파일을 다운로드 받은 뒤, 해당 py 파일을 모듈로 import하여

EarlyStopping 함수를 사용할 수 있습니다.

https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

 

GitHub - Bjarten/early-stopping-pytorch: Early stopping for PyTorch

Early stopping for PyTorch . Contribute to Bjarten/early-stopping-pytorch development by creating an account on GitHub.

github.com

 

해당 모듈의 EarlyStopping 함수를 적용한 예시 코드 스니펫은 아래와 같습니다.

1번의 직접 구현 방법에 비하여 상대적으로 간단한 편인 듯 합니다.

from pytorchtools import EarlyStopping # 위 링크의 깃허브 파일에서 임포트

# early_stopping 객체 선언(3번의 epoch 연속으로 loss 미개선 시에 조기 종료 예시)
early_stopping = EarlyStopping(patience = 3, verbose = True)

### 전체 학습 코드 스니펫 ###
for i in range(epochs):
	
    ### 각 epoch의 train 부분 ###
    model.train()

    for X, y in train_dataloader:
        y_pred = model(X)
        loss = loss_function(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    ### 각 epoch train 이후 evaluation 진행 ###
    model.eval()
    val_loss = 0

    for X, y in eval_dataloader:
        
        y_pred = model(X)
        loss = loss_function(y_pred, y)

        val_loss += loss.item()
        
    ### early stopping 여부를 체크하는 부분 ###
    early_stopping(val_loss, model) # 현재 과적합 상황 추적
    
    if early_stopping.early_stop: # 조건 만족 시 조기 종료
        break

 

 

3. torchsample 모듈 이용

또다른 구현 모듈인 torchsample의 코드 파일들을 아래의 깃허브 사이트에서 다운받아

여기에서 구현된 EarlyStopping 기능도 활용이 가능한 듯 합니다.

https://github.com/ncullen93/torchsample

 

GitHub - ncullen93/torchsample: High-Level Training, Data Augmentation, and Utilities for Pytorch

High-Level Training, Data Augmentation, and Utilities for Pytorch - GitHub - ncullen93/torchsample: High-Level Training, Data Augmentation, and Utilities for Pytorch

github.com

 

해당 깃허브 사이트를 참조해보았을 때, 대략적인 코드 스니펫은 아래의 형태인 것으로 보입니다.

from torchsample.modules import ModuleTrainer
from torchsample.callbacks import EarlyStopping

# 모델 객체를 moduletrainer에 선언
model = Network()
trainer = ModuleTrainer(model)

# early stopping callback 선언
callbacks = [EarlyStopping(monitor='val_loss', patience=5)]
trainer.set_callbacks(callbacks)

# fit 형태로 학습 선언이 가능
trainer.fit(x_train, y_train, 
            val_data=(x_test, y_test),
            num_epoch=20, 
            batch_size=128,
            verbose=1)