Python 사이킷런 나이브 베이즈(NB) 분류 모델 학습하기
파이썬에서 scikit-learn의 기능을 활용하여 나이브 베이즈 분류기 학습 진행의 과정을
구현해보는 예제를 다루어보겠습니다.
데이터셋 불러오기
먼저, 이번 글에서 사용할 붓꽃 데이터셋을 불러오도록 하겠습니다.
from sklearn.datasets import load_iris
import pandas as pd
# 데이터셋 로드
iris = load_iris()
df = pd.DataFrame(data= iris.data ,
columns= ['sepal length', 'sepal width', 'petal length', 'petal width'])
df['target'] = iris.target
df
0, 1, 2로 표시된 3가지 종류의 target(꽃의 종류)를 4가지 feature의 정보를 바탕으로
예측하는 문제이며, 각 종류 별로 50개의 데이터씩 총 150개의 데이터로 구성되어 있습니다.
학습용/테스트용 데이터 분리
이어서, 학습된 분류기의 성능 검증을 위하여 먼저 학습용 데이터와 테스트용 데이터를
train_test_split 함수를 통하여 분리해보겠습니다.
여기서는 모든 feature 정보를 전부 사용하는 상황을 가정하며, test set의 비율은 30%로
설정하여 분리를 진행했습니다.
from sklearn.model_selection import train_test_split
# train, test 데이터셋 분리
X = df[df.columns[:-1]]
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 62)
나이브 베이즈 분류기 학습 과정 및 NB 함수 종류 고르는 팁
나이브 베이즈 분류 모델을 학습하는 과정은 사이킷런에서 적절한 NB 함수를 불러온 뒤,
train 용 데이터를 input으로 넣어 fit 메소드를 적용해주시면 간단하게 완료됩니다.
from sklearn.naive_bayes import CategoricalNB
# 나이브 베이즈 모델 선언 및 학습
model = CategoricalNB()
model.fit(X_train, y_train)
여기서는 0, 1, 2의 3가지 카테고리에 대한 분류를 진행할 것이기에
CategoricalNB 함수를 골라서 가져왔습니다.
참고로, 사이킷런에서는 총 5가지 종류의 NB 함수를 지원하는데,
각 함수를 골라야하는 상황을 정리하면 다음과 같습니다.
BernoulliNB : 가장 기본적인 NB 함수로 이진 분류 시에 사용합니다.
CategoricalNB : 분류할 카테고리의 종류가 3가지 이상일 때 사용합니다.
MultinomialNB : 텍스트의 등장 횟수처럼 이산적인 값의 수를 예측할 때 사용합니다.
GaussianNB : 예측할 값이 연속적인 값인 경우에 사용합니다.
ComplementNB : target label의 balance가 맞지 않는 불균형한 상황에 사용합니다.
다른 종류의 NB 함수의 사용 방법도 sklearn.naive_bayes 모듈에서 불러와 같은 방식으로
사용해주시면 됩니다.
정확도 성능 평가
이제 학습이 완료된 분류기에 대하여 분리해두었던 테스트 데이터를 가지고
라벨 예측의 정확도로 성능 평가를 진행해보겠습니다.
사이킷런의 accuracy_score 함수를 활용하여 아래와 같이 간단하게 성능 평가가 가능합니다.
from sklearn.metrics import accuracy_score
y_pred = model.predict(X_test) # 예측 라벨
print(accuracy_score(y_test, y_pred)) # 0.9111111111111111
여기서는 약 91% 정도의 분류 정확도를 기록하였습니다.
'Python > Sklearn' 카테고리의 다른 글
[Sklearn] 파이썬 단어 개수 세기 예제 : CountVectorizer 함수 (2) | 2022.08.29 |
---|---|
[Sklearn] 파이썬 Regularization : Lasso, Ridge, ElasticNet 적용하기 (0) | 2022.06.03 |
[Sklearn] 파이썬 MNIST 데이터셋 불러오기, 숫자 시각화 예제 (2) | 2022.05.31 |