산탄데르 고객만족 예측

https://www.kaggle.com/c/santander-customer-satisfaction

370개의 피처로 이루어진 데이터 세트 기반에서 고객만족 여부를 예측하는 것
피처이름은 모두 익명화되어 있어 어떤 속성인지는 알 수 없습니다.

클래스 레이블명은 TARGET이며 1이면 불만, 0이면 만족한 고객입니다.

모델의 성능평가는 ROC-AUC로 평가합니다.
대부분이 만족이고, 불만족인 데이터는 일부이기 때문에 단순 정확도 수치보다 ROU-AUC가 더 적합합니다.

데이터 전처리

클래스 값 컬럼을 포함하면 총 371개의 피처가 존재합니다.
피처들의 타입과 Null값을 더 살펴보겠습니다.

111개의 피처가 float형, 260개의 피처가 int형이며 NULL 값은 존재하지 않습니다.
다음으로 레이블인 TARGET의 분포를 살펴보겠습니다.

데이터프레임의 describe() 메소드를 사용해 각 피처의 값 분포를 간단히 확인해보겠습니다.

var3의 경우 최소값이 -999999로 나왔는데 1, 2, 3분위 수나 최댓값으로 미루어 보았을 때 결측치로 보입니다.

-999999의 값은 116개로 전체 데이터에 대해 비중이 적으므로 가장 값이 많은 2로 대체하겠습니다.
그리고 ID는 단순 식별자이므로 피처에서 드롭하겠습니다.
그 후에 피처들과 레이블 데이터를 분리해서 별도로 저장하겠습니다.

이후 성능평가를 위해서 현재 데이터 세트를 학습과 테스트 데이터 세트로 분리하겠습니다.
비대칭 데이터 세트이기 때문에 레이블의 분포가 원 데이터와 유사하게 추출되었는지 확인해보겠습니다.

데이터 세트 분할 결과 원 데이터셋과 유사하게 unsatisfied의 비율이 4%에 근접하는 수준이 되었습니다.

XGBoost 모델 학습과 하이퍼 파라미터 튜닝

XGBoost로 기본 세팅을 아래와 설정하고, 예측결과를 ROC-AUC로 평가해보겠습니다.

n_estimators = 500

early_stopping_rounds = 100

eval_metric = 'auc'

평가 데이터세트는 앞에서 분리한 테스트 데이터 세트를 이용하겠습니다.
테스트 데이터 세트를 XGBoost의 평가 데이터 세트로 사용하면 과적합이 될 우려가 있지만, 일단 진행하겠습니다.

eval_set = [(X_train, y_train), (X_test, y_test)]

테스트 데이터 세트로 예측시 ROU AUC는 약 0.8419가 나왔습니다.
다음으론 XGBoost의 하이퍼 파라미터 튜닝을 수행해보겠습니다.
컬럼의 수가 많아 과적합 가능성을 가정하고, max_depth, min_child_weight, colsample_bytree만
일차적으로 튜닝하겠습니다.
학습시간이 많이 필요한 ML모델의 경우 2~ 3개 정도의 파라미터를 결합해 최적파라미터를 찾아낸 뒤
이 파라미터를 기반으로 1~2개의 파라미터를 결합해 파라미터 튜닝을 수행하는 것입니다.

뒤의 예제 코드에서는 수행시간이 오래 걸리므로
n_estimators = 100, early_stopping_rounds = 30으로 줄여서 테스트하겠습니다.

파라미터가 위와 같을 때, 처음실행한 0.8419에서 0.8438로 성능이 소폭 좋아졌습니다.
이를 기반으로 n_estimator를 1000으로 증가시키고, learning_rate = 0.02로 감소시켰습니다.
그리고 reg_alpha = 0.03을 추가했습니다.

ROC AUC가 0.8441로 이전보다 조금 더 향상되었습니다.
튜닝된 모델에서 피처 중요도를 그래프로 나타내보겠습니다.

  1. 데이터 2020.12.09 17:07

    안녕하세요, 올려주신 내용 잘 보았습니다 감사합니다^^ 혹시 산탄데르 데이터는 어디서 받을 수 있을까요? 캐글은 닫혀 있는 것 같아서요~

  2. 박찬규 2021.02.28 13:06

    위키북스 담당자입니다. 현재 위키북스에서 출간한 책에 있는 내용을 무단으로 게제하고 있습니다. 바로 관련된 글을 삭제하지 않을 경우 저작권 위반으로 고발조치 됨을 알려드립니다. 빠른 조치해주시기 바랍니다.

1. LightGBM의 장단점

LightGBM의 장점

(1) XGBoost 대비 더 빠른 학습과 예측 수행 시간
(2) 더 작은 메무리 사용량
(3) 카테고리형 피처의 자동 변환과 최적 분할
: 원-핫인코딩 등을 사용하지 않고도 카테고리형 피처를 최적으로 변환하고 이에 따른 노드분할 수행

LightGBM의 단점

적은 데이터 세트에 적용할 경우 과적합이 발생하기 쉽습니다.
(공식 문서상 대략 10,000건 이하의 데이터 세트)

기존 GBM과의 차이점

일반적인 균형트리분할 (Level Wise) 방식과 달리 리프중심 트리분할(Leaf Wise) 방식을 사용합니다.

  • 균형트리분할은 최대한 균형 잡힌 트리를 유지하며 분할하여 트리의 깊이를 최소화하여
    오버피팅에 강한구조이지만 균형을 맞추기 위한 시간이 필요합니다.
  • 리프중심 트리분할의 경우 최대 손실 값을 가지는 리프노드를 지속적으로 분할하면서
    트리가 깊어지고 비대칭적으로 생성합니다. 이로써 예측 오류 손실을 최소화하고자 합니다.

 

2. LightGBM의 하이퍼 파라미터

하이퍼 파라미터 튜닝방안

num_leaves의 개수를 중심으로 min_child_sampes(min_data_in_leaf), max_depth를
함께 조절하면서 모델의 복잡도를 줄이는 것이 기본 튜닝 방안입니다.

  • num_leaves를 늘리면 정확도가 높아지지만 트리가 깊어지고 과접합되기 쉬움
  • min_child_samples(min_data_in_leaf)를 크게 설정하면 트리가 깊어지는 것을 방지
  • max_depth는 명시적으로 깊이를 제한. 위의 두 파라미터와 함꼐 과적합을 개선하는데 사용

또한, learning_rate을 줄이면서 n_estimator를 크게하는 것은 부스팅에서의 기본적인 튜닝 방안

 

3. LightGBM 적용 - 위스콘신 유방암 예측

LightGBM에서도 위스콘신 유방암 데이터 세트를 이용해 예측을 해보겠습니다.

feature importance 시각화

LightGBM도 plot_importance() 를 통해 시각화 가능

  1. 쿵짝작 2021.10.07 16:39 신고

    파라미터 설명까지 잘되어있어서 너무좋네요 감사합니다

4-5. Boosting Algorithm(XGBoost)
In [1]:
from IPython.core.display import display, HTML
display(HTML("<style> .container{width:90% !important;}</style>"))

1. XGBoost(eXtra Gradient Boost)의 개요

트리 기반의 알고리즘의 앙상블 학습에서 각광받는 알고리즘 중 하나입니다.
GBM에 기반하고 있지만, GBM의 단점인 느린 수행시간, 과적합 규제 등을 해결한 알고리즘입니다.

XGBoost의 주요장점

(1) 뛰어난 예측 성능
(2) GBM 대비 빠른 수행 시간
(3) 과적합 규제(Overfitting Regularization)
(4) Tree pruning(트리 가지치기) : 긍정 이득이 없는 분할을 가지치기해서 분할 수를 줄임
(5) 자체 내장된 교차 검증

  • 반복 수행시마다 내부적으로 교차검증을 수행해 최적회된 반복 수행횟수를 가질 수 있음
  • 지정된 반복횟수가 아니라 교차검증을 통해 평가 데이트세트의 평가 값이 최적화되면 반복을 중간에 멈출 수 있는 기능이 있음

(6) 결손값 자체 처리

XGBoost는 독자적인 XGBoost 모듈과 사이킷런 프레임워크 기반의 모듈이 존재합니다.
독자적인 모듈은 고유의 API와 하이퍼파라미터를 사용하지만, 사이킷런 기반 모듈에서는 다른 Estimator와 동일한 사용법을 가지고 있습니다.

2. XGBoost의 하이퍼 파라미터

일반 파라미터

: 일반적으로 실행 시 스레드의 개수나 silent 모드 등의 선택을 위한 파라미터, default 값을 바꾸는 일은 거의 없음

파라미터 명 설명
booster - gbtree(tree based model) 또는 gblinear(linear model) 중 선택
- Default = 'gbtree'
silent - Default = 1
- 출력 메시지를 나타내고 싶지 않을 경우 1로 설정
nthread - CPU 실행 스레드 개수 조정
- Default는 전체 다 사용하는 것
- 멀티코어/스레드 CPU 시스템에서 일부CPU만 사용할 때 변경

주요 부스터 파라미터

: 트리 최적화, 부스팅, regularization 등과 관련된 파라미터를 지칭

파라미터 명
(파이썬 래퍼)
파라미터명
(사이킷런 래퍼)
설명
eta
(0.3)
learning rate
(0.1)
- GBM의 learning rate와 같은 파라미터
- 범위: 0 ~ 1
num_boost_around
(10)
n_estimators
(100)
- 생성할 weak learner의 수
min_child_weight
(1)
min_child_weight
(1)
- GBM의 min_samples_leaf와 유사
- 관측치에 대한 가중치 합의 최소를 말하지만
GBM에서는 관측치 수에 대한 최소를 의미
- 과적합 조절 용도
- 범위: 0 ~ ∞
gamma
(0)
min_split_loss
(0)
- 리프노드의 추가분할을 결정할 최소손실 감소값
- 해당값보다 손실이 크게 감소할 때 분리
- 값이 클수록 과적합 감소효과
- 범위: 0 ~ ∞
max_depth
(6)
max_depth
(3)
- 트리 기반 알고리즘의 max_depth와 동일
- 0을 지정하면 깊이의 제한이 없음
- 너무 크면 과적합(통상 3~10정도 적용)
- 범위: 0 ~ ∞
sub_sample
(1)
subsample
(1)
- GBM의 subsample과 동일
- 데이터 샘플링 비율 지정(과적합 제어)
- 일반적으로 0.5~1 사이의 값을 사용
- 범위: 0 ~ 1
colsample_bytree
(1)
colsample_bytree
(1)
- GBM의 max_features와 유사
- 트리 생성에 필요한 피처의 샘플링에 사용
- 피처가 많을 때 과적합 조절에 사용
- 범위: 0 ~ 1
lambda
(1)
reg_lambda
(1)
- L2 Regularization 적용 값
- 피처 개수가 많을 때 적용을 검토
- 클수록 과적합 감소 효과
alpha
(0)
reg_alpha
(0)
- L1 Regularization 적용 값
- 피처 개수가 많을 때 적용을 검토
- 클수록 과적합 감소 효과
scale_pos_weight
(1)
scale_pos_weight
(1)
- 불균형 데이터셋의 균형을 유지

학습 태스크 파라미터

: 학습 수행 시의 객체함수, 평가를 위한 지표 등을 설정하는 파라미터

파라미터 명 설명
objective - ‘reg:linear’ : 회귀
- binary:logistic : 이진분류
- multi:softmax : 다중분류, 클래스 반환
- multi:softprob : 다중분류, 확륣반환
eval_metric - 검증에 사용되는 함수정의
- 회귀 분석인 경우 'rmse'를, 클래스 분류 문제인 경우 'error'
----------------------------------------------------
- rmse : Root Mean Squared Error
- mae : mean absolute error
- logloss : Negative log-likelihood
- error : binary classification error rate
- merror : multiclass classification error rate
- mlogloss: Multiclass logloss
- auc: Area Under Curve

과적합 제어

  • eta 값을 낮춥니다.(0.01 ~ 0.1) → eta 값을 낮추면 num_boost_round(n_estimator)를 반대로 높여주어야 합니다.
  • max_depth 값을 낮춥니다.
  • min_child_weight 값을 높입니다.
  • gamma 값을 높입니다.
  • subsample과 colsample_bytree를 낮춥니다.

Early Stopping 기능 :

GBM의 경우 n_estimators에 지정된 횟수만큼 학습을 끝까지 수행하지만, XGB의 경우 오류가 더 이상 개선되지 않으면 수행을 중지
n_estimators 를 200으로 설정하고, 조기 중단 파라미터 값을 50으로 설정하면, 1부터 200회까지 부스팅을 반복하다가
50회를 반복하는 동안 학습오류가 감소하지 않으면 더 이상 부스팅을 진행하지 않고 종료합니다.
(가령 100회에서 학습오류 값이 0.8인데 101~150회 반복하는 동안 예측 오류가 0.8보다 작은 값이 하나도 없으면 부스팅을 종료)

3. 파이썬 래퍼 XGBoost 적용

위스콘신 유방암 데이터 세트를 활용한 API 사용법

In [2]:
import xgboost as xgb ## XGBoost 불러오기
from xgboost import plot_importance ## Feature Importance를 불러오기 위함
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score
import warnings
warnings.filterwarnings('ignore')

dataset = load_breast_cancer()
X_features = dataset.data
y_label = dataset.target

cancer_df = pd.DataFrame(data=X_features, columns = dataset.feature_names)
cancer_df['target'] = y_label
cancer_df.head(3)
Out[2]:
mean radius mean texture mean perimeter mean area mean smoothness mean compactness mean concavity mean concave points mean symmetry mean fractal dimension ... worst texture worst perimeter worst area worst smoothness worst compactness worst concavity worst concave points worst symmetry worst fractal dimension target
0 17.99 10.38 122.8 1001.0 0.11840 0.27760 0.3001 0.14710 0.2419 0.07871 ... 17.33 184.6 2019.0 0.1622 0.6656 0.7119 0.2654 0.4601 0.11890 0
1 20.57 17.77 132.9 1326.0 0.08474 0.07864 0.0869 0.07017 0.1812 0.05667 ... 23.41 158.8 1956.0 0.1238 0.1866 0.2416 0.1860 0.2750 0.08902 0
2 19.69 21.25 130.0 1203.0 0.10960 0.15990 0.1974 0.12790 0.2069 0.05999 ... 25.53 152.5 1709.0 0.1444 0.4245 0.4504 0.2430 0.3613 0.08758 0

3 rows × 31 columns

위의 데이터셋에서 악성종양은 0, 양성은 1 값으로 되어 있음

In [3]:
print(dataset.target_names)
print(cancer_df['target'].value_counts())
['malignant' 'benign']
1    357
0    212
Name: target, dtype: int64
In [4]:
# 전체 데이터셋을 학습용 80%, 테스트용 20%로 분할
X_train, X_test, y_train, y_test = train_test_split(X_features, y_label, test_size=0.2, random_state=156)
print(X_train.shape, X_test.shape)
(455, 30) (114, 30)

파이썬래퍼 XGBoost와 사이킷런래퍼 XGBoost의 가장 큰 차이는
파이썬래퍼는 학습용과 테스트 데이터 세트를 위해 별도의 DMatrix를 생성한다는 것입니다.
DMatrix : 넘파이 입력 파라미터를 받아서 만들어지는 XGBoost만의 전용 데이터 세트

  • 주요 입력 파라미터는 data(피처 데이터 세트)와 label
    (분류: 레이블 데이터 세트/회귀: 숫자형인 종속값 데이터 세트)
  • 판다스의 DataFrame으로 데이터 인터페이스를 하기 위해서는 DataFrame.values를 이용해 넘파이로 일차변환 한 뒤에 DMatrix 변환을 적용
In [5]:
# 넘파이 형태의 학습 데이터 세트와 테스트 데이터를 DMatrix로 변환하는 예제
dtrain = xgb.DMatrix(data=X_train, label = y_train)
dtest = xgb.DMatrix(data=X_test, label=y_test)
In [6]:
# max_depth = 3, 학습률은 0.1, 예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
# 오류함수의 평가성능지표는 logloss
# 부스팅 반복횟수는 400
# 조기중단을 위한 최소 반복횟수는 100

params = {'max_depth' : 3,
         'eta' : 0.1, 
         'objective' : 'binary:logistic',
         'eval_metric' : 'logloss',
         'early_stoppings' : 100 }

num_rounds = 400

파이썬래퍼 XGBoost에서 하이퍼 파라미터를 xgboost 모듈의 train( ) 함수에 파라미터로 전달합니다.
(사이킷런 래퍼는 Estimator 생성자를 하이퍼 파라미터로 전달)

early_stopping_rounds 파라미터 : 조기 중단을 위한 라운드를 설정합니다.
조기 중단 기능 수행을 위해서는 반드시 eval_set과 eval_metric이 함께 설정되어야 합니다.

  • eval_set : 성능평가를 위한 평가용 데이터 세트를 설정
  • eval_metric : 평가 세트에 적용할 성능 평가 방법
    (반복마다 eval_set으로 지정된 데이터 세트에서 eval_metric의 지정된 평가 지표로 예측 오류를 측정)

train() 함수를 호출하면 xgboost가 반복 시마다 evals에 표시된 데이터 세트에 대해 평가 지표를 출력합니다.
그 후 학습이 완료된 모델 객체를 반환합니다.

In [7]:
# train 데이터 세트는 'train', evaluation(test) 데이터 세트는 'eval' 로 명기
wlist = [(dtrain, 'train'), (dtest,'eval')]
# 하이퍼 파라미터와 early stopping 파라미터를 train() 함수의 파라미터로 전달
xgb_model = xgb.train(params = params, dtrain=dtrain, num_boost_round=num_rounds, evals=wlist)
[0]	train-logloss:0.609685	eval-logloss:0.61352
[1]	train-logloss:0.540804	eval-logloss:0.547842
[2]	train-logloss:0.483755	eval-logloss:0.494247
[3]	train-logloss:0.434455	eval-logloss:0.447986
[4]	train-logloss:0.390549	eval-logloss:0.409109
[5]	train-logloss:0.354145	eval-logloss:0.374977
[6]	train-logloss:0.321222	eval-logloss:0.345714
[7]	train-logloss:0.292592	eval-logloss:0.320529
[8]	train-logloss:0.267467	eval-logloss:0.29721
[9]	train-logloss:0.245152	eval-logloss:0.277991
[10]	train-logloss:0.225694	eval-logloss:0.260302
[11]	train-logloss:0.207937	eval-logloss:0.246037
[12]	train-logloss:0.192183	eval-logloss:0.231556
[13]	train-logloss:0.177917	eval-logloss:0.22005
[14]	train-logloss:0.165221	eval-logloss:0.208572
[15]	train-logloss:0.153622	eval-logloss:0.199993
[16]	train-logloss:0.14333	eval-logloss:0.190118
[17]	train-logloss:0.133985	eval-logloss:0.181818
[18]	train-logloss:0.125599	eval-logloss:0.174729
[19]	train-logloss:0.117286	eval-logloss:0.167657
[20]	train-logloss:0.109688	eval-logloss:0.158202
[21]	train-logloss:0.102975	eval-logloss:0.154725
[22]	train-logloss:0.097068	eval-logloss:0.148947
[23]	train-logloss:0.091428	eval-logloss:0.143308
[24]	train-logloss:0.086335	eval-logloss:0.136344
[25]	train-logloss:0.081311	eval-logloss:0.132778
[26]	train-logloss:0.076857	eval-logloss:0.127912
[27]	train-logloss:0.072836	eval-logloss:0.125263
[28]	train-logloss:0.069248	eval-logloss:0.119978
[29]	train-logloss:0.065549	eval-logloss:0.116412
[30]	train-logloss:0.062414	eval-logloss:0.114502
[31]	train-logloss:0.059591	eval-logloss:0.112572
[32]	train-logloss:0.057096	eval-logloss:0.11154
[33]	train-logloss:0.054407	eval-logloss:0.108681
[34]	train-logloss:0.052036	eval-logloss:0.106681
[35]	train-logloss:0.049751	eval-logloss:0.104207
[36]	train-logloss:0.04775	eval-logloss:0.102962
[37]	train-logloss:0.045854	eval-logloss:0.100576
[38]	train-logloss:0.044015	eval-logloss:0.098683
[39]	train-logloss:0.042263	eval-logloss:0.096444
[40]	train-logloss:0.040649	eval-logloss:0.095869
[41]	train-logloss:0.039126	eval-logloss:0.094242
[42]	train-logloss:0.037377	eval-logloss:0.094715
[43]	train-logloss:0.036106	eval-logloss:0.094272
[44]	train-logloss:0.034941	eval-logloss:0.093894
[45]	train-logloss:0.033654	eval-logloss:0.094184
[46]	train-logloss:0.032528	eval-logloss:0.09402
[47]	train-logloss:0.031485	eval-logloss:0.09236
[48]	train-logloss:0.030389	eval-logloss:0.093012
[49]	train-logloss:0.029467	eval-logloss:0.091272
[50]	train-logloss:0.028545	eval-logloss:0.090051
[51]	train-logloss:0.027525	eval-logloss:0.089605
[52]	train-logloss:0.026555	eval-logloss:0.089577
[53]	train-logloss:0.025682	eval-logloss:0.090703
[54]	train-logloss:0.025004	eval-logloss:0.089579
[55]	train-logloss:0.024297	eval-logloss:0.090357
[56]	train-logloss:0.023574	eval-logloss:0.091587
[57]	train-logloss:0.022965	eval-logloss:0.091527
[58]	train-logloss:0.022488	eval-logloss:0.091986
[59]	train-logloss:0.021854	eval-logloss:0.091951
[60]	train-logloss:0.021316	eval-logloss:0.091939
[61]	train-logloss:0.020794	eval-logloss:0.091461
[62]	train-logloss:0.020218	eval-logloss:0.090311
[63]	train-logloss:0.019701	eval-logloss:0.089407
[64]	train-logloss:0.01918	eval-logloss:0.089719
[65]	train-logloss:0.018724	eval-logloss:0.089743
[66]	train-logloss:0.018325	eval-logloss:0.089622
[67]	train-logloss:0.017867	eval-logloss:0.088734
[68]	train-logloss:0.017598	eval-logloss:0.088621
[69]	train-logloss:0.017243	eval-logloss:0.089739
[70]	train-logloss:0.01688	eval-logloss:0.089981
[71]	train-logloss:0.016641	eval-logloss:0.089782
[72]	train-logloss:0.016287	eval-logloss:0.089584
[73]	train-logloss:0.015983	eval-logloss:0.089533
[74]	train-logloss:0.015658	eval-logloss:0.088748
[75]	train-logloss:0.015393	eval-logloss:0.088597
[76]	train-logloss:0.015151	eval-logloss:0.08812
[77]	train-logloss:0.01488	eval-logloss:0.088396
[78]	train-logloss:0.014637	eval-logloss:0.088736
[79]	train-logloss:0.014491	eval-logloss:0.088153
[80]	train-logloss:0.014185	eval-logloss:0.087577
[81]	train-logloss:0.014005	eval-logloss:0.087412
[82]	train-logloss:0.013772	eval-logloss:0.08849
[83]	train-logloss:0.013567	eval-logloss:0.088575
[84]	train-logloss:0.013414	eval-logloss:0.08807
[85]	train-logloss:0.013253	eval-logloss:0.087641
[86]	train-logloss:0.013109	eval-logloss:0.087416
[87]	train-logloss:0.012926	eval-logloss:0.087611
[88]	train-logloss:0.012714	eval-logloss:0.087065
[89]	train-logloss:0.012544	eval-logloss:0.08727
[90]	train-logloss:0.012353	eval-logloss:0.087161
[91]	train-logloss:0.012226	eval-logloss:0.086962
[92]	train-logloss:0.012065	eval-logloss:0.087166
[93]	train-logloss:0.011927	eval-logloss:0.087067
[94]	train-logloss:0.011821	eval-logloss:0.086592
[95]	train-logloss:0.011649	eval-logloss:0.086116
[96]	train-logloss:0.011482	eval-logloss:0.087139
[97]	train-logloss:0.01136	eval-logloss:0.086768
[98]	train-logloss:0.011239	eval-logloss:0.086694
[99]	train-logloss:0.011132	eval-logloss:0.086547
[100]	train-logloss:0.011002	eval-logloss:0.086498
[101]	train-logloss:0.010852	eval-logloss:0.08641
[102]	train-logloss:0.010755	eval-logloss:0.086288
[103]	train-logloss:0.010636	eval-logloss:0.086258
[104]	train-logloss:0.0105	eval-logloss:0.086835
[105]	train-logloss:0.010395	eval-logloss:0.086767
[106]	train-logloss:0.010305	eval-logloss:0.087321
[107]	train-logloss:0.010197	eval-logloss:0.087304
[108]	train-logloss:0.010072	eval-logloss:0.08728
[109]	train-logloss:0.01	eval-logloss:0.087298
[110]	train-logloss:0.009914	eval-logloss:0.087289
[111]	train-logloss:0.009798	eval-logloss:0.088002
[112]	train-logloss:0.00971	eval-logloss:0.087936
[113]	train-logloss:0.009628	eval-logloss:0.087843
[114]	train-logloss:0.009558	eval-logloss:0.088066
[115]	train-logloss:0.009483	eval-logloss:0.087649
[116]	train-logloss:0.009416	eval-logloss:0.087298
[117]	train-logloss:0.009306	eval-logloss:0.087799
[118]	train-logloss:0.009228	eval-logloss:0.087751
[119]	train-logloss:0.009154	eval-logloss:0.08768
[120]	train-logloss:0.009118	eval-logloss:0.087626
[121]	train-logloss:0.009016	eval-logloss:0.08757
[122]	train-logloss:0.008972	eval-logloss:0.087547
[123]	train-logloss:0.008904	eval-logloss:0.087156
[124]	train-logloss:0.008837	eval-logloss:0.08767
[125]	train-logloss:0.008803	eval-logloss:0.087737
[126]	train-logloss:0.008709	eval-logloss:0.088275
[127]	train-logloss:0.008645	eval-logloss:0.088309
[128]	train-logloss:0.008613	eval-logloss:0.088266
[129]	train-logloss:0.008555	eval-logloss:0.087886
[130]	train-logloss:0.008463	eval-logloss:0.088861
[131]	train-logloss:0.008416	eval-logloss:0.088675
[132]	train-logloss:0.008385	eval-logloss:0.088743
[133]	train-logloss:0.0083	eval-logloss:0.089218
[134]	train-logloss:0.00827	eval-logloss:0.089179
[135]	train-logloss:0.008218	eval-logloss:0.088821
[136]	train-logloss:0.008157	eval-logloss:0.088512
[137]	train-logloss:0.008076	eval-logloss:0.08848
[138]	train-logloss:0.008047	eval-logloss:0.088386
[139]	train-logloss:0.007973	eval-logloss:0.089145
[140]	train-logloss:0.007946	eval-logloss:0.08911
[141]	train-logloss:0.007898	eval-logloss:0.088765
[142]	train-logloss:0.007872	eval-logloss:0.088678
[143]	train-logloss:0.007847	eval-logloss:0.088389
[144]	train-logloss:0.007776	eval-logloss:0.089271
[145]	train-logloss:0.007752	eval-logloss:0.089238
[146]	train-logloss:0.007728	eval-logloss:0.089139
[147]	train-logloss:0.007689	eval-logloss:0.088907
[148]	train-logloss:0.007621	eval-logloss:0.089416
[149]	train-logloss:0.007598	eval-logloss:0.089388
[150]	train-logloss:0.007575	eval-logloss:0.089108
[151]	train-logloss:0.007521	eval-logloss:0.088735
[152]	train-logloss:0.007498	eval-logloss:0.088717
[153]	train-logloss:0.007464	eval-logloss:0.088484
[154]	train-logloss:0.00741	eval-logloss:0.088471
[155]	train-logloss:0.007389	eval-logloss:0.088545
[156]	train-logloss:0.007367	eval-logloss:0.088521
[157]	train-logloss:0.007345	eval-logloss:0.088547
[158]	train-logloss:0.007323	eval-logloss:0.088275
[159]	train-logloss:0.007303	eval-logloss:0.0883
[160]	train-logloss:0.007282	eval-logloss:0.08828
[161]	train-logloss:0.007261	eval-logloss:0.088013
[162]	train-logloss:0.007241	eval-logloss:0.087758
[163]	train-logloss:0.007221	eval-logloss:0.087784
[164]	train-logloss:0.0072	eval-logloss:0.087777
[165]	train-logloss:0.00718	eval-logloss:0.087517
[166]	train-logloss:0.007161	eval-logloss:0.087542
[167]	train-logloss:0.007142	eval-logloss:0.087642
[168]	train-logloss:0.007122	eval-logloss:0.08739
[169]	train-logloss:0.007103	eval-logloss:0.087377
[170]	train-logloss:0.007084	eval-logloss:0.087298
[171]	train-logloss:0.007065	eval-logloss:0.087368
[172]	train-logloss:0.007047	eval-logloss:0.087395
[173]	train-logloss:0.007028	eval-logloss:0.087385
[174]	train-logloss:0.007009	eval-logloss:0.087132
[175]	train-logloss:0.006991	eval-logloss:0.087159
[176]	train-logloss:0.006973	eval-logloss:0.086955
[177]	train-logloss:0.006955	eval-logloss:0.087053
[178]	train-logloss:0.006937	eval-logloss:0.08697
[179]	train-logloss:0.00692	eval-logloss:0.086973
[180]	train-logloss:0.006901	eval-logloss:0.087038
[181]	train-logloss:0.006884	eval-logloss:0.086799
[182]	train-logloss:0.006866	eval-logloss:0.086826
[183]	train-logloss:0.006849	eval-logloss:0.086582
[184]	train-logloss:0.006831	eval-logloss:0.086588
[185]	train-logloss:0.006815	eval-logloss:0.086614
[186]	train-logloss:0.006798	eval-logloss:0.086372
[187]	train-logloss:0.006781	eval-logloss:0.086369
[188]	train-logloss:0.006764	eval-logloss:0.086297
[189]	train-logloss:0.006747	eval-logloss:0.086104
[190]	train-logloss:0.00673	eval-logloss:0.086023
[191]	train-logloss:0.006714	eval-logloss:0.08605
[192]	train-logloss:0.006698	eval-logloss:0.086149
[193]	train-logloss:0.006682	eval-logloss:0.085916
[194]	train-logloss:0.006666	eval-logloss:0.085915
[195]	train-logloss:0.00665	eval-logloss:0.085984
[196]	train-logloss:0.006634	eval-logloss:0.086012
[197]	train-logloss:0.006618	eval-logloss:0.085922
[198]	train-logloss:0.006603	eval-logloss:0.085853
[199]	train-logloss:0.006587	eval-logloss:0.085874
[200]	train-logloss:0.006572	eval-logloss:0.085888
[201]	train-logloss:0.006556	eval-logloss:0.08595
[202]	train-logloss:0.006542	eval-logloss:0.08573
[203]	train-logloss:0.006527	eval-logloss:0.08573
[204]	train-logloss:0.006512	eval-logloss:0.085753
[205]	train-logloss:0.006497	eval-logloss:0.085821
[206]	train-logloss:0.006483	eval-logloss:0.08584
[207]	train-logloss:0.006469	eval-logloss:0.085776
[208]	train-logloss:0.006455	eval-logloss:0.085686
[209]	train-logloss:0.00644	eval-logloss:0.08571
[210]	train-logloss:0.006427	eval-logloss:0.085806
[211]	train-logloss:0.006413	eval-logloss:0.085593
[212]	train-logloss:0.006399	eval-logloss:0.085801
[213]	train-logloss:0.006385	eval-logloss:0.085806
[214]	train-logloss:0.006372	eval-logloss:0.085744
[215]	train-logloss:0.006359	eval-logloss:0.085658
[216]	train-logloss:0.006345	eval-logloss:0.085843
[217]	train-logloss:0.006332	eval-logloss:0.085632
[218]	train-logloss:0.006319	eval-logloss:0.085726
[219]	train-logloss:0.006306	eval-logloss:0.085783
[220]	train-logloss:0.006293	eval-logloss:0.085791
[221]	train-logloss:0.00628	eval-logloss:0.085817
[222]	train-logloss:0.006268	eval-logloss:0.085757
[223]	train-logloss:0.006255	eval-logloss:0.085674
[224]	train-logloss:0.006242	eval-logloss:0.08586
[225]	train-logloss:0.00623	eval-logloss:0.085871
[226]	train-logloss:0.006218	eval-logloss:0.085927
[227]	train-logloss:0.006206	eval-logloss:0.085954
[228]	train-logloss:0.006194	eval-logloss:0.085874
[229]	train-logloss:0.006182	eval-logloss:0.086057
[230]	train-logloss:0.00617	eval-logloss:0.086002
[231]	train-logloss:0.006158	eval-logloss:0.085922
[232]	train-logloss:0.006147	eval-logloss:0.086102
[233]	train-logloss:0.006135	eval-logloss:0.086115
[234]	train-logloss:0.006124	eval-logloss:0.086169
[235]	train-logloss:0.006112	eval-logloss:0.086263
[236]	train-logloss:0.006101	eval-logloss:0.086292
[237]	train-logloss:0.00609	eval-logloss:0.086217
[238]	train-logloss:0.006079	eval-logloss:0.086395
[239]	train-logloss:0.006068	eval-logloss:0.086342
[240]	train-logloss:0.006057	eval-logloss:0.08618
[241]	train-logloss:0.006046	eval-logloss:0.086195
[242]	train-logloss:0.006036	eval-logloss:0.086248
[243]	train-logloss:0.006025	eval-logloss:0.086263
[244]	train-logloss:0.006014	eval-logloss:0.086293
[245]	train-logloss:0.006004	eval-logloss:0.086222
[246]	train-logloss:0.005993	eval-logloss:0.086398
[247]	train-logloss:0.005983	eval-logloss:0.086347
[248]	train-logloss:0.005972	eval-logloss:0.086276
[249]	train-logloss:0.005962	eval-logloss:0.086448
[250]	train-logloss:0.005952	eval-logloss:0.086294
[251]	train-logloss:0.005942	eval-logloss:0.086312
[252]	train-logloss:0.005932	eval-logloss:0.086364
[253]	train-logloss:0.005922	eval-logloss:0.086394
[254]	train-logloss:0.005912	eval-logloss:0.08649
[255]	train-logloss:0.005903	eval-logloss:0.086441
[256]	train-logloss:0.005893	eval-logloss:0.08629
[257]	train-logloss:0.005883	eval-logloss:0.08646
[258]	train-logloss:0.005874	eval-logloss:0.086391
[259]	train-logloss:0.005864	eval-logloss:0.086441
[260]	train-logloss:0.005855	eval-logloss:0.086461
[261]	train-logloss:0.005845	eval-logloss:0.086491
[262]	train-logloss:0.005836	eval-logloss:0.086445
[263]	train-logloss:0.005827	eval-logloss:0.086466
[264]	train-logloss:0.005818	eval-logloss:0.086319
[265]	train-logloss:0.005809	eval-logloss:0.086488
[266]	train-logloss:0.0058	eval-logloss:0.086538
[267]	train-logloss:0.005791	eval-logloss:0.086471
[268]	train-logloss:0.005782	eval-logloss:0.086501
[269]	train-logloss:0.005773	eval-logloss:0.086522
[270]	train-logloss:0.005764	eval-logloss:0.086689
[271]	train-logloss:0.005755	eval-logloss:0.086738
[272]	train-logloss:0.005747	eval-logloss:0.08683
[273]	train-logloss:0.005738	eval-logloss:0.086684
[274]	train-logloss:0.005729	eval-logloss:0.08664
[275]	train-logloss:0.005721	eval-logloss:0.086496
[276]	train-logloss:0.005712	eval-logloss:0.086355
[277]	train-logloss:0.005704	eval-logloss:0.086519
[278]	train-logloss:0.005696	eval-logloss:0.086567
[279]	train-logloss:0.005687	eval-logloss:0.08659
[280]	train-logloss:0.005679	eval-logloss:0.086679
[281]	train-logloss:0.005671	eval-logloss:0.086637
[282]	train-logloss:0.005663	eval-logloss:0.086499
[283]	train-logloss:0.005655	eval-logloss:0.086356
[284]	train-logloss:0.005646	eval-logloss:0.086405
[285]	train-logloss:0.005639	eval-logloss:0.086429
[286]	train-logloss:0.005631	eval-logloss:0.086456
[287]	train-logloss:0.005623	eval-logloss:0.086504
[288]	train-logloss:0.005615	eval-logloss:0.08637
[289]	train-logloss:0.005608	eval-logloss:0.086457
[290]	train-logloss:0.0056	eval-logloss:0.086453
[291]	train-logloss:0.005593	eval-logloss:0.086322
[292]	train-logloss:0.005585	eval-logloss:0.086284
[293]	train-logloss:0.005577	eval-logloss:0.086148
[294]	train-logloss:0.00557	eval-logloss:0.086196
[295]	train-logloss:0.005563	eval-logloss:0.086221
[296]	train-logloss:0.005556	eval-logloss:0.086308
[297]	train-logloss:0.005548	eval-logloss:0.086178
[298]	train-logloss:0.005541	eval-logloss:0.086263
[299]	train-logloss:0.005534	eval-logloss:0.086131
[300]	train-logloss:0.005527	eval-logloss:0.086179
[301]	train-logloss:0.005519	eval-logloss:0.086052
[302]	train-logloss:0.005512	eval-logloss:0.086016
[303]	train-logloss:0.005505	eval-logloss:0.086101
[304]	train-logloss:0.005498	eval-logloss:0.085977
[305]	train-logloss:0.005491	eval-logloss:0.086059
[306]	train-logloss:0.005484	eval-logloss:0.085971
[307]	train-logloss:0.005478	eval-logloss:0.085998
[308]	train-logloss:0.005471	eval-logloss:0.085999
[309]	train-logloss:0.005464	eval-logloss:0.085877
[310]	train-logloss:0.005457	eval-logloss:0.085923
[311]	train-logloss:0.00545	eval-logloss:0.085948
[312]	train-logloss:0.005444	eval-logloss:0.086028
[313]	train-logloss:0.005437	eval-logloss:0.086112
[314]	train-logloss:0.00543	eval-logloss:0.085989
[315]	train-logloss:0.005424	eval-logloss:0.085903
[316]	train-logloss:0.005417	eval-logloss:0.085949
[317]	train-logloss:0.005411	eval-logloss:0.085977
[318]	train-logloss:0.005404	eval-logloss:0.086002
[319]	train-logloss:0.005398	eval-logloss:0.085883
[320]	train-logloss:0.005392	eval-logloss:0.085967
[321]	train-logloss:0.005385	eval-logloss:0.086046
[322]	train-logloss:0.005379	eval-logloss:0.086091
[323]	train-logloss:0.005373	eval-logloss:0.085977
[324]	train-logloss:0.005366	eval-logloss:0.085978
[325]	train-logloss:0.00536	eval-logloss:0.085896
[326]	train-logloss:0.005354	eval-logloss:0.08578
[327]	train-logloss:0.005348	eval-logloss:0.085857
[328]	train-logloss:0.005342	eval-logloss:0.085939
[329]	train-logloss:0.005336	eval-logloss:0.085825
[330]	train-logloss:0.00533	eval-logloss:0.085869
[331]	train-logloss:0.005324	eval-logloss:0.085893
[332]	train-logloss:0.005318	eval-logloss:0.085922
[333]	train-logloss:0.005312	eval-logloss:0.085842
[334]	train-logloss:0.005306	eval-logloss:0.085735
[335]	train-logloss:0.0053	eval-logloss:0.085816
[336]	train-logloss:0.005294	eval-logloss:0.085892
[337]	train-logloss:0.005288	eval-logloss:0.085936
[338]	train-logloss:0.005283	eval-logloss:0.08583
[339]	train-logloss:0.005277	eval-logloss:0.085909
[340]	train-logloss:0.005271	eval-logloss:0.085831
[341]	train-logloss:0.005265	eval-logloss:0.085727
[342]	train-logloss:0.00526	eval-logloss:0.085678
[343]	train-logloss:0.005254	eval-logloss:0.085721
[344]	train-logloss:0.005249	eval-logloss:0.085796
[345]	train-logloss:0.005243	eval-logloss:0.085819
[346]	train-logloss:0.005237	eval-logloss:0.085715
[347]	train-logloss:0.005232	eval-logloss:0.085793
[348]	train-logloss:0.005227	eval-logloss:0.085835
[349]	train-logloss:0.005221	eval-logloss:0.085734
[350]	train-logloss:0.005216	eval-logloss:0.085658
[351]	train-logloss:0.00521	eval-logloss:0.08573
[352]	train-logloss:0.005205	eval-logloss:0.085807
[353]	train-logloss:0.0052	eval-logloss:0.085706
[354]	train-logloss:0.005195	eval-logloss:0.085659
[355]	train-logloss:0.005189	eval-logloss:0.085701
[356]	train-logloss:0.005184	eval-logloss:0.085628
[357]	train-logloss:0.005179	eval-logloss:0.085529
[358]	train-logloss:0.005174	eval-logloss:0.085604
[359]	train-logloss:0.005169	eval-logloss:0.085676
[360]	train-logloss:0.005164	eval-logloss:0.085579
[361]	train-logloss:0.005159	eval-logloss:0.085601
[362]	train-logloss:0.005153	eval-logloss:0.085643
[363]	train-logloss:0.005149	eval-logloss:0.085713
[364]	train-logloss:0.005144	eval-logloss:0.085787
[365]	train-logloss:0.005139	eval-logloss:0.085689
[366]	train-logloss:0.005134	eval-logloss:0.08573
[367]	train-logloss:0.005129	eval-logloss:0.085684
[368]	train-logloss:0.005124	eval-logloss:0.085589
[369]	train-logloss:0.005119	eval-logloss:0.085516
[370]	train-logloss:0.005114	eval-logloss:0.085588
[371]	train-logloss:0.00511	eval-logloss:0.085495
[372]	train-logloss:0.005105	eval-logloss:0.085564
[373]	train-logloss:0.0051	eval-logloss:0.085605
[374]	train-logloss:0.005096	eval-logloss:0.085626
[375]	train-logloss:0.005091	eval-logloss:0.085535
[376]	train-logloss:0.005086	eval-logloss:0.085606
[377]	train-logloss:0.005082	eval-logloss:0.085674
[378]	train-logloss:0.005077	eval-logloss:0.085714
[379]	train-logloss:0.005073	eval-logloss:0.085624
[380]	train-logloss:0.005068	eval-logloss:0.085579
[381]	train-logloss:0.005064	eval-logloss:0.085618
[382]	train-logloss:0.00506	eval-logloss:0.085639
[383]	train-logloss:0.005055	eval-logloss:0.08555
[384]	train-logloss:0.005051	eval-logloss:0.085617
[385]	train-logloss:0.005047	eval-logloss:0.085621
[386]	train-logloss:0.005042	eval-logloss:0.085551
[387]	train-logloss:0.005038	eval-logloss:0.085463
[388]	train-logloss:0.005034	eval-logloss:0.085502
[389]	train-logloss:0.005029	eval-logloss:0.085459
[390]	train-logloss:0.005025	eval-logloss:0.085321
[391]	train-logloss:0.005021	eval-logloss:0.085389
[392]	train-logloss:0.005017	eval-logloss:0.085303
[393]	train-logloss:0.005013	eval-logloss:0.085369
[394]	train-logloss:0.005009	eval-logloss:0.085301
[395]	train-logloss:0.005005	eval-logloss:0.085368
[396]	train-logloss:0.005	eval-logloss:0.085283
[397]	train-logloss:0.004996	eval-logloss:0.08532
[398]	train-logloss:0.004992	eval-logloss:0.085279
[399]	train-logloss:0.004988	eval-logloss:0.085196

train( ) 을 통해 학습을 수행하면 반복시 train-logloss와 eval-logloss가 지속적으로 감소합니다.
xgboost를 이용해 학습이 완료됐으면 predict() 메서드를 이용해 예측을 수행합니다.
여기서 파이썬 래퍼는 예측 결과를 추정할 수 있는 호가률 값을 반환합니다.(반면 사이킷런 래퍼는 클래스 값을 반환)

In [8]:
pred_probs = xgb_model.predict(dtest)
print('predict() 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨')
print(np.round(pred_probs[:10], 3))

# 예측 확률이 0.5보다 크면 1, 그렇지 않으면 0으로 예측값 결정해 리스트 객체인 preds에 저장
preds = [ 1 if x > 0.5 else 0 for x in pred_probs]
print('예측값 10개만 표시: ', preds[:10])
predict() 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨
[0.95  0.003 0.9   0.086 0.993 1.    1.    0.999 0.998 0.   ]
예측값 10개만 표시:  [1, 0, 1, 0, 1, 1, 1, 1, 1, 0]
In [9]:
# 혼동행렬, 정확도, 정밀도, 재현율, F1, AUC 불러오기
def get_clf_eval(y_test, y_pred):
    confusion = confusion_matrix(y_test, y_pred)
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    F1 = f1_score(y_test, y_pred)
    AUC = roc_auc_score(y_test, y_pred)
    print('오차행렬:\n', confusion)
    print('\n정확도: {:.4f}'.format(accuracy))
    print('정밀도: {:.4f}'.format(precision))
    print('재현율: {:.4f}'.format(recall))
    print('F1: {:.4f}'.format(F1))
    print('AUC: {:.4f}'.format(AUC))
In [10]:
get_clf_eval(y_test, preds)
오차행렬:
 [[35  2]
 [ 1 76]]

정확도: 0.9737
정밀도: 0.9744
재현율: 0.9870
F1: 0.9806
AUC: 0.9665

Feature importance를 시각화할 때,
→ 기본 평가 지료로 f1스코어를 기반으로 각 feature의 중요도를 나타냅니다.
→ 사이킷런 래퍼는 estimator 객체의 featureimportances 속성을 이용해 시각화 코드를 직접 작성해야 합니다.
→ 반면, 파이썬 래퍼는 plot_importance()를 이용해 바로 피처 중요 코드를 시각화 할 수 있습니다.

In [11]:
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(xgb_model, ax=ax)
Out[11]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a20bbf7f0>

다만, xgboost 넘파이 기반의 피처 데이터로 학습 시에 피처명을 제대로 알 수 없으므로
피처별로 f자 뒤에 순서를 붙여 X축에 피처들로 나열합니다.(f0는 첫번째 피처, f1는 두번째 피처를 의미)

파이썬 래퍼의 교차 검증 수행 및 최적 파라미터 구하기

: xgboost는 사이킷런의 GridSearchCV와 유사하게 cv( )를 API로 제공합니다.

xgb.cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False,
folds=None, metrics=(),obj=None, feval=None, maximize=False,
early_stopping_rounds=None, fpreproc=None, as_pandas=True,
verbose_eval=None, show_stdv=True, seed=0, callbacks=None, shuffle=True)

  • params(dict): 부스터 파라미터
  • dtrain(DMatrix) : 학습 데이터
  • num_boost_round(int) : 부스팅 반복횟수
  • nfold(int) : CV폴드 개수
  • stratified(bool) : CV수행시 샘플을 균등하게 추출할지 여부
  • metrics(string or list of strings) : CV 수행시 모니터링할 성능 평가 지표
  • early_stopping_rounds(int) : 조기중단을 활성화시킴. 반복횟수 지정

xgv.cv의 반환 값은 데이터프레임 형태입니다.

4. 사이킷런 래퍼 XGBoost의 개요 및 적용

특징

  • 사이킷런의 기본 Estimator를 이용해 만들어 fit()과 predict()만으로 학습과 예측이 가능
  • GridSearchCV,Pipeline 등 사이킷런의 유틸리티를 그대로 사용 가능
  • 분류 : XGBClassifier / 회귀 : XGBRegressor

파이썬 래퍼와 비교시 달라진 파라미터

  • eta → learning_rate
  • sub_sample → subsample
  • lambda → reg_lambda
  • alpha → reg_alpha
  • num_boost_round → n_estimators

위와 동일하게 위스콘신 유방암 데이터를 통한 예측

In [12]:
# max_depth = 3, 학습률은 0.1, 예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
# 부스팅 반복횟수는 400

from xgboost import XGBClassifier

xgb_wrapper = XGBClassifier(n_estimators = 400, learning_rate = 0.1, max_depth = 3)
xgb_wrapper.fit(X_train, y_train)
w_preds = xgb_wrapper.predict(X_test)
In [13]:
# 예측 결과 확인
get_clf_eval(y_test, w_preds)
오차행렬:
 [[35  2]
 [ 1 76]]

정확도: 0.9737
정밀도: 0.9744
재현율: 0.9870
F1: 0.9806
AUC: 0.9665

앞선 파이썬 래퍼 XGBoost와 동일한 결과가 나옵니다.

사이킷런 래퍼 XGBoost에서도 조기 중단 기능을 수행할 수 있는데 fit( )에 해당 파라미터를 입력하면 됩니다.
→ early_stopping_rounds, eval_metrics, eval_set

In [14]:
# max_depth = 3, 학습률은 0.1, 예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
# 오류함수의 평가성능지표는 logloss
# 부스팅 반복횟수는 400
# 조기중단을 위한 최소 반복횟수는 100

# 아래 예제에서는 평가를 위한 데이터 세트로 테스트 데이터 세트를 사용했지만, 바람직하진 않습니다.
# 테스트 데이터 세트는 학습에 완전히 알려지지 않은 데이터 세트를 사용해야 합니다.
# 평가에 테스트 데이터 세트를 사용하면 학습시에 미리 참고가 되어 과적합할 수 있기 때문입니다.

xgb_wrapper = XGBClassifier(n_estimators = 400, learning_rate = 0.1 , max_depth = 3)
evals = [(X_test, y_test)]
xgb_wrapper.fit(X_train, y_train, early_stopping_rounds = 100, 
                eval_metric="logloss", eval_set = evals, verbose=True)
ws100_preds = xgb_wrapper.predict(X_test)
[0]	validation_0-logloss:0.61352
Will train until validation_0-logloss hasn't improved in 100 rounds.
[1]	validation_0-logloss:0.547842
[2]	validation_0-logloss:0.494247
[3]	validation_0-logloss:0.447986
[4]	validation_0-logloss:0.409109
[5]	validation_0-logloss:0.374977
[6]	validation_0-logloss:0.345714
[7]	validation_0-logloss:0.320529
[8]	validation_0-logloss:0.29721
[9]	validation_0-logloss:0.277991
[10]	validation_0-logloss:0.260302
[11]	validation_0-logloss:0.246037
[12]	validation_0-logloss:0.231556
[13]	validation_0-logloss:0.22005
[14]	validation_0-logloss:0.208572
[15]	validation_0-logloss:0.199993
[16]	validation_0-logloss:0.190118
[17]	validation_0-logloss:0.181818
[18]	validation_0-logloss:0.174729
[19]	validation_0-logloss:0.167657
[20]	validation_0-logloss:0.158202
[21]	validation_0-logloss:0.154725
[22]	validation_0-logloss:0.148947
[23]	validation_0-logloss:0.143308
[24]	validation_0-logloss:0.136344
[25]	validation_0-logloss:0.132778
[26]	validation_0-logloss:0.127912
[27]	validation_0-logloss:0.125263
[28]	validation_0-logloss:0.119978
[29]	validation_0-logloss:0.116412
[30]	validation_0-logloss:0.114502
[31]	validation_0-logloss:0.112572
[32]	validation_0-logloss:0.11154
[33]	validation_0-logloss:0.108681
[34]	validation_0-logloss:0.106681
[35]	validation_0-logloss:0.104207
[36]	validation_0-logloss:0.102962
[37]	validation_0-logloss:0.100576
[38]	validation_0-logloss:0.098683
[39]	validation_0-logloss:0.096444
[40]	validation_0-logloss:0.095869
[41]	validation_0-logloss:0.094242
[42]	validation_0-logloss:0.094715
[43]	validation_0-logloss:0.094272
[44]	validation_0-logloss:0.093894
[45]	validation_0-logloss:0.094184
[46]	validation_0-logloss:0.09402
[47]	validation_0-logloss:0.09236
[48]	validation_0-logloss:0.093012
[49]	validation_0-logloss:0.091272
[50]	validation_0-logloss:0.090051
[51]	validation_0-logloss:0.089605
[52]	validation_0-logloss:0.089577
[53]	validation_0-logloss:0.090703
[54]	validation_0-logloss:0.089579
[55]	validation_0-logloss:0.090357
[56]	validation_0-logloss:0.091587
[57]	validation_0-logloss:0.091527
[58]	validation_0-logloss:0.091986
[59]	validation_0-logloss:0.091951
[60]	validation_0-logloss:0.091939
[61]	validation_0-logloss:0.091461
[62]	validation_0-logloss:0.090311
[63]	validation_0-logloss:0.089407
[64]	validation_0-logloss:0.089719
[65]	validation_0-logloss:0.089743
[66]	validation_0-logloss:0.089622
[67]	validation_0-logloss:0.088734
[68]	validation_0-logloss:0.088621
[69]	validation_0-logloss:0.089739
[70]	validation_0-logloss:0.089981
[71]	validation_0-logloss:0.089782
[72]	validation_0-logloss:0.089584
[73]	validation_0-logloss:0.089533
[74]	validation_0-logloss:0.088748
[75]	validation_0-logloss:0.088597
[76]	validation_0-logloss:0.08812
[77]	validation_0-logloss:0.088396
[78]	validation_0-logloss:0.088736
[79]	validation_0-logloss:0.088153
[80]	validation_0-logloss:0.087577
[81]	validation_0-logloss:0.087412
[82]	validation_0-logloss:0.08849
[83]	validation_0-logloss:0.088575
[84]	validation_0-logloss:0.08807
[85]	validation_0-logloss:0.087641
[86]	validation_0-logloss:0.087416
[87]	validation_0-logloss:0.087611
[88]	validation_0-logloss:0.087065
[89]	validation_0-logloss:0.08727
[90]	validation_0-logloss:0.087161
[91]	validation_0-logloss:0.086962
[92]	validation_0-logloss:0.087166
[93]	validation_0-logloss:0.087067
[94]	validation_0-logloss:0.086592
[95]	validation_0-logloss:0.086116
[96]	validation_0-logloss:0.087139
[97]	validation_0-logloss:0.086768
[98]	validation_0-logloss:0.086694
[99]	validation_0-logloss:0.086547
[100]	validation_0-logloss:0.086498
[101]	validation_0-logloss:0.08641
[102]	validation_0-logloss:0.086288
[103]	validation_0-logloss:0.086258
[104]	validation_0-logloss:0.086835
[105]	validation_0-logloss:0.086767
[106]	validation_0-logloss:0.087321
[107]	validation_0-logloss:0.087304
[108]	validation_0-logloss:0.08728
[109]	validation_0-logloss:0.087298
[110]	validation_0-logloss:0.087289
[111]	validation_0-logloss:0.088002
[112]	validation_0-logloss:0.087936
[113]	validation_0-logloss:0.087843
[114]	validation_0-logloss:0.088066
[115]	validation_0-logloss:0.087649
[116]	validation_0-logloss:0.087298
[117]	validation_0-logloss:0.087799
[118]	validation_0-logloss:0.087751
[119]	validation_0-logloss:0.08768
[120]	validation_0-logloss:0.087626
[121]	validation_0-logloss:0.08757
[122]	validation_0-logloss:0.087547
[123]	validation_0-logloss:0.087156
[124]	validation_0-logloss:0.08767
[125]	validation_0-logloss:0.087737
[126]	validation_0-logloss:0.088275
[127]	validation_0-logloss:0.088309
[128]	validation_0-logloss:0.088266
[129]	validation_0-logloss:0.087886
[130]	validation_0-logloss:0.088861
[131]	validation_0-logloss:0.088675
[132]	validation_0-logloss:0.088743
[133]	validation_0-logloss:0.089218
[134]	validation_0-logloss:0.089179
[135]	validation_0-logloss:0.088821
[136]	validation_0-logloss:0.088512
[137]	validation_0-logloss:0.08848
[138]	validation_0-logloss:0.088386
[139]	validation_0-logloss:0.089145
[140]	validation_0-logloss:0.08911
[141]	validation_0-logloss:0.088765
[142]	validation_0-logloss:0.088678
[143]	validation_0-logloss:0.088389
[144]	validation_0-logloss:0.089271
[145]	validation_0-logloss:0.089238
[146]	validation_0-logloss:0.089139
[147]	validation_0-logloss:0.088907
[148]	validation_0-logloss:0.089416
[149]	validation_0-logloss:0.089388
[150]	validation_0-logloss:0.089108
[151]	validation_0-logloss:0.088735
[152]	validation_0-logloss:0.088717
[153]	validation_0-logloss:0.088484
[154]	validation_0-logloss:0.088471
[155]	validation_0-logloss:0.088545
[156]	validation_0-logloss:0.088521
[157]	validation_0-logloss:0.088547
[158]	validation_0-logloss:0.088275
[159]	validation_0-logloss:0.0883
[160]	validation_0-logloss:0.08828
[161]	validation_0-logloss:0.088013
[162]	validation_0-logloss:0.087758
[163]	validation_0-logloss:0.087784
[164]	validation_0-logloss:0.087777
[165]	validation_0-logloss:0.087517
[166]	validation_0-logloss:0.087542
[167]	validation_0-logloss:0.087642
[168]	validation_0-logloss:0.08739
[169]	validation_0-logloss:0.087377
[170]	validation_0-logloss:0.087298
[171]	validation_0-logloss:0.087368
[172]	validation_0-logloss:0.087395
[173]	validation_0-logloss:0.087385
[174]	validation_0-logloss:0.087132
[175]	validation_0-logloss:0.087159
[176]	validation_0-logloss:0.086955
[177]	validation_0-logloss:0.087053
[178]	validation_0-logloss:0.08697
[179]	validation_0-logloss:0.086973
[180]	validation_0-logloss:0.087038
[181]	validation_0-logloss:0.086799
[182]	validation_0-logloss:0.086826
[183]	validation_0-logloss:0.086582
[184]	validation_0-logloss:0.086588
[185]	validation_0-logloss:0.086614
[186]	validation_0-logloss:0.086372
[187]	validation_0-logloss:0.086369
[188]	validation_0-logloss:0.086297
[189]	validation_0-logloss:0.086104
[190]	validation_0-logloss:0.086023
[191]	validation_0-logloss:0.08605
[192]	validation_0-logloss:0.086149
[193]	validation_0-logloss:0.085916
[194]	validation_0-logloss:0.085915
[195]	validation_0-logloss:0.085984
[196]	validation_0-logloss:0.086012
[197]	validation_0-logloss:0.085922
[198]	validation_0-logloss:0.085853
[199]	validation_0-logloss:0.085874
[200]	validation_0-logloss:0.085888
[201]	validation_0-logloss:0.08595
[202]	validation_0-logloss:0.08573
[203]	validation_0-logloss:0.08573
[204]	validation_0-logloss:0.085753
[205]	validation_0-logloss:0.085821
[206]	validation_0-logloss:0.08584
[207]	validation_0-logloss:0.085776
[208]	validation_0-logloss:0.085686
[209]	validation_0-logloss:0.08571
[210]	validation_0-logloss:0.085806
[211]	validation_0-logloss:0.085593
[212]	validation_0-logloss:0.085801
[213]	validation_0-logloss:0.085806
[214]	validation_0-logloss:0.085744
[215]	validation_0-logloss:0.085658
[216]	validation_0-logloss:0.085843
[217]	validation_0-logloss:0.085632
[218]	validation_0-logloss:0.085726
[219]	validation_0-logloss:0.085783
[220]	validation_0-logloss:0.085791
[221]	validation_0-logloss:0.085817
[222]	validation_0-logloss:0.085757
[223]	validation_0-logloss:0.085674
[224]	validation_0-logloss:0.08586
[225]	validation_0-logloss:0.085871
[226]	validation_0-logloss:0.085927
[227]	validation_0-logloss:0.085954
[228]	validation_0-logloss:0.085874
[229]	validation_0-logloss:0.086057
[230]	validation_0-logloss:0.086002
[231]	validation_0-logloss:0.085922
[232]	validation_0-logloss:0.086102
[233]	validation_0-logloss:0.086115
[234]	validation_0-logloss:0.086169
[235]	validation_0-logloss:0.086263
[236]	validation_0-logloss:0.086292
[237]	validation_0-logloss:0.086217
[238]	validation_0-logloss:0.086395
[239]	validation_0-logloss:0.086342
[240]	validation_0-logloss:0.08618
[241]	validation_0-logloss:0.086195
[242]	validation_0-logloss:0.086248
[243]	validation_0-logloss:0.086263
[244]	validation_0-logloss:0.086293
[245]	validation_0-logloss:0.086222
[246]	validation_0-logloss:0.086398
[247]	validation_0-logloss:0.086347
[248]	validation_0-logloss:0.086276
[249]	validation_0-logloss:0.086448
[250]	validation_0-logloss:0.086294
[251]	validation_0-logloss:0.086312
[252]	validation_0-logloss:0.086364
[253]	validation_0-logloss:0.086394
[254]	validation_0-logloss:0.08649
[255]	validation_0-logloss:0.086441
[256]	validation_0-logloss:0.08629
[257]	validation_0-logloss:0.08646
[258]	validation_0-logloss:0.086391
[259]	validation_0-logloss:0.086441
[260]	validation_0-logloss:0.086461
[261]	validation_0-logloss:0.086491
[262]	validation_0-logloss:0.086445
[263]	validation_0-logloss:0.086466
[264]	validation_0-logloss:0.086319
[265]	validation_0-logloss:0.086488
[266]	validation_0-logloss:0.086538
[267]	validation_0-logloss:0.086471
[268]	validation_0-logloss:0.086501
[269]	validation_0-logloss:0.086522
[270]	validation_0-logloss:0.086689
[271]	validation_0-logloss:0.086738
[272]	validation_0-logloss:0.08683
[273]	validation_0-logloss:0.086684
[274]	validation_0-logloss:0.08664
[275]	validation_0-logloss:0.086496
[276]	validation_0-logloss:0.086355
[277]	validation_0-logloss:0.086519
[278]	validation_0-logloss:0.086567
[279]	validation_0-logloss:0.08659
[280]	validation_0-logloss:0.086679
[281]	validation_0-logloss:0.086637
[282]	validation_0-logloss:0.086499
[283]	validation_0-logloss:0.086356
[284]	validation_0-logloss:0.086405
[285]	validation_0-logloss:0.086429
[286]	validation_0-logloss:0.086456
[287]	validation_0-logloss:0.086504
[288]	validation_0-logloss:0.08637
[289]	validation_0-logloss:0.086457
[290]	validation_0-logloss:0.086453
[291]	validation_0-logloss:0.086322
[292]	validation_0-logloss:0.086284
[293]	validation_0-logloss:0.086148
[294]	validation_0-logloss:0.086196
[295]	validation_0-logloss:0.086221
[296]	validation_0-logloss:0.086308
[297]	validation_0-logloss:0.086178
[298]	validation_0-logloss:0.086263
[299]	validation_0-logloss:0.086131
[300]	validation_0-logloss:0.086179
[301]	validation_0-logloss:0.086052
[302]	validation_0-logloss:0.086016
[303]	validation_0-logloss:0.086101
[304]	validation_0-logloss:0.085977
[305]	validation_0-logloss:0.086059
[306]	validation_0-logloss:0.085971
[307]	validation_0-logloss:0.085998
[308]	validation_0-logloss:0.085999
[309]	validation_0-logloss:0.085877
[310]	validation_0-logloss:0.085923
[311]	validation_0-logloss:0.085948
Stopping. Best iteration:
[211]	validation_0-logloss:0.085593

위의 결과에서는 211번 반복시 logloss가 0.085593 이었는데 이후 100번 반복되는 311번째까지
성능평가 지수가 향상되지 않았기 때문에 더 이상 반복하지 않고 멈추게 되었습니다.

In [15]:
get_clf_eval(y_test, ws100_preds)
오차행렬:
 [[34  3]
 [ 1 76]]

정확도: 0.9649
정밀도: 0.9620
재현율: 0.9870
F1: 0.9744
AUC: 0.9530


조기 중단값을 너무 급격하게 줄이면 성능이 향상될 여지가 있음에도 학습을 멈춰 예측 성능이 저하될 수 있습니다.

In [16]:
# early_stopping_rounds = 10 으로 설정하고 재학습
xgb_wrapper.fit(X_train, y_train, early_stopping_rounds = 10, 
                eval_metric='logloss', eval_set=evals , verbose=True)

ws10_preds = xgb_wrapper.predict(X_test)
get_clf_eval(y_test, ws10_preds)
[0]	validation_0-logloss:0.61352
Will train until validation_0-logloss hasn't improved in 10 rounds.
[1]	validation_0-logloss:0.547842
[2]	validation_0-logloss:0.494247
[3]	validation_0-logloss:0.447986
[4]	validation_0-logloss:0.409109
[5]	validation_0-logloss:0.374977
[6]	validation_0-logloss:0.345714
[7]	validation_0-logloss:0.320529
[8]	validation_0-logloss:0.29721
[9]	validation_0-logloss:0.277991
[10]	validation_0-logloss:0.260302
[11]	validation_0-logloss:0.246037
[12]	validation_0-logloss:0.231556
[13]	validation_0-logloss:0.22005
[14]	validation_0-logloss:0.208572
[15]	validation_0-logloss:0.199993
[16]	validation_0-logloss:0.190118
[17]	validation_0-logloss:0.181818
[18]	validation_0-logloss:0.174729
[19]	validation_0-logloss:0.167657
[20]	validation_0-logloss:0.158202
[21]	validation_0-logloss:0.154725
[22]	validation_0-logloss:0.148947
[23]	validation_0-logloss:0.143308
[24]	validation_0-logloss:0.136344
[25]	validation_0-logloss:0.132778
[26]	validation_0-logloss:0.127912
[27]	validation_0-logloss:0.125263
[28]	validation_0-logloss:0.119978
[29]	validation_0-logloss:0.116412
[30]	validation_0-logloss:0.114502
[31]	validation_0-logloss:0.112572
[32]	validation_0-logloss:0.11154
[33]	validation_0-logloss:0.108681
[34]	validation_0-logloss:0.106681
[35]	validation_0-logloss:0.104207
[36]	validation_0-logloss:0.102962
[37]	validation_0-logloss:0.100576
[38]	validation_0-logloss:0.098683
[39]	validation_0-logloss:0.096444
[40]	validation_0-logloss:0.095869
[41]	validation_0-logloss:0.094242
[42]	validation_0-logloss:0.094715
[43]	validation_0-logloss:0.094272
[44]	validation_0-logloss:0.093894
[45]	validation_0-logloss:0.094184
[46]	validation_0-logloss:0.09402
[47]	validation_0-logloss:0.09236
[48]	validation_0-logloss:0.093012
[49]	validation_0-logloss:0.091272
[50]	validation_0-logloss:0.090051
[51]	validation_0-logloss:0.089605
[52]	validation_0-logloss:0.089577
[53]	validation_0-logloss:0.090703
[54]	validation_0-logloss:0.089579
[55]	validation_0-logloss:0.090357
[56]	validation_0-logloss:0.091587
[57]	validation_0-logloss:0.091527
[58]	validation_0-logloss:0.091986
[59]	validation_0-logloss:0.091951
[60]	validation_0-logloss:0.091939
[61]	validation_0-logloss:0.091461
[62]	validation_0-logloss:0.090311
Stopping. Best iteration:
[52]	validation_0-logloss:0.089577

오차행렬:
 [[34  3]
 [ 2 75]]

정확도: 0.9561
정밀도: 0.9615
재현율: 0.9740
F1: 0.9677
AUC: 0.9465

62번째까지만 수행이 되고 종료되었는데, 이렇게 학습된 모델로 예측한 결과, 정확도는 약 0.9561로
ealry_stopping_rounds = 100일 때의 정확도인 0.9649보다 낮게 나왔습니다.

모델 예측 후 피처 중요도를 동일하게 plot_importance() API를 통해 시각화할 수 있습니다.

In [17]:
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))

plot_importance(xgb_wrapper, ax=ax)
Out[17]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a20d54f60>
  1. 지나가던참견꾼 2021.01.09 16:40

    시간이 제법 지난 글이지만... 많은 도움이 되었습니다.
    정성스레 글 써주셔서 감사합니다!!
    다름이 아니라, XGBoost 는 eXtra Gradient Boosting 보다는 eXtreme Gradient Boosting 의 약자로 쓰이는 게 보통인 것 같습니다.
    (Google 기준: eXtra Gradient Boosting 418건, eXtreme Gradient Boosting 16200건)
    글쓴이님의 글 정독하다가 문득 full name 이 알고 싶어, 검색하다 우연히 찾게 되었네요..ㅎㅎ

  2. dddd 2021.04.21 16:24

    많은 정보 얻어갑니다. 좋은 글 감사드립니다.

4-4. Boosting Algorithm(AdaBoost & GBM)
In [1]:
from IPython.core.display import display, HTML
display(HTML("<style> .container{width:90% !important;}</style>"))

1. Boosting Algorithm

부스팅 알고리즘은 여러 개의 약한 학습기(weak learner)를 순차적으로 학습-예측하면서 잘못 예측한 데이터에 가중치를 부여해 오류를 개선해나가며 학습하는 방식

부스팅 알고리즘은 대표적으로 아래와 같은 알고리즘들이 있음

  • AdaBoost
  • Gradient Booting Machine(GBM)
  • XGBoost
  • LightGBM
  • CatBoost

2. AdaBoost

Adaptive Boost의 줄임말로서 약한 학습기(weak learner)의 오류 데이터에 가중치를 부여하면서 부스팅을 수행하는 대표적인 알고리즘
속도나 성능적인 측면에서 decision tree를 약한 학습기로 사용함

AdaBoost의 학습

Step 1) 첫 번째 약한 학습기가 첫번째 분류기준(D1)으로 + 와 - 를 분류

Step 2) 잘못 분류된 데이터에 대해 가중치를 부여(두 번쨰 그림에서 커진 + 표시)

Step 3) 두 번째 약한 학습기가 두번째 분류기준(D2)으로 +와 - 를 다시 분류

Step 4) 잘못 분류된 데이터에 대해 가중치를 부여(세 번째 그림에서 커진 - 표시)

Step 5) 세 번째 약한 학습기가 세번째 분류기준으로(D3) +와 -를 다시 분류해서 오류 데이터를 찾음

Step 6) 마지막으로 분류기들을 결합하여 최종 예측 수행

→ 약한 학습기를 순차적으로 학습시켜, 개별 학습기에 가중치를 부여하여 모두 결합함으로써 개별 약한 학습기보다 높은 정확도의 예측 결과를 만듦

AdaBoost의 실행코드

In [2]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

# 데이터셋을 구하는 함수 설정
def get_human_dataset():
    
    # 각 데이터 파일들은 공백으로 분리되어 있으므로 read_csv에서 공백문자를 sep으로 할당
    feature_name_df = pd.read_csv('human_activity/features.txt', sep='\s+',
                                                     header=None, names=['column_index', 'column_name'])
    # 데이터프레임에 피처명을 컬럼으로 뷰여하기 위해 리스트 객체로 다시 반환
    feature_name = feature_name_df.iloc[:, 1].values.tolist()
    
    # 학습 피처 데이터세트와 테스트 피처 데이터를 데이터프레임으로 로딩
    # 컬럼명은 feature_name 적용
    X_train = pd.read_csv('human_activity/train/X_train.txt', sep='\s+', names=feature_name)
    X_test = pd.read_csv('human_activity/test/X_test.txt', sep='\s+', names=feature_name)
    
    # 학습 레이블과 테스트 레이블 데이터를 데이터 프레임으로 로딩, 컬럼명은 action으로 부여
    y_train = pd.read_csv('human_activity/train/y_train.txt', sep='\s+', names=['action'])
    y_test = pd.read_csv('human_activity/test/y_test.txt', sep='\s+', names=['action'])
    
    # 로드된 학습/테스트용 데이터프레임을 모두 반환
    return X_train, X_test, y_train, y_test

# 학습/테스트용 데이터 프레임 반환
X_train, X_test, y_train, y_test = get_human_dataset()
In [3]:
from sklearn.ensemble import AdaBoostClassifier
from sklearn.metrics import accuracy_score

clf = AdaBoostClassifier(n_estimators=30, 
                        random_state=10, 
                        learning_rate=0.1)
clf.fit(X_train, y_train)
pred = clf.predict(X_test)
print('AdaBoost 정확도: {:.4f}'.format(accuracy_score(y_test, pred)))
AdaBoost 정확도: 0.7720

AdaBoost의 하이퍼파라미터

파라미터 명 설명
base_estimators - 학습에 사용하는 알고리즘
- Default = None
→ DecisionTreeClassifier(max_depth=1)가 적용
n_estimators - 생성할 약한 학습기의 갯수를 지정
- Default = 50
learning_rate - 학습을 진행할 때마다 적용하는 학습률(0~1)
- Weak learner가 순차적으로 오류 값을 보정해나갈 때 적용하는 계수
- Default = 1.0
  • n_estimators를 늘린다면
    • 생성하는 weak learner의 수는 늘어남
    • 이 여러 학습기들의 decision boundary가 많아지면서 모델이 복잡해짐
  • learning_rate을 줄인다면
    • 가중치 갱신의 변동폭이 감소해서, 여러 학습기들의 decision boundary 차이가 줄어듦

위의 두 가지는 trade-off 관계입니다.
n_estimators(또는 learning_rate)를 늘리고, learning_rate(또는 n_estimators)을 줄인다면 서로 효과가 상쇄됩다.
→ 때문에 이 두 파라미터를 잘 조정하는 것이 알고리즘의 핵심입니다.

3. Gradient Boost Machine(GBM)

GBM의 학습 방식

AdaBoost와 유사하지만, 가중치 업데이트를 경사하강법(Gradient Descent)를 이용하여 최적화된 결과를 얻는 알고리즘입니다.
GBM은 예측 성능이 높지만 Greedy Algorithm으로 과적합이 빠르게되고, 시간이 오래 걸린다는 단점이 있습니다.

※ 경사하강법
분류의 실제값을 y, 피처에 기반한 예측함수를 F(x), 오류식을 h(x) = y-F(x)라고 하면 이 오류식을 최소화하는 방향성을 가지고 가중치 값을 업데이트

※ Greedy Algorithm(탐욕 알고리즘)
미래를 생각하지 않고 각 단계에서 가장 최선의 선택을 하는 기법으로
각 단계에서 최선의 선택을 한 것이 전체적으로도 최선이길 바라는 알고리즘입니다.
물론 모든 경우에서 그리디 알고리즘이 통하지는 않습니다.
가령 지금 선택하면 1개의 마시멜로를 받고, 1분 기다렸다 선택하면 2개의 마시멜로를 받는 문제에서는,
그리디 알고리즘을 사용하면 항상 마시멜로를 1개밖에 받지 못합니다.
지금 당장 최선의 선택은 마시멜로 1개를 받는 거지만, 결과적으로는 1분 기다렸다가 2개 받는 게 최선이기 때문입니다.
( 설명출처: https://www.zerocho.com/category/Algorithm/post/584ba5c9580277001862f188 )

In [4]:
# Gradient Boosting Classifier 불러오기
from sklearn.ensemble import GradientBoostingClassifier 
from sklearn.metrics import accuracy_score
import time

# GBM 수행시간 측정을 위함. 시작시간 설정
start_time = time.time()

# 예시 데이터셋 불러오기
gb_clf = GradientBoostingClassifier(random_state=0)
gb_clf.fit(X_train, y_train.values)
gb_pred = gb_clf.predict(X_test)
gb_accuracy = accuracy_score(y_test, gb_pred)

print('GBM 정확도: {:.4f}'.format(gb_accuracy))
print('GBM 수행 시간: {:.1f}초'.format(time.time() - start_time))
GBM 정확도: 0.9376
GBM 수행 시간: 141.2초

GBM의 하이퍼파라미터

Tree에 관한 하이퍼 파라미터

파라미터 명 설명
max_depth - 트리의 최대 깊이
- default = 3
- 깊이가 깊어지면 과적합될 수 있으므로 적절히 제어 필요
min_samples_split - 노드를 분할하기 위한 최소한의 샘플 데이터수
→ 과적합을 제어하는데 사용
- Default = 2 → 작게 설정할 수록 분할 노드가 많아져 과적합 가능성 증가
min_samples_leaf - 리프노드가 되기 위해 필요한 최소한의 샘플 데이터수
- min_samples_split과 함께 과적합 제어 용도
- default = 1
- 불균형 데이터의 경우 특정 클래스의 데이터가 극도로 작을 수 있으므로 작게 설정 필요
max_features - 최적의 분할을 위해 고려할 최대 feature 개수
- Default = 'none' → 모든 피처 사용
- int형으로 지정 →피처 갯수 / float형으로 지정 →비중
- sqrt 또는 auto : 전체 피처 중 √(피처개수) 만큼 선정
- log : 전체 피처 중 log2(전체 피처 개수) 만큼 선정
max_leaf_nodes - 리프노드의 최대 개수
- default = None → 제한없음

Boosting에 관한 하이퍼파라미터

파라미터 명 설명
loss - 경사하강법에서 사용할 cost function 지정
- 특별한 이유가 없으면 default 값인 deviance 적용
n_estimators - 생성할 트리의 갯수를 지정
- Default = 100
- 많을소록 성능은 좋아지지만 시간이 오래 걸림
learning_rate - 학습을 진행할 때마다 적용하는 학습률(0~1)
- Weak learner가 순차적으로 오류 값을 보정해나갈 때 적용하는 계수
- Default = 0.1
- 낮은 만큼 최소 오류 값을 찾아 예측성능이 높아질 수 있음
- 하지만 많은 수의 트리가 필요하고 시간이 많이 소요
subsample - 개별 트리가 학습에 사용하는 데이터 샘플링 비율(0~1)
- default=1 (전체 데이터 학습)
- 이 값을 조절하여 트리 간의 상관도를 줄일 수 있음

GridSearchCV를 통한 GBM의 하이퍼파라미터 튜닝

In [5]:
from sklearn.model_selection import GridSearchCV

param = {
    'n_estimators' : [100, 500],
    'learning_rate' : [0.05, 0.1]
}

grid_cv = GridSearchCV(gb_clf, param_grid=param, cv=2, verbose=1, n_jobs=-1)
grid_cv.fit(X_train, y_train.values)
print('최적 하이퍼 파라미터: \n', grid_cv.best_params_)
print('최고 예측 정확도: {0:.4f}'.format(grid_cv.best_score_))
Fitting 2 folds for each of 4 candidates, totalling 8 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   8 | elapsed:  2.4min remaining:  7.1min
[Parallel(n_jobs=-1)]: Done   8 out of   8 | elapsed:  5.0min finished
최적 하이퍼 파라미터: 
 {'learning_rate': 0.1, 'n_estimators': 500}
최고 예측 정확도: 0.9014

위의 간단한 그리드서치에서는 learning_rate = 0.05, n_estimator = 500 일 때 , 90.1% 정확도가 최고로 도출되었습니다.

In [6]:
# GridSearchCV를 이용해 최적으로 학습된 estimators로 예측 수행
gb_pred = grid_cv.best_estimator_.predict(X_test)
gb_accuracy = accuracy_score(y_test, gb_pred)
print('GBM 정확도: {0:.4f}'.format(gb_accuracy))
GBM 정확도: 0.9403

테스트 데이터 세트에서 약 94.03%의 정확도를 나타냈습니다.

4-3. Random Forest