Python/Sklearn

[Sklearn] 파이썬 나이브 베이즈 분류기 구현 예제

jimmy_AI 2022. 6. 13. 22:40
반응형

Python 사이킷런 나이브 베이즈(NB) 분류 모델 학습하기

파이썬에서 scikit-learn의 기능을 활용하여 나이브 베이즈 분류기 학습 진행의 과정을

구현해보는 예제를 다루어보겠습니다.

 

 

데이터셋 불러오기

먼저, 이번 글에서 사용할 붓꽃 데이터셋을 불러오도록 하겠습니다.

from sklearn.datasets import load_iris
import pandas as pd

# 데이터셋 로드
iris = load_iris()
df = pd.DataFrame(data= iris.data , 
                  columns= ['sepal length', 'sepal width', 'petal length', 'petal width'])
df['target'] = iris.target

df

0, 1, 2로 표시된 3가지 종류의 target(꽃의 종류)를 4가지 feature의 정보를 바탕으로

예측하는 문제이며, 각 종류 별로 50개의 데이터씩 총 150개의 데이터로 구성되어 있습니다.

 

 

학습용/테스트용 데이터 분리

이어서, 학습된 분류기의 성능 검증을 위하여 먼저 학습용 데이터와 테스트용 데이터를

train_test_split 함수를 통하여 분리해보겠습니다.

 

여기서는 모든 feature 정보를 전부 사용하는 상황을 가정하며, test set의 비율은 30%로

설정하여 분리를 진행했습니다.

from sklearn.model_selection import train_test_split

# train, test 데이터셋 분리
X = df[df.columns[:-1]]
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 62)

 

반응형

 

나이브 베이즈 분류기 학습 과정 및 NB 함수 종류 고르는 팁

나이브 베이즈 분류 모델을 학습하는 과정은 사이킷런에서 적절한 NB 함수를 불러온 뒤,

train 용 데이터를 input으로 넣어 fit 메소드를 적용해주시면 간단하게 완료됩니다.

from sklearn.naive_bayes import CategoricalNB

# 나이브 베이즈 모델 선언 및 학습
model = CategoricalNB()
model.fit(X_train, y_train)

여기서는 0, 1, 2의 3가지 카테고리에 대한 분류를 진행할 것이기에

CategoricalNB 함수를 골라서 가져왔습니다.

 

참고로, 사이킷런에서는 총 5가지 종류의 NB 함수를 지원하는데,

각 함수를 골라야하는 상황을 정리하면 다음과 같습니다.

 

BernoulliNB : 가장 기본적인 NB 함수로 이진 분류 시에 사용합니다.

CategoricalNB : 분류할 카테고리의 종류가 3가지 이상일 때 사용합니다.

MultinomialNB : 텍스트의 등장 횟수처럼 이산적인 값의 수를 예측할 때 사용합니다.

GaussianNB : 예측할 값이 연속적인 값인 경우에 사용합니다.

ComplementNB : target label의 balance가 맞지 않는 불균형한 상황에 사용합니다.

 

다른 종류의 NB 함수의 사용 방법도 sklearn.naive_bayes 모듈에서 불러와 같은 방식으로

사용해주시면 됩니다.

 

 

정확도 성능 평가

이제 학습이 완료된 분류기에 대하여 분리해두었던 테스트 데이터를 가지고

라벨 예측의 정확도로 성능 평가를 진행해보겠습니다.

 

사이킷런의 accuracy_score 함수를 활용하여 아래와 같이 간단하게 성능 평가가 가능합니다.

from sklearn.metrics import accuracy_score

y_pred = model.predict(X_test) # 예측 라벨
print(accuracy_score(y_test, y_pred)) # 0.9111111111111111

여기서는 약 91% 정도의 분류 정확도를 기록하였습니다.