1. Introduction

Lazy Evaluation 때문에 텐서플로 프로그램을 디버깅하는 것은 조금 까다롭다.
그래프에 에러가 있더라도, 그래프의 에러 부분이 실행되기 전까지는 알 수가 없다.
이 때문에 출력 결과를 바로바로 볼 수 있는 eager 모드가 개발 시 유용하게 쓰인다.

텐서플로 프로그램을 디버깅하는 것은 다른 소프트웨어와 크게 다르지 않다.

(1) 에러 메시지를 잘 봐야한다.

에러메시지는 stack trace(좌측 그림), error message(우측 그림)의 두 파트로 나뉘어진다.
좌측의 그림에서는 어디에서 에러가 발생했는지 알 수 있다.
이번의 경우에는 a와 c를 더하는 s에서 에러가 발생했다.

에러가 어디인지 파악이 됐으면 error message를 읽어 봐야 한다.
이번의 경우에는 더하는 두 텐서 간의 shape가 동일하지 않아 에러가 발생했다.

(2) 문제에서 메소드를 분리시킨다.

텐서플로 전체 프로그램과 데이터를 실행시키는 대신,
특정 부분에 대해서만 집중해서 문제를 해결한다.

(3) 가짜 데이터로 문제가 되는 메소드를 불러와서 확인해라.

(4) 이를 바탕으로 일반적인 문제해결은 어떻게 하는지 파악하라.

2. Shape Problems

Shape의 미스매치로 인한 오류는 가장 흔히 발생하는 유형 중 하나이다.

위의 예제도 shape가 맞지 않는 케이스이다.
어떤 shape이 맞을지는 코드의 목적에 따라 다르다.

위의 예제에서는 c = data[:, 1:3] 으로 변경해줌으로써 코드를 올바르게 수정하였다.

텐서플로에서 shape 에러는 배치 사이즈에 의해 발생할 수도 있습니다.
위의 예제에서는 input으로 (?, 3) 형태인 2차원 텐서가 들어가야 하는데
1차원 텐서를 넣었기 때문에 에러가 발생했습니다.
때문에 2차원 데이터로 변경을 해주어야 코드가 작동합니다.

이러한 shape 오류는 아래의 방법을 사용해서 해결할 수 있습니다.
(1) tf.reshape()
(2) tf.expand_dims()
(3) tf.slice()
(4) tf.squeeze

3. Fixing shape problems

(1) tf.reshape(tensor, shape, name=None)
: 입력된 텐서의 형태를 변형하는데 사용합니다.

(2) tf.expand_dims(input, dim, name=None)
: tf.expand_dims() 크기 1인 차원을 텐서의 구조(shape)에 삽입합니다.
이 때 차원 인덱스 dim은 0부터 시작합니다.

위의 예시를 보면 (3,2)의 모양을 가지는 2차원 텐서에 expanded_dim()의 인자를 0, 1, 2를
추가하면 각각 (1, 3, 2) , (3, 1, 2) , (3, 2, 1) 의 모양을 가지는 3차원 텐서가 된다.

(3) tf.slice(input_, begin, size, name=None)
: 이 함수는 텐서 input에서 begin 위치에서 시작해 크기 size인 부분을 추출합니다.

(4) tf.squeeze(input, squeeze_dims=None, name=None)
: 텐서에서 크기 1인 차원을 제거합니다.
input 텐서가 주어졌을 때, 이 함수는 그와 같은 자료형의 크기 1인 차원이 모두 제거된 텐서를 반환합니다.
만약 모든 크기 1인 차원을 제거하고 싶은 것이 아니라면, 제거하고 싶은 특정한 크기 1인 차원들을
squeeze_dims으로 지정할 수 있습니다.

4. Data Type Problems

위와 같은 코드를 실행하면 아래와 같은 오류 메시지를 보게된다.

a는 float 타입, b는 int 타입이라 데이터의 형이 다르기 때문에 연산하지 못해 오류가 발생하는 것이다.

이럴 때는 아래와 같이 tf.cast()를 이용해 데이터형 변환을 해주면 오류를 해결할 수 있다.

# 5. Debugging Full Programs
https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/deepdive/03_tensorflow/debug_demo.ipynb

위 링크 노트북 실습

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

Estimator API  (0) 2019.11.01
TensorFlow 실습 1  (0) 2019.10.28
Tensor and Variable  (0) 2019.10.25
Graph and Session  (0) 2019.10.25
TensorFlow API Hierarchy  (0) 2019.10.25

Getting Started with TensorFlow (구글퀵랩 실습)

아래와 같이 텐서플로를 불러오고, 버전을 확인합니다.
그리고 뒤의 실습을 위해 넘파이도 같이 불러 옵니다.

(1) 두 개의 tensor를 add하기(더하기)

텐서플로 코드를 이용하기 전에, 입력한 결과를 즉시 볼 수 있는 넘파이를 먼저 사용해보겠습니다.

이와 동일한 결과를 내기 위한 텐서플로 코드는 두 단계로 이루어집니다.

Step(1) : 그래프 만들기(Build the Graph)

위의 코드에서 c는 (3,)형태의 int32 속성을 가진 tensor를 반환하는 연산을 나타냅니다.
(넘파이와는 달리 print(c)를 실행한다고 해서 값이 합쳐진 [8, 2, 10] 의 결과가 나오지 않습니다.)

Step(2) : 그래프를 실행하기(Run the Graph)

(2) feed_dict를 사용하기

같은 그래프 이지만, 그래프를 만드는 단계에서 입력값을 하드코딩 하지 않고 placeholder를 사용합니다.

(3) TensorFlow로 Heron's Formula 구현하기

Heron's Formula에서는 삼각형의 세 변(a,b,c)이 주어지고 s=(a+b+c)/2 라 할 때,
삼각형의 면적을 ( s * (s-a) * (s-b) * (s-c) ) ^(1/2) 로 구합니다.

(4) Placeholder 와 feed_dict

텐서플로에서 프로그램 입력값을 placeholder로 정의하고,
세션 실행시 실제 값을 feed_dict를 이용해서 넣어주는 것은 흔한 방식입니다.
(3)번 예제에서 사용했던 코드에서는 area 그래프에서 입력 값이 값으로 바로 하드코딩 되있지만,
아래의 코드에서는 placeholder로 정의한 다음 run-time에서 입력값이 들어갑니다.

(5) tf.eager

tf.eager는 build-then-run 단계를 건너뛰게 해줍니다.
하지만 lazy evaluation의 패러다임 하에서 텐서플로가 멀티디바이스 지원, 활용이 가능하기 때문에
많은 생산코드들은 lazy evaluation을 사용하고 있습니다.

tf.eager는 프로그램의 문제점들을 수정해나갈 때 많이 사용하며,
tf.eager를 통해서 개발한 다음 eager 실행을 주석처리하고, 세션 관리코드를 추가해야 합니다.

다음의 eager모드를 실행하기 위해서는 노트북의 런타임을 재시작해야 합니다.

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

Estimator API  (0) 2019.11.01
Debugging TensorFlow Programs  (0) 2019.10.29
Tensor and Variable  (0) 2019.10.25
Graph and Session  (0) 2019.10.25
TensorFlow API Hierarchy  (0) 2019.10.25
4-5. Boosting Algorithm(XGBoost)
In [1]:
from IPython.core.display import display, HTML
display(HTML("<style> .container{width:90% !important;}</style>"))

1. XGBoost(eXtra Gradient Boost)의 개요

트리 기반의 알고리즘의 앙상블 학습에서 각광받는 알고리즘 중 하나입니다.
GBM에 기반하고 있지만, GBM의 단점인 느린 수행시간, 과적합 규제 등을 해결한 알고리즘입니다.

XGBoost의 주요장점

(1) 뛰어난 예측 성능
(2) GBM 대비 빠른 수행 시간
(3) 과적합 규제(Overfitting Regularization)
(4) Tree pruning(트리 가지치기) : 긍정 이득이 없는 분할을 가지치기해서 분할 수를 줄임
(5) 자체 내장된 교차 검증

  • 반복 수행시마다 내부적으로 교차검증을 수행해 최적회된 반복 수행횟수를 가질 수 있음
  • 지정된 반복횟수가 아니라 교차검증을 통해 평가 데이트세트의 평가 값이 최적화되면 반복을 중간에 멈출 수 있는 기능이 있음

(6) 결손값 자체 처리

XGBoost는 독자적인 XGBoost 모듈과 사이킷런 프레임워크 기반의 모듈이 존재합니다.
독자적인 모듈은 고유의 API와 하이퍼파라미터를 사용하지만, 사이킷런 기반 모듈에서는 다른 Estimator와 동일한 사용법을 가지고 있습니다.

2. XGBoost의 하이퍼 파라미터

일반 파라미터

: 일반적으로 실행 시 스레드의 개수나 silent 모드 등의 선택을 위한 파라미터, default 값을 바꾸는 일은 거의 없음

파라미터 명 설명
booster - gbtree(tree based model) 또는 gblinear(linear model) 중 선택
- Default = 'gbtree'
silent - Default = 1
- 출력 메시지를 나타내고 싶지 않을 경우 1로 설정
nthread - CPU 실행 스레드 개수 조정
- Default는 전체 다 사용하는 것
- 멀티코어/스레드 CPU 시스템에서 일부CPU만 사용할 때 변경

주요 부스터 파라미터

: 트리 최적화, 부스팅, regularization 등과 관련된 파라미터를 지칭

파라미터 명
(파이썬 래퍼)
파라미터명
(사이킷런 래퍼)
설명
eta
(0.3)
learning rate
(0.1)
- GBM의 learning rate와 같은 파라미터
- 범위: 0 ~ 1
num_boost_around
(10)
n_estimators
(100)
- 생성할 weak learner의 수
min_child_weight
(1)
min_child_weight
(1)
- GBM의 min_samples_leaf와 유사
- 관측치에 대한 가중치 합의 최소를 말하지만
GBM에서는 관측치 수에 대한 최소를 의미
- 과적합 조절 용도
- 범위: 0 ~ ∞
gamma
(0)
min_split_loss
(0)
- 리프노드의 추가분할을 결정할 최소손실 감소값
- 해당값보다 손실이 크게 감소할 때 분리
- 값이 클수록 과적합 감소효과
- 범위: 0 ~ ∞
max_depth
(6)
max_depth
(3)
- 트리 기반 알고리즘의 max_depth와 동일
- 0을 지정하면 깊이의 제한이 없음
- 너무 크면 과적합(통상 3~10정도 적용)
- 범위: 0 ~ ∞
sub_sample
(1)
subsample
(1)
- GBM의 subsample과 동일
- 데이터 샘플링 비율 지정(과적합 제어)
- 일반적으로 0.5~1 사이의 값을 사용
- 범위: 0 ~ 1
colsample_bytree
(1)
colsample_bytree
(1)
- GBM의 max_features와 유사
- 트리 생성에 필요한 피처의 샘플링에 사용
- 피처가 많을 때 과적합 조절에 사용
- 범위: 0 ~ 1
lambda
(1)
reg_lambda
(1)
- L2 Regularization 적용 값
- 피처 개수가 많을 때 적용을 검토
- 클수록 과적합 감소 효과
alpha
(0)
reg_alpha
(0)
- L1 Regularization 적용 값
- 피처 개수가 많을 때 적용을 검토
- 클수록 과적합 감소 효과
scale_pos_weight
(1)
scale_pos_weight
(1)
- 불균형 데이터셋의 균형을 유지

학습 태스크 파라미터

: 학습 수행 시의 객체함수, 평가를 위한 지표 등을 설정하는 파라미터

파라미터 명 설명
objective - ‘reg:linear’ : 회귀
- binary:logistic : 이진분류
- multi:softmax : 다중분류, 클래스 반환
- multi:softprob : 다중분류, 확륣반환
eval_metric - 검증에 사용되는 함수정의
- 회귀 분석인 경우 'rmse'를, 클래스 분류 문제인 경우 'error'
----------------------------------------------------
- rmse : Root Mean Squared Error
- mae : mean absolute error
- logloss : Negative log-likelihood
- error : binary classification error rate
- merror : multiclass classification error rate
- mlogloss: Multiclass logloss
- auc: Area Under Curve

과적합 제어

  • eta 값을 낮춥니다.(0.01 ~ 0.1) → eta 값을 낮추면 num_boost_round(n_estimator)를 반대로 높여주어야 합니다.
  • max_depth 값을 낮춥니다.
  • min_child_weight 값을 높입니다.
  • gamma 값을 높입니다.
  • subsample과 colsample_bytree를 낮춥니다.

Early Stopping 기능 :

GBM의 경우 n_estimators에 지정된 횟수만큼 학습을 끝까지 수행하지만, XGB의 경우 오류가 더 이상 개선되지 않으면 수행을 중지
n_estimators 를 200으로 설정하고, 조기 중단 파라미터 값을 50으로 설정하면, 1부터 200회까지 부스팅을 반복하다가
50회를 반복하는 동안 학습오류가 감소하지 않으면 더 이상 부스팅을 진행하지 않고 종료합니다.
(가령 100회에서 학습오류 값이 0.8인데 101~150회 반복하는 동안 예측 오류가 0.8보다 작은 값이 하나도 없으면 부스팅을 종료)

3. 파이썬 래퍼 XGBoost 적용

위스콘신 유방암 데이터 세트를 활용한 API 사용법

In [2]:
import xgboost as xgb ## XGBoost 불러오기
from xgboost import plot_importance ## Feature Importance를 불러오기 위함
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score
import warnings
warnings.filterwarnings('ignore')

dataset = load_breast_cancer()
X_features = dataset.data
y_label = dataset.target

cancer_df = pd.DataFrame(data=X_features, columns = dataset.feature_names)
cancer_df['target'] = y_label
cancer_df.head(3)
Out[2]:
mean radius mean texture mean perimeter mean area mean smoothness mean compactness mean concavity mean concave points mean symmetry mean fractal dimension ... worst texture worst perimeter worst area worst smoothness worst compactness worst concavity worst concave points worst symmetry worst fractal dimension target
0 17.99 10.38 122.8 1001.0 0.11840 0.27760 0.3001 0.14710 0.2419 0.07871 ... 17.33 184.6 2019.0 0.1622 0.6656 0.7119 0.2654 0.4601 0.11890 0
1 20.57 17.77 132.9 1326.0 0.08474 0.07864 0.0869 0.07017 0.1812 0.05667 ... 23.41 158.8 1956.0 0.1238 0.1866 0.2416 0.1860 0.2750 0.08902 0
2 19.69 21.25 130.0 1203.0 0.10960 0.15990 0.1974 0.12790 0.2069 0.05999 ... 25.53 152.5 1709.0 0.1444 0.4245 0.4504 0.2430 0.3613 0.08758 0

3 rows × 31 columns

위의 데이터셋에서 악성종양은 0, 양성은 1 값으로 되어 있음

In [3]:
print(dataset.target_names)
print(cancer_df['target'].value_counts())
['malignant' 'benign']
1    357
0    212
Name: target, dtype: int64
In [4]:
# 전체 데이터셋을 학습용 80%, 테스트용 20%로 분할
X_train, X_test, y_train, y_test = train_test_split(X_features, y_label, test_size=0.2, random_state=156)
print(X_train.shape, X_test.shape)
(455, 30) (114, 30)

파이썬래퍼 XGBoost와 사이킷런래퍼 XGBoost의 가장 큰 차이는
파이썬래퍼는 학습용과 테스트 데이터 세트를 위해 별도의 DMatrix를 생성한다는 것입니다.
DMatrix : 넘파이 입력 파라미터를 받아서 만들어지는 XGBoost만의 전용 데이터 세트

  • 주요 입력 파라미터는 data(피처 데이터 세트)와 label
    (분류: 레이블 데이터 세트/회귀: 숫자형인 종속값 데이터 세트)
  • 판다스의 DataFrame으로 데이터 인터페이스를 하기 위해서는 DataFrame.values를 이용해 넘파이로 일차변환 한 뒤에 DMatrix 변환을 적용
In [5]:
# 넘파이 형태의 학습 데이터 세트와 테스트 데이터를 DMatrix로 변환하는 예제
dtrain = xgb.DMatrix(data=X_train, label = y_train)
dtest = xgb.DMatrix(data=X_test, label=y_test)
In [6]:
# max_depth = 3, 학습률은 0.1, 예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
# 오류함수의 평가성능지표는 logloss
# 부스팅 반복횟수는 400
# 조기중단을 위한 최소 반복횟수는 100

params = {'max_depth' : 3,
         'eta' : 0.1, 
         'objective' : 'binary:logistic',
         'eval_metric' : 'logloss',
         'early_stoppings' : 100 }

num_rounds = 400

파이썬래퍼 XGBoost에서 하이퍼 파라미터를 xgboost 모듈의 train( ) 함수에 파라미터로 전달합니다.
(사이킷런 래퍼는 Estimator 생성자를 하이퍼 파라미터로 전달)

early_stopping_rounds 파라미터 : 조기 중단을 위한 라운드를 설정합니다.
조기 중단 기능 수행을 위해서는 반드시 eval_set과 eval_metric이 함께 설정되어야 합니다.

  • eval_set : 성능평가를 위한 평가용 데이터 세트를 설정
  • eval_metric : 평가 세트에 적용할 성능 평가 방법
    (반복마다 eval_set으로 지정된 데이터 세트에서 eval_metric의 지정된 평가 지표로 예측 오류를 측정)

train() 함수를 호출하면 xgboost가 반복 시마다 evals에 표시된 데이터 세트에 대해 평가 지표를 출력합니다.
그 후 학습이 완료된 모델 객체를 반환합니다.

In [7]:
# train 데이터 세트는 'train', evaluation(test) 데이터 세트는 'eval' 로 명기
wlist = [(dtrain, 'train'), (dtest,'eval')]
# 하이퍼 파라미터와 early stopping 파라미터를 train() 함수의 파라미터로 전달
xgb_model = xgb.train(params = params, dtrain=dtrain, num_boost_round=num_rounds, evals=wlist)
[0]	train-logloss:0.609685	eval-logloss:0.61352
[1]	train-logloss:0.540804	eval-logloss:0.547842
[2]	train-logloss:0.483755	eval-logloss:0.494247
[3]	train-logloss:0.434455	eval-logloss:0.447986
[4]	train-logloss:0.390549	eval-logloss:0.409109
[5]	train-logloss:0.354145	eval-logloss:0.374977
[6]	train-logloss:0.321222	eval-logloss:0.345714
[7]	train-logloss:0.292592	eval-logloss:0.320529
[8]	train-logloss:0.267467	eval-logloss:0.29721
[9]	train-logloss:0.245152	eval-logloss:0.277991
[10]	train-logloss:0.225694	eval-logloss:0.260302
[11]	train-logloss:0.207937	eval-logloss:0.246037
[12]	train-logloss:0.192183	eval-logloss:0.231556
[13]	train-logloss:0.177917	eval-logloss:0.22005
[14]	train-logloss:0.165221	eval-logloss:0.208572
[15]	train-logloss:0.153622	eval-logloss:0.199993
[16]	train-logloss:0.14333	eval-logloss:0.190118
[17]	train-logloss:0.133985	eval-logloss:0.181818
[18]	train-logloss:0.125599	eval-logloss:0.174729
[19]	train-logloss:0.117286	eval-logloss:0.167657
[20]	train-logloss:0.109688	eval-logloss:0.158202
[21]	train-logloss:0.102975	eval-logloss:0.154725
[22]	train-logloss:0.097068	eval-logloss:0.148947
[23]	train-logloss:0.091428	eval-logloss:0.143308
[24]	train-logloss:0.086335	eval-logloss:0.136344
[25]	train-logloss:0.081311	eval-logloss:0.132778
[26]	train-logloss:0.076857	eval-logloss:0.127912
[27]	train-logloss:0.072836	eval-logloss:0.125263
[28]	train-logloss:0.069248	eval-logloss:0.119978
[29]	train-logloss:0.065549	eval-logloss:0.116412
[30]	train-logloss:0.062414	eval-logloss:0.114502
[31]	train-logloss:0.059591	eval-logloss:0.112572
[32]	train-logloss:0.057096	eval-logloss:0.11154
[33]	train-logloss:0.054407	eval-logloss:0.108681
[34]	train-logloss:0.052036	eval-logloss:0.106681
[35]	train-logloss:0.049751	eval-logloss:0.104207
[36]	train-logloss:0.04775	eval-logloss:0.102962
[37]	train-logloss:0.045854	eval-logloss:0.100576
[38]	train-logloss:0.044015	eval-logloss:0.098683
[39]	train-logloss:0.042263	eval-logloss:0.096444
[40]	train-logloss:0.040649	eval-logloss:0.095869
[41]	train-logloss:0.039126	eval-logloss:0.094242
[42]	train-logloss:0.037377	eval-logloss:0.094715
[43]	train-logloss:0.036106	eval-logloss:0.094272
[44]	train-logloss:0.034941	eval-logloss:0.093894
[45]	train-logloss:0.033654	eval-logloss:0.094184
[46]	train-logloss:0.032528	eval-logloss:0.09402
[47]	train-logloss:0.031485	eval-logloss:0.09236
[48]	train-logloss:0.030389	eval-logloss:0.093012
[49]	train-logloss:0.029467	eval-logloss:0.091272
[50]	train-logloss:0.028545	eval-logloss:0.090051
[51]	train-logloss:0.027525	eval-logloss:0.089605
[52]	train-logloss:0.026555	eval-logloss:0.089577
[53]	train-logloss:0.025682	eval-logloss:0.090703
[54]	train-logloss:0.025004	eval-logloss:0.089579
[55]	train-logloss:0.024297	eval-logloss:0.090357
[56]	train-logloss:0.023574	eval-logloss:0.091587
[57]	train-logloss:0.022965	eval-logloss:0.091527
[58]	train-logloss:0.022488	eval-logloss:0.091986
[59]	train-logloss:0.021854	eval-logloss:0.091951
[60]	train-logloss:0.021316	eval-logloss:0.091939
[61]	train-logloss:0.020794	eval-logloss:0.091461
[62]	train-logloss:0.020218	eval-logloss:0.090311
[63]	train-logloss:0.019701	eval-logloss:0.089407
[64]	train-logloss:0.01918	eval-logloss:0.089719
[65]	train-logloss:0.018724	eval-logloss:0.089743
[66]	train-logloss:0.018325	eval-logloss:0.089622
[67]	train-logloss:0.017867	eval-logloss:0.088734
[68]	train-logloss:0.017598	eval-logloss:0.088621
[69]	train-logloss:0.017243	eval-logloss:0.089739
[70]	train-logloss:0.01688	eval-logloss:0.089981
[71]	train-logloss:0.016641	eval-logloss:0.089782
[72]	train-logloss:0.016287	eval-logloss:0.089584
[73]	train-logloss:0.015983	eval-logloss:0.089533
[74]	train-logloss:0.015658	eval-logloss:0.088748
[75]	train-logloss:0.015393	eval-logloss:0.088597
[76]	train-logloss:0.015151	eval-logloss:0.08812
[77]	train-logloss:0.01488	eval-logloss:0.088396
[78]	train-logloss:0.014637	eval-logloss:0.088736
[79]	train-logloss:0.014491	eval-logloss:0.088153
[80]	train-logloss:0.014185	eval-logloss:0.087577
[81]	train-logloss:0.014005	eval-logloss:0.087412
[82]	train-logloss:0.013772	eval-logloss:0.08849
[83]	train-logloss:0.013567	eval-logloss:0.088575
[84]	train-logloss:0.013414	eval-logloss:0.08807
[85]	train-logloss:0.013253	eval-logloss:0.087641
[86]	train-logloss:0.013109	eval-logloss:0.087416
[87]	train-logloss:0.012926	eval-logloss:0.087611
[88]	train-logloss:0.012714	eval-logloss:0.087065
[89]	train-logloss:0.012544	eval-logloss:0.08727
[90]	train-logloss:0.012353	eval-logloss:0.087161
[91]	train-logloss:0.012226	eval-logloss:0.086962
[92]	train-logloss:0.012065	eval-logloss:0.087166
[93]	train-logloss:0.011927	eval-logloss:0.087067
[94]	train-logloss:0.011821	eval-logloss:0.086592
[95]	train-logloss:0.011649	eval-logloss:0.086116
[96]	train-logloss:0.011482	eval-logloss:0.087139
[97]	train-logloss:0.01136	eval-logloss:0.086768
[98]	train-logloss:0.011239	eval-logloss:0.086694
[99]	train-logloss:0.011132	eval-logloss:0.086547
[100]	train-logloss:0.011002	eval-logloss:0.086498
[101]	train-logloss:0.010852	eval-logloss:0.08641
[102]	train-logloss:0.010755	eval-logloss:0.086288
[103]	train-logloss:0.010636	eval-logloss:0.086258
[104]	train-logloss:0.0105	eval-logloss:0.086835
[105]	train-logloss:0.010395	eval-logloss:0.086767
[106]	train-logloss:0.010305	eval-logloss:0.087321
[107]	train-logloss:0.010197	eval-logloss:0.087304
[108]	train-logloss:0.010072	eval-logloss:0.08728
[109]	train-logloss:0.01	eval-logloss:0.087298
[110]	train-logloss:0.009914	eval-logloss:0.087289
[111]	train-logloss:0.009798	eval-logloss:0.088002
[112]	train-logloss:0.00971	eval-logloss:0.087936
[113]	train-logloss:0.009628	eval-logloss:0.087843
[114]	train-logloss:0.009558	eval-logloss:0.088066
[115]	train-logloss:0.009483	eval-logloss:0.087649
[116]	train-logloss:0.009416	eval-logloss:0.087298
[117]	train-logloss:0.009306	eval-logloss:0.087799
[118]	train-logloss:0.009228	eval-logloss:0.087751
[119]	train-logloss:0.009154	eval-logloss:0.08768
[120]	train-logloss:0.009118	eval-logloss:0.087626
[121]	train-logloss:0.009016	eval-logloss:0.08757
[122]	train-logloss:0.008972	eval-logloss:0.087547
[123]	train-logloss:0.008904	eval-logloss:0.087156
[124]	train-logloss:0.008837	eval-logloss:0.08767
[125]	train-logloss:0.008803	eval-logloss:0.087737
[126]	train-logloss:0.008709	eval-logloss:0.088275
[127]	train-logloss:0.008645	eval-logloss:0.088309
[128]	train-logloss:0.008613	eval-logloss:0.088266
[129]	train-logloss:0.008555	eval-logloss:0.087886
[130]	train-logloss:0.008463	eval-logloss:0.088861
[131]	train-logloss:0.008416	eval-logloss:0.088675
[132]	train-logloss:0.008385	eval-logloss:0.088743
[133]	train-logloss:0.0083	eval-logloss:0.089218
[134]	train-logloss:0.00827	eval-logloss:0.089179
[135]	train-logloss:0.008218	eval-logloss:0.088821
[136]	train-logloss:0.008157	eval-logloss:0.088512
[137]	train-logloss:0.008076	eval-logloss:0.08848
[138]	train-logloss:0.008047	eval-logloss:0.088386
[139]	train-logloss:0.007973	eval-logloss:0.089145
[140]	train-logloss:0.007946	eval-logloss:0.08911
[141]	train-logloss:0.007898	eval-logloss:0.088765
[142]	train-logloss:0.007872	eval-logloss:0.088678
[143]	train-logloss:0.007847	eval-logloss:0.088389
[144]	train-logloss:0.007776	eval-logloss:0.089271
[145]	train-logloss:0.007752	eval-logloss:0.089238
[146]	train-logloss:0.007728	eval-logloss:0.089139
[147]	train-logloss:0.007689	eval-logloss:0.088907
[148]	train-logloss:0.007621	eval-logloss:0.089416
[149]	train-logloss:0.007598	eval-logloss:0.089388
[150]	train-logloss:0.007575	eval-logloss:0.089108
[151]	train-logloss:0.007521	eval-logloss:0.088735
[152]	train-logloss:0.007498	eval-logloss:0.088717
[153]	train-logloss:0.007464	eval-logloss:0.088484
[154]	train-logloss:0.00741	eval-logloss:0.088471
[155]	train-logloss:0.007389	eval-logloss:0.088545
[156]	train-logloss:0.007367	eval-logloss:0.088521
[157]	train-logloss:0.007345	eval-logloss:0.088547
[158]	train-logloss:0.007323	eval-logloss:0.088275
[159]	train-logloss:0.007303	eval-logloss:0.0883
[160]	train-logloss:0.007282	eval-logloss:0.08828
[161]	train-logloss:0.007261	eval-logloss:0.088013
[162]	train-logloss:0.007241	eval-logloss:0.087758
[163]	train-logloss:0.007221	eval-logloss:0.087784
[164]	train-logloss:0.0072	eval-logloss:0.087777
[165]	train-logloss:0.00718	eval-logloss:0.087517
[166]	train-logloss:0.007161	eval-logloss:0.087542
[167]	train-logloss:0.007142	eval-logloss:0.087642
[168]	train-logloss:0.007122	eval-logloss:0.08739
[169]	train-logloss:0.007103	eval-logloss:0.087377
[170]	train-logloss:0.007084	eval-logloss:0.087298
[171]	train-logloss:0.007065	eval-logloss:0.087368
[172]	train-logloss:0.007047	eval-logloss:0.087395
[173]	train-logloss:0.007028	eval-logloss:0.087385
[174]	train-logloss:0.007009	eval-logloss:0.087132
[175]	train-logloss:0.006991	eval-logloss:0.087159
[176]	train-logloss:0.006973	eval-logloss:0.086955
[177]	train-logloss:0.006955	eval-logloss:0.087053
[178]	train-logloss:0.006937	eval-logloss:0.08697
[179]	train-logloss:0.00692	eval-logloss:0.086973
[180]	train-logloss:0.006901	eval-logloss:0.087038
[181]	train-logloss:0.006884	eval-logloss:0.086799
[182]	train-logloss:0.006866	eval-logloss:0.086826
[183]	train-logloss:0.006849	eval-logloss:0.086582
[184]	train-logloss:0.006831	eval-logloss:0.086588
[185]	train-logloss:0.006815	eval-logloss:0.086614
[186]	train-logloss:0.006798	eval-logloss:0.086372
[187]	train-logloss:0.006781	eval-logloss:0.086369
[188]	train-logloss:0.006764	eval-logloss:0.086297
[189]	train-logloss:0.006747	eval-logloss:0.086104
[190]	train-logloss:0.00673	eval-logloss:0.086023
[191]	train-logloss:0.006714	eval-logloss:0.08605
[192]	train-logloss:0.006698	eval-logloss:0.086149
[193]	train-logloss:0.006682	eval-logloss:0.085916
[194]	train-logloss:0.006666	eval-logloss:0.085915
[195]	train-logloss:0.00665	eval-logloss:0.085984
[196]	train-logloss:0.006634	eval-logloss:0.086012
[197]	train-logloss:0.006618	eval-logloss:0.085922
[198]	train-logloss:0.006603	eval-logloss:0.085853
[199]	train-logloss:0.006587	eval-logloss:0.085874
[200]	train-logloss:0.006572	eval-logloss:0.085888
[201]	train-logloss:0.006556	eval-logloss:0.08595
[202]	train-logloss:0.006542	eval-logloss:0.08573
[203]	train-logloss:0.006527	eval-logloss:0.08573
[204]	train-logloss:0.006512	eval-logloss:0.085753
[205]	train-logloss:0.006497	eval-logloss:0.085821
[206]	train-logloss:0.006483	eval-logloss:0.08584
[207]	train-logloss:0.006469	eval-logloss:0.085776
[208]	train-logloss:0.006455	eval-logloss:0.085686
[209]	train-logloss:0.00644	eval-logloss:0.08571
[210]	train-logloss:0.006427	eval-logloss:0.085806
[211]	train-logloss:0.006413	eval-logloss:0.085593
[212]	train-logloss:0.006399	eval-logloss:0.085801
[213]	train-logloss:0.006385	eval-logloss:0.085806
[214]	train-logloss:0.006372	eval-logloss:0.085744
[215]	train-logloss:0.006359	eval-logloss:0.085658
[216]	train-logloss:0.006345	eval-logloss:0.085843
[217]	train-logloss:0.006332	eval-logloss:0.085632
[218]	train-logloss:0.006319	eval-logloss:0.085726
[219]	train-logloss:0.006306	eval-logloss:0.085783
[220]	train-logloss:0.006293	eval-logloss:0.085791
[221]	train-logloss:0.00628	eval-logloss:0.085817
[222]	train-logloss:0.006268	eval-logloss:0.085757
[223]	train-logloss:0.006255	eval-logloss:0.085674
[224]	train-logloss:0.006242	eval-logloss:0.08586
[225]	train-logloss:0.00623	eval-logloss:0.085871
[226]	train-logloss:0.006218	eval-logloss:0.085927
[227]	train-logloss:0.006206	eval-logloss:0.085954
[228]	train-logloss:0.006194	eval-logloss:0.085874
[229]	train-logloss:0.006182	eval-logloss:0.086057
[230]	train-logloss:0.00617	eval-logloss:0.086002
[231]	train-logloss:0.006158	eval-logloss:0.085922
[232]	train-logloss:0.006147	eval-logloss:0.086102
[233]	train-logloss:0.006135	eval-logloss:0.086115
[234]	train-logloss:0.006124	eval-logloss:0.086169
[235]	train-logloss:0.006112	eval-logloss:0.086263
[236]	train-logloss:0.006101	eval-logloss:0.086292
[237]	train-logloss:0.00609	eval-logloss:0.086217
[238]	train-logloss:0.006079	eval-logloss:0.086395
[239]	train-logloss:0.006068	eval-logloss:0.086342
[240]	train-logloss:0.006057	eval-logloss:0.08618
[241]	train-logloss:0.006046	eval-logloss:0.086195
[242]	train-logloss:0.006036	eval-logloss:0.086248
[243]	train-logloss:0.006025	eval-logloss:0.086263
[244]	train-logloss:0.006014	eval-logloss:0.086293
[245]	train-logloss:0.006004	eval-logloss:0.086222
[246]	train-logloss:0.005993	eval-logloss:0.086398
[247]	train-logloss:0.005983	eval-logloss:0.086347
[248]	train-logloss:0.005972	eval-logloss:0.086276
[249]	train-logloss:0.005962	eval-logloss:0.086448
[250]	train-logloss:0.005952	eval-logloss:0.086294
[251]	train-logloss:0.005942	eval-logloss:0.086312
[252]	train-logloss:0.005932	eval-logloss:0.086364
[253]	train-logloss:0.005922	eval-logloss:0.086394
[254]	train-logloss:0.005912	eval-logloss:0.08649
[255]	train-logloss:0.005903	eval-logloss:0.086441
[256]	train-logloss:0.005893	eval-logloss:0.08629
[257]	train-logloss:0.005883	eval-logloss:0.08646
[258]	train-logloss:0.005874	eval-logloss:0.086391
[259]	train-logloss:0.005864	eval-logloss:0.086441
[260]	train-logloss:0.005855	eval-logloss:0.086461
[261]	train-logloss:0.005845	eval-logloss:0.086491
[262]	train-logloss:0.005836	eval-logloss:0.086445
[263]	train-logloss:0.005827	eval-logloss:0.086466
[264]	train-logloss:0.005818	eval-logloss:0.086319
[265]	train-logloss:0.005809	eval-logloss:0.086488
[266]	train-logloss:0.0058	eval-logloss:0.086538
[267]	train-logloss:0.005791	eval-logloss:0.086471
[268]	train-logloss:0.005782	eval-logloss:0.086501
[269]	train-logloss:0.005773	eval-logloss:0.086522
[270]	train-logloss:0.005764	eval-logloss:0.086689
[271]	train-logloss:0.005755	eval-logloss:0.086738
[272]	train-logloss:0.005747	eval-logloss:0.08683
[273]	train-logloss:0.005738	eval-logloss:0.086684
[274]	train-logloss:0.005729	eval-logloss:0.08664
[275]	train-logloss:0.005721	eval-logloss:0.086496
[276]	train-logloss:0.005712	eval-logloss:0.086355
[277]	train-logloss:0.005704	eval-logloss:0.086519
[278]	train-logloss:0.005696	eval-logloss:0.086567
[279]	train-logloss:0.005687	eval-logloss:0.08659
[280]	train-logloss:0.005679	eval-logloss:0.086679
[281]	train-logloss:0.005671	eval-logloss:0.086637
[282]	train-logloss:0.005663	eval-logloss:0.086499
[283]	train-logloss:0.005655	eval-logloss:0.086356
[284]	train-logloss:0.005646	eval-logloss:0.086405
[285]	train-logloss:0.005639	eval-logloss:0.086429
[286]	train-logloss:0.005631	eval-logloss:0.086456
[287]	train-logloss:0.005623	eval-logloss:0.086504
[288]	train-logloss:0.005615	eval-logloss:0.08637
[289]	train-logloss:0.005608	eval-logloss:0.086457
[290]	train-logloss:0.0056	eval-logloss:0.086453
[291]	train-logloss:0.005593	eval-logloss:0.086322
[292]	train-logloss:0.005585	eval-logloss:0.086284
[293]	train-logloss:0.005577	eval-logloss:0.086148
[294]	train-logloss:0.00557	eval-logloss:0.086196
[295]	train-logloss:0.005563	eval-logloss:0.086221
[296]	train-logloss:0.005556	eval-logloss:0.086308
[297]	train-logloss:0.005548	eval-logloss:0.086178
[298]	train-logloss:0.005541	eval-logloss:0.086263
[299]	train-logloss:0.005534	eval-logloss:0.086131
[300]	train-logloss:0.005527	eval-logloss:0.086179
[301]	train-logloss:0.005519	eval-logloss:0.086052
[302]	train-logloss:0.005512	eval-logloss:0.086016
[303]	train-logloss:0.005505	eval-logloss:0.086101
[304]	train-logloss:0.005498	eval-logloss:0.085977
[305]	train-logloss:0.005491	eval-logloss:0.086059
[306]	train-logloss:0.005484	eval-logloss:0.085971
[307]	train-logloss:0.005478	eval-logloss:0.085998
[308]	train-logloss:0.005471	eval-logloss:0.085999
[309]	train-logloss:0.005464	eval-logloss:0.085877
[310]	train-logloss:0.005457	eval-logloss:0.085923
[311]	train-logloss:0.00545	eval-logloss:0.085948
[312]	train-logloss:0.005444	eval-logloss:0.086028
[313]	train-logloss:0.005437	eval-logloss:0.086112
[314]	train-logloss:0.00543	eval-logloss:0.085989
[315]	train-logloss:0.005424	eval-logloss:0.085903
[316]	train-logloss:0.005417	eval-logloss:0.085949
[317]	train-logloss:0.005411	eval-logloss:0.085977
[318]	train-logloss:0.005404	eval-logloss:0.086002
[319]	train-logloss:0.005398	eval-logloss:0.085883
[320]	train-logloss:0.005392	eval-logloss:0.085967
[321]	train-logloss:0.005385	eval-logloss:0.086046
[322]	train-logloss:0.005379	eval-logloss:0.086091
[323]	train-logloss:0.005373	eval-logloss:0.085977
[324]	train-logloss:0.005366	eval-logloss:0.085978
[325]	train-logloss:0.00536	eval-logloss:0.085896
[326]	train-logloss:0.005354	eval-logloss:0.08578
[327]	train-logloss:0.005348	eval-logloss:0.085857
[328]	train-logloss:0.005342	eval-logloss:0.085939
[329]	train-logloss:0.005336	eval-logloss:0.085825
[330]	train-logloss:0.00533	eval-logloss:0.085869
[331]	train-logloss:0.005324	eval-logloss:0.085893
[332]	train-logloss:0.005318	eval-logloss:0.085922
[333]	train-logloss:0.005312	eval-logloss:0.085842
[334]	train-logloss:0.005306	eval-logloss:0.085735
[335]	train-logloss:0.0053	eval-logloss:0.085816
[336]	train-logloss:0.005294	eval-logloss:0.085892
[337]	train-logloss:0.005288	eval-logloss:0.085936
[338]	train-logloss:0.005283	eval-logloss:0.08583
[339]	train-logloss:0.005277	eval-logloss:0.085909
[340]	train-logloss:0.005271	eval-logloss:0.085831
[341]	train-logloss:0.005265	eval-logloss:0.085727
[342]	train-logloss:0.00526	eval-logloss:0.085678
[343]	train-logloss:0.005254	eval-logloss:0.085721
[344]	train-logloss:0.005249	eval-logloss:0.085796
[345]	train-logloss:0.005243	eval-logloss:0.085819
[346]	train-logloss:0.005237	eval-logloss:0.085715
[347]	train-logloss:0.005232	eval-logloss:0.085793
[348]	train-logloss:0.005227	eval-logloss:0.085835
[349]	train-logloss:0.005221	eval-logloss:0.085734
[350]	train-logloss:0.005216	eval-logloss:0.085658
[351]	train-logloss:0.00521	eval-logloss:0.08573
[352]	train-logloss:0.005205	eval-logloss:0.085807
[353]	train-logloss:0.0052	eval-logloss:0.085706
[354]	train-logloss:0.005195	eval-logloss:0.085659
[355]	train-logloss:0.005189	eval-logloss:0.085701
[356]	train-logloss:0.005184	eval-logloss:0.085628
[357]	train-logloss:0.005179	eval-logloss:0.085529
[358]	train-logloss:0.005174	eval-logloss:0.085604
[359]	train-logloss:0.005169	eval-logloss:0.085676
[360]	train-logloss:0.005164	eval-logloss:0.085579
[361]	train-logloss:0.005159	eval-logloss:0.085601
[362]	train-logloss:0.005153	eval-logloss:0.085643
[363]	train-logloss:0.005149	eval-logloss:0.085713
[364]	train-logloss:0.005144	eval-logloss:0.085787
[365]	train-logloss:0.005139	eval-logloss:0.085689
[366]	train-logloss:0.005134	eval-logloss:0.08573
[367]	train-logloss:0.005129	eval-logloss:0.085684
[368]	train-logloss:0.005124	eval-logloss:0.085589
[369]	train-logloss:0.005119	eval-logloss:0.085516
[370]	train-logloss:0.005114	eval-logloss:0.085588
[371]	train-logloss:0.00511	eval-logloss:0.085495
[372]	train-logloss:0.005105	eval-logloss:0.085564
[373]	train-logloss:0.0051	eval-logloss:0.085605
[374]	train-logloss:0.005096	eval-logloss:0.085626
[375]	train-logloss:0.005091	eval-logloss:0.085535
[376]	train-logloss:0.005086	eval-logloss:0.085606
[377]	train-logloss:0.005082	eval-logloss:0.085674
[378]	train-logloss:0.005077	eval-logloss:0.085714
[379]	train-logloss:0.005073	eval-logloss:0.085624
[380]	train-logloss:0.005068	eval-logloss:0.085579
[381]	train-logloss:0.005064	eval-logloss:0.085618
[382]	train-logloss:0.00506	eval-logloss:0.085639
[383]	train-logloss:0.005055	eval-logloss:0.08555
[384]	train-logloss:0.005051	eval-logloss:0.085617
[385]	train-logloss:0.005047	eval-logloss:0.085621
[386]	train-logloss:0.005042	eval-logloss:0.085551
[387]	train-logloss:0.005038	eval-logloss:0.085463
[388]	train-logloss:0.005034	eval-logloss:0.085502
[389]	train-logloss:0.005029	eval-logloss:0.085459
[390]	train-logloss:0.005025	eval-logloss:0.085321
[391]	train-logloss:0.005021	eval-logloss:0.085389
[392]	train-logloss:0.005017	eval-logloss:0.085303
[393]	train-logloss:0.005013	eval-logloss:0.085369
[394]	train-logloss:0.005009	eval-logloss:0.085301
[395]	train-logloss:0.005005	eval-logloss:0.085368
[396]	train-logloss:0.005	eval-logloss:0.085283
[397]	train-logloss:0.004996	eval-logloss:0.08532
[398]	train-logloss:0.004992	eval-logloss:0.085279
[399]	train-logloss:0.004988	eval-logloss:0.085196

train( ) 을 통해 학습을 수행하면 반복시 train-logloss와 eval-logloss가 지속적으로 감소합니다.
xgboost를 이용해 학습이 완료됐으면 predict() 메서드를 이용해 예측을 수행합니다.
여기서 파이썬 래퍼는 예측 결과를 추정할 수 있는 호가률 값을 반환합니다.(반면 사이킷런 래퍼는 클래스 값을 반환)

In [8]:
pred_probs = xgb_model.predict(dtest)
print('predict() 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨')
print(np.round(pred_probs[:10], 3))

# 예측 확률이 0.5보다 크면 1, 그렇지 않으면 0으로 예측값 결정해 리스트 객체인 preds에 저장
preds = [ 1 if x > 0.5 else 0 for x in pred_probs]
print('예측값 10개만 표시: ', preds[:10])
predict() 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨
[0.95  0.003 0.9   0.086 0.993 1.    1.    0.999 0.998 0.   ]
예측값 10개만 표시:  [1, 0, 1, 0, 1, 1, 1, 1, 1, 0]
In [9]:
# 혼동행렬, 정확도, 정밀도, 재현율, F1, AUC 불러오기
def get_clf_eval(y_test, y_pred):
    confusion = confusion_matrix(y_test, y_pred)
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    F1 = f1_score(y_test, y_pred)
    AUC = roc_auc_score(y_test, y_pred)
    print('오차행렬:\n', confusion)
    print('\n정확도: {:.4f}'.format(accuracy))
    print('정밀도: {:.4f}'.format(precision))
    print('재현율: {:.4f}'.format(recall))
    print('F1: {:.4f}'.format(F1))
    print('AUC: {:.4f}'.format(AUC))
In [10]:
get_clf_eval(y_test, preds)
오차행렬:
 [[35  2]
 [ 1 76]]

정확도: 0.9737
정밀도: 0.9744
재현율: 0.9870
F1: 0.9806
AUC: 0.9665

Feature importance를 시각화할 때,
→ 기본 평가 지료로 f1스코어를 기반으로 각 feature의 중요도를 나타냅니다.
→ 사이킷런 래퍼는 estimator 객체의 featureimportances 속성을 이용해 시각화 코드를 직접 작성해야 합니다.
→ 반면, 파이썬 래퍼는 plot_importance()를 이용해 바로 피처 중요 코드를 시각화 할 수 있습니다.

In [11]:
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(xgb_model, ax=ax)
Out[11]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a20bbf7f0>

다만, xgboost 넘파이 기반의 피처 데이터로 학습 시에 피처명을 제대로 알 수 없으므로
피처별로 f자 뒤에 순서를 붙여 X축에 피처들로 나열합니다.(f0는 첫번째 피처, f1는 두번째 피처를 의미)

파이썬 래퍼의 교차 검증 수행 및 최적 파라미터 구하기

: xgboost는 사이킷런의 GridSearchCV와 유사하게 cv( )를 API로 제공합니다.

xgb.cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False,
folds=None, metrics=(),obj=None, feval=None, maximize=False,
early_stopping_rounds=None, fpreproc=None, as_pandas=True,
verbose_eval=None, show_stdv=True, seed=0, callbacks=None, shuffle=True)

  • params(dict): 부스터 파라미터
  • dtrain(DMatrix) : 학습 데이터
  • num_boost_round(int) : 부스팅 반복횟수
  • nfold(int) : CV폴드 개수
  • stratified(bool) : CV수행시 샘플을 균등하게 추출할지 여부
  • metrics(string or list of strings) : CV 수행시 모니터링할 성능 평가 지표
  • early_stopping_rounds(int) : 조기중단을 활성화시킴. 반복횟수 지정

xgv.cv의 반환 값은 데이터프레임 형태입니다.

4. 사이킷런 래퍼 XGBoost의 개요 및 적용

특징

  • 사이킷런의 기본 Estimator를 이용해 만들어 fit()과 predict()만으로 학습과 예측이 가능
  • GridSearchCV,Pipeline 등 사이킷런의 유틸리티를 그대로 사용 가능
  • 분류 : XGBClassifier / 회귀 : XGBRegressor

파이썬 래퍼와 비교시 달라진 파라미터

  • eta → learning_rate
  • sub_sample → subsample
  • lambda → reg_lambda
  • alpha → reg_alpha
  • num_boost_round → n_estimators

위와 동일하게 위스콘신 유방암 데이터를 통한 예측

In [12]:
# max_depth = 3, 학습률은 0.1, 예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
# 부스팅 반복횟수는 400

from xgboost import XGBClassifier

xgb_wrapper = XGBClassifier(n_estimators = 400, learning_rate = 0.1, max_depth = 3)
xgb_wrapper.fit(X_train, y_train)
w_preds = xgb_wrapper.predict(X_test)
In [13]:
# 예측 결과 확인
get_clf_eval(y_test, w_preds)
오차행렬:
 [[35  2]
 [ 1 76]]

정확도: 0.9737
정밀도: 0.9744
재현율: 0.9870
F1: 0.9806
AUC: 0.9665

앞선 파이썬 래퍼 XGBoost와 동일한 결과가 나옵니다.

사이킷런 래퍼 XGBoost에서도 조기 중단 기능을 수행할 수 있는데 fit( )에 해당 파라미터를 입력하면 됩니다.
→ early_stopping_rounds, eval_metrics, eval_set

In [14]:
# max_depth = 3, 학습률은 0.1, 예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
# 오류함수의 평가성능지표는 logloss
# 부스팅 반복횟수는 400
# 조기중단을 위한 최소 반복횟수는 100

# 아래 예제에서는 평가를 위한 데이터 세트로 테스트 데이터 세트를 사용했지만, 바람직하진 않습니다.
# 테스트 데이터 세트는 학습에 완전히 알려지지 않은 데이터 세트를 사용해야 합니다.
# 평가에 테스트 데이터 세트를 사용하면 학습시에 미리 참고가 되어 과적합할 수 있기 때문입니다.

xgb_wrapper = XGBClassifier(n_estimators = 400, learning_rate = 0.1 , max_depth = 3)
evals = [(X_test, y_test)]
xgb_wrapper.fit(X_train, y_train, early_stopping_rounds = 100, 
                eval_metric="logloss", eval_set = evals, verbose=True)
ws100_preds = xgb_wrapper.predict(X_test)
[0]	validation_0-logloss:0.61352
Will train until validation_0-logloss hasn't improved in 100 rounds.
[1]	validation_0-logloss:0.547842
[2]	validation_0-logloss:0.494247
[3]	validation_0-logloss:0.447986
[4]	validation_0-logloss:0.409109
[5]	validation_0-logloss:0.374977
[6]	validation_0-logloss:0.345714
[7]	validation_0-logloss:0.320529
[8]	validation_0-logloss:0.29721
[9]	validation_0-logloss:0.277991
[10]	validation_0-logloss:0.260302
[11]	validation_0-logloss:0.246037
[12]	validation_0-logloss:0.231556
[13]	validation_0-logloss:0.22005
[14]	validation_0-logloss:0.208572
[15]	validation_0-logloss:0.199993
[16]	validation_0-logloss:0.190118
[17]	validation_0-logloss:0.181818
[18]	validation_0-logloss:0.174729
[19]	validation_0-logloss:0.167657
[20]	validation_0-logloss:0.158202
[21]	validation_0-logloss:0.154725
[22]	validation_0-logloss:0.148947
[23]	validation_0-logloss:0.143308
[24]	validation_0-logloss:0.136344
[25]	validation_0-logloss:0.132778
[26]	validation_0-logloss:0.127912
[27]	validation_0-logloss:0.125263
[28]	validation_0-logloss:0.119978
[29]	validation_0-logloss:0.116412
[30]	validation_0-logloss:0.114502
[31]	validation_0-logloss:0.112572
[32]	validation_0-logloss:0.11154
[33]	validation_0-logloss:0.108681
[34]	validation_0-logloss:0.106681
[35]	validation_0-logloss:0.104207
[36]	validation_0-logloss:0.102962
[37]	validation_0-logloss:0.100576
[38]	validation_0-logloss:0.098683
[39]	validation_0-logloss:0.096444
[40]	validation_0-logloss:0.095869
[41]	validation_0-logloss:0.094242
[42]	validation_0-logloss:0.094715
[43]	validation_0-logloss:0.094272
[44]	validation_0-logloss:0.093894
[45]	validation_0-logloss:0.094184
[46]	validation_0-logloss:0.09402
[47]	validation_0-logloss:0.09236
[48]	validation_0-logloss:0.093012
[49]	validation_0-logloss:0.091272
[50]	validation_0-logloss:0.090051
[51]	validation_0-logloss:0.089605
[52]	validation_0-logloss:0.089577
[53]	validation_0-logloss:0.090703
[54]	validation_0-logloss:0.089579
[55]	validation_0-logloss:0.090357
[56]	validation_0-logloss:0.091587
[57]	validation_0-logloss:0.091527
[58]	validation_0-logloss:0.091986
[59]	validation_0-logloss:0.091951
[60]	validation_0-logloss:0.091939
[61]	validation_0-logloss:0.091461
[62]	validation_0-logloss:0.090311
[63]	validation_0-logloss:0.089407
[64]	validation_0-logloss:0.089719
[65]	validation_0-logloss:0.089743
[66]	validation_0-logloss:0.089622
[67]	validation_0-logloss:0.088734
[68]	validation_0-logloss:0.088621
[69]	validation_0-logloss:0.089739
[70]	validation_0-logloss:0.089981
[71]	validation_0-logloss:0.089782
[72]	validation_0-logloss:0.089584
[73]	validation_0-logloss:0.089533
[74]	validation_0-logloss:0.088748
[75]	validation_0-logloss:0.088597
[76]	validation_0-logloss:0.08812
[77]	validation_0-logloss:0.088396
[78]	validation_0-logloss:0.088736
[79]	validation_0-logloss:0.088153
[80]	validation_0-logloss:0.087577
[81]	validation_0-logloss:0.087412
[82]	validation_0-logloss:0.08849
[83]	validation_0-logloss:0.088575
[84]	validation_0-logloss:0.08807
[85]	validation_0-logloss:0.087641
[86]	validation_0-logloss:0.087416
[87]	validation_0-logloss:0.087611
[88]	validation_0-logloss:0.087065
[89]	validation_0-logloss:0.08727
[90]	validation_0-logloss:0.087161
[91]	validation_0-logloss:0.086962
[92]	validation_0-logloss:0.087166
[93]	validation_0-logloss:0.087067
[94]	validation_0-logloss:0.086592
[95]	validation_0-logloss:0.086116
[96]	validation_0-logloss:0.087139
[97]	validation_0-logloss:0.086768
[98]	validation_0-logloss:0.086694
[99]	validation_0-logloss:0.086547
[100]	validation_0-logloss:0.086498
[101]	validation_0-logloss:0.08641
[102]	validation_0-logloss:0.086288
[103]	validation_0-logloss:0.086258
[104]	validation_0-logloss:0.086835
[105]	validation_0-logloss:0.086767
[106]	validation_0-logloss:0.087321
[107]	validation_0-logloss:0.087304
[108]	validation_0-logloss:0.08728
[109]	validation_0-logloss:0.087298
[110]	validation_0-logloss:0.087289
[111]	validation_0-logloss:0.088002
[112]	validation_0-logloss:0.087936
[113]	validation_0-logloss:0.087843
[114]	validation_0-logloss:0.088066
[115]	validation_0-logloss:0.087649
[116]	validation_0-logloss:0.087298
[117]	validation_0-logloss:0.087799
[118]	validation_0-logloss:0.087751
[119]	validation_0-logloss:0.08768
[120]	validation_0-logloss:0.087626
[121]	validation_0-logloss:0.08757
[122]	validation_0-logloss:0.087547
[123]	validation_0-logloss:0.087156
[124]	validation_0-logloss:0.08767
[125]	validation_0-logloss:0.087737
[126]	validation_0-logloss:0.088275
[127]	validation_0-logloss:0.088309
[128]	validation_0-logloss:0.088266
[129]	validation_0-logloss:0.087886
[130]	validation_0-logloss:0.088861
[131]	validation_0-logloss:0.088675
[132]	validation_0-logloss:0.088743
[133]	validation_0-logloss:0.089218
[134]	validation_0-logloss:0.089179
[135]	validation_0-logloss:0.088821
[136]	validation_0-logloss:0.088512
[137]	validation_0-logloss:0.08848
[138]	validation_0-logloss:0.088386
[139]	validation_0-logloss:0.089145
[140]	validation_0-logloss:0.08911
[141]	validation_0-logloss:0.088765
[142]	validation_0-logloss:0.088678
[143]	validation_0-logloss:0.088389
[144]	validation_0-logloss:0.089271
[145]	validation_0-logloss:0.089238
[146]	validation_0-logloss:0.089139
[147]	validation_0-logloss:0.088907
[148]	validation_0-logloss:0.089416
[149]	validation_0-logloss:0.089388
[150]	validation_0-logloss:0.089108
[151]	validation_0-logloss:0.088735
[152]	validation_0-logloss:0.088717
[153]	validation_0-logloss:0.088484
[154]	validation_0-logloss:0.088471
[155]	validation_0-logloss:0.088545
[156]	validation_0-logloss:0.088521
[157]	validation_0-logloss:0.088547
[158]	validation_0-logloss:0.088275
[159]	validation_0-logloss:0.0883
[160]	validation_0-logloss:0.08828
[161]	validation_0-logloss:0.088013
[162]	validation_0-logloss:0.087758
[163]	validation_0-logloss:0.087784
[164]	validation_0-logloss:0.087777
[165]	validation_0-logloss:0.087517
[166]	validation_0-logloss:0.087542
[167]	validation_0-logloss:0.087642
[168]	validation_0-logloss:0.08739
[169]	validation_0-logloss:0.087377
[170]	validation_0-logloss:0.087298
[171]	validation_0-logloss:0.087368
[172]	validation_0-logloss:0.087395
[173]	validation_0-logloss:0.087385
[174]	validation_0-logloss:0.087132
[175]	validation_0-logloss:0.087159
[176]	validation_0-logloss:0.086955
[177]	validation_0-logloss:0.087053
[178]	validation_0-logloss:0.08697
[179]	validation_0-logloss:0.086973
[180]	validation_0-logloss:0.087038
[181]	validation_0-logloss:0.086799
[182]	validation_0-logloss:0.086826
[183]	validation_0-logloss:0.086582
[184]	validation_0-logloss:0.086588
[185]	validation_0-logloss:0.086614
[186]	validation_0-logloss:0.086372
[187]	validation_0-logloss:0.086369
[188]	validation_0-logloss:0.086297
[189]	validation_0-logloss:0.086104
[190]	validation_0-logloss:0.086023
[191]	validation_0-logloss:0.08605
[192]	validation_0-logloss:0.086149
[193]	validation_0-logloss:0.085916
[194]	validation_0-logloss:0.085915
[195]	validation_0-logloss:0.085984
[196]	validation_0-logloss:0.086012
[197]	validation_0-logloss:0.085922
[198]	validation_0-logloss:0.085853
[199]	validation_0-logloss:0.085874
[200]	validation_0-logloss:0.085888
[201]	validation_0-logloss:0.08595
[202]	validation_0-logloss:0.08573
[203]	validation_0-logloss:0.08573
[204]	validation_0-logloss:0.085753
[205]	validation_0-logloss:0.085821
[206]	validation_0-logloss:0.08584
[207]	validation_0-logloss:0.085776
[208]	validation_0-logloss:0.085686
[209]	validation_0-logloss:0.08571
[210]	validation_0-logloss:0.085806
[211]	validation_0-logloss:0.085593
[212]	validation_0-logloss:0.085801
[213]	validation_0-logloss:0.085806
[214]	validation_0-logloss:0.085744
[215]	validation_0-logloss:0.085658
[216]	validation_0-logloss:0.085843
[217]	validation_0-logloss:0.085632
[218]	validation_0-logloss:0.085726
[219]	validation_0-logloss:0.085783
[220]	validation_0-logloss:0.085791
[221]	validation_0-logloss:0.085817
[222]	validation_0-logloss:0.085757
[223]	validation_0-logloss:0.085674
[224]	validation_0-logloss:0.08586
[225]	validation_0-logloss:0.085871
[226]	validation_0-logloss:0.085927
[227]	validation_0-logloss:0.085954
[228]	validation_0-logloss:0.085874
[229]	validation_0-logloss:0.086057
[230]	validation_0-logloss:0.086002
[231]	validation_0-logloss:0.085922
[232]	validation_0-logloss:0.086102
[233]	validation_0-logloss:0.086115
[234]	validation_0-logloss:0.086169
[235]	validation_0-logloss:0.086263
[236]	validation_0-logloss:0.086292
[237]	validation_0-logloss:0.086217
[238]	validation_0-logloss:0.086395
[239]	validation_0-logloss:0.086342
[240]	validation_0-logloss:0.08618
[241]	validation_0-logloss:0.086195
[242]	validation_0-logloss:0.086248
[243]	validation_0-logloss:0.086263
[244]	validation_0-logloss:0.086293
[245]	validation_0-logloss:0.086222
[246]	validation_0-logloss:0.086398
[247]	validation_0-logloss:0.086347
[248]	validation_0-logloss:0.086276
[249]	validation_0-logloss:0.086448
[250]	validation_0-logloss:0.086294
[251]	validation_0-logloss:0.086312
[252]	validation_0-logloss:0.086364
[253]	validation_0-logloss:0.086394
[254]	validation_0-logloss:0.08649
[255]	validation_0-logloss:0.086441
[256]	validation_0-logloss:0.08629
[257]	validation_0-logloss:0.08646
[258]	validation_0-logloss:0.086391
[259]	validation_0-logloss:0.086441
[260]	validation_0-logloss:0.086461
[261]	validation_0-logloss:0.086491
[262]	validation_0-logloss:0.086445
[263]	validation_0-logloss:0.086466
[264]	validation_0-logloss:0.086319
[265]	validation_0-logloss:0.086488
[266]	validation_0-logloss:0.086538
[267]	validation_0-logloss:0.086471
[268]	validation_0-logloss:0.086501
[269]	validation_0-logloss:0.086522
[270]	validation_0-logloss:0.086689
[271]	validation_0-logloss:0.086738
[272]	validation_0-logloss:0.08683
[273]	validation_0-logloss:0.086684
[274]	validation_0-logloss:0.08664
[275]	validation_0-logloss:0.086496
[276]	validation_0-logloss:0.086355
[277]	validation_0-logloss:0.086519
[278]	validation_0-logloss:0.086567
[279]	validation_0-logloss:0.08659
[280]	validation_0-logloss:0.086679
[281]	validation_0-logloss:0.086637
[282]	validation_0-logloss:0.086499
[283]	validation_0-logloss:0.086356
[284]	validation_0-logloss:0.086405
[285]	validation_0-logloss:0.086429
[286]	validation_0-logloss:0.086456
[287]	validation_0-logloss:0.086504
[288]	validation_0-logloss:0.08637
[289]	validation_0-logloss:0.086457
[290]	validation_0-logloss:0.086453
[291]	validation_0-logloss:0.086322
[292]	validation_0-logloss:0.086284
[293]	validation_0-logloss:0.086148
[294]	validation_0-logloss:0.086196
[295]	validation_0-logloss:0.086221
[296]	validation_0-logloss:0.086308
[297]	validation_0-logloss:0.086178
[298]	validation_0-logloss:0.086263
[299]	validation_0-logloss:0.086131
[300]	validation_0-logloss:0.086179
[301]	validation_0-logloss:0.086052
[302]	validation_0-logloss:0.086016
[303]	validation_0-logloss:0.086101
[304]	validation_0-logloss:0.085977
[305]	validation_0-logloss:0.086059
[306]	validation_0-logloss:0.085971
[307]	validation_0-logloss:0.085998
[308]	validation_0-logloss:0.085999
[309]	validation_0-logloss:0.085877
[310]	validation_0-logloss:0.085923
[311]	validation_0-logloss:0.085948
Stopping. Best iteration:
[211]	validation_0-logloss:0.085593

위의 결과에서는 211번 반복시 logloss가 0.085593 이었는데 이후 100번 반복되는 311번째까지
성능평가 지수가 향상되지 않았기 때문에 더 이상 반복하지 않고 멈추게 되었습니다.

In [15]:
get_clf_eval(y_test, ws100_preds)
오차행렬:
 [[34  3]
 [ 1 76]]

정확도: 0.9649
정밀도: 0.9620
재현율: 0.9870
F1: 0.9744
AUC: 0.9530


조기 중단값을 너무 급격하게 줄이면 성능이 향상될 여지가 있음에도 학습을 멈춰 예측 성능이 저하될 수 있습니다.

In [16]:
# early_stopping_rounds = 10 으로 설정하고 재학습
xgb_wrapper.fit(X_train, y_train, early_stopping_rounds = 10, 
                eval_metric='logloss', eval_set=evals , verbose=True)

ws10_preds = xgb_wrapper.predict(X_test)
get_clf_eval(y_test, ws10_preds)
[0]	validation_0-logloss:0.61352
Will train until validation_0-logloss hasn't improved in 10 rounds.
[1]	validation_0-logloss:0.547842
[2]	validation_0-logloss:0.494247
[3]	validation_0-logloss:0.447986
[4]	validation_0-logloss:0.409109
[5]	validation_0-logloss:0.374977
[6]	validation_0-logloss:0.345714
[7]	validation_0-logloss:0.320529
[8]	validation_0-logloss:0.29721
[9]	validation_0-logloss:0.277991
[10]	validation_0-logloss:0.260302
[11]	validation_0-logloss:0.246037
[12]	validation_0-logloss:0.231556
[13]	validation_0-logloss:0.22005
[14]	validation_0-logloss:0.208572
[15]	validation_0-logloss:0.199993
[16]	validation_0-logloss:0.190118
[17]	validation_0-logloss:0.181818
[18]	validation_0-logloss:0.174729
[19]	validation_0-logloss:0.167657
[20]	validation_0-logloss:0.158202
[21]	validation_0-logloss:0.154725
[22]	validation_0-logloss:0.148947
[23]	validation_0-logloss:0.143308
[24]	validation_0-logloss:0.136344
[25]	validation_0-logloss:0.132778
[26]	validation_0-logloss:0.127912
[27]	validation_0-logloss:0.125263
[28]	validation_0-logloss:0.119978
[29]	validation_0-logloss:0.116412
[30]	validation_0-logloss:0.114502
[31]	validation_0-logloss:0.112572
[32]	validation_0-logloss:0.11154
[33]	validation_0-logloss:0.108681
[34]	validation_0-logloss:0.106681
[35]	validation_0-logloss:0.104207
[36]	validation_0-logloss:0.102962
[37]	validation_0-logloss:0.100576
[38]	validation_0-logloss:0.098683
[39]	validation_0-logloss:0.096444
[40]	validation_0-logloss:0.095869
[41]	validation_0-logloss:0.094242
[42]	validation_0-logloss:0.094715
[43]	validation_0-logloss:0.094272
[44]	validation_0-logloss:0.093894
[45]	validation_0-logloss:0.094184
[46]	validation_0-logloss:0.09402
[47]	validation_0-logloss:0.09236
[48]	validation_0-logloss:0.093012
[49]	validation_0-logloss:0.091272
[50]	validation_0-logloss:0.090051
[51]	validation_0-logloss:0.089605
[52]	validation_0-logloss:0.089577
[53]	validation_0-logloss:0.090703
[54]	validation_0-logloss:0.089579
[55]	validation_0-logloss:0.090357
[56]	validation_0-logloss:0.091587
[57]	validation_0-logloss:0.091527
[58]	validation_0-logloss:0.091986
[59]	validation_0-logloss:0.091951
[60]	validation_0-logloss:0.091939
[61]	validation_0-logloss:0.091461
[62]	validation_0-logloss:0.090311
Stopping. Best iteration:
[52]	validation_0-logloss:0.089577

오차행렬:
 [[34  3]
 [ 2 75]]

정확도: 0.9561
정밀도: 0.9615
재현율: 0.9740
F1: 0.9677
AUC: 0.9465

62번째까지만 수행이 되고 종료되었는데, 이렇게 학습된 모델로 예측한 결과, 정확도는 약 0.9561로
ealry_stopping_rounds = 100일 때의 정확도인 0.9649보다 낮게 나왔습니다.

모델 예측 후 피처 중요도를 동일하게 plot_importance() API를 통해 시각화할 수 있습니다.

In [17]:
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))

plot_importance(xgb_wrapper, ax=ax)
Out[17]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a20d54f60>

1. Tensors

  • Tensor(텐서)는 데이터의 n차원 배열을 말한다.

위의 그림과 같이 0차원부터 3차원까지 리스트를 이용해서 텐서를 구성할 수 있고,
n차원의 텐서 또한 위와 같이 텐서를 쌓음으로써(stack) 쉽게 구성할 수 있다.

저차원의 텐서를 쌓아 고차원의 텐서를 만들 수 있고, 반대로 고차원의 텐서에서 저차원의 텐서를 뽑아낼 수도 있다.
위의 예시에서 x는 (2,3)의 모양을 가진 2차원 텐서이다.
y는 x를 슬라이싱한 결과로 1차원 텐서가 되었다.

위의 예시처럼 텐서 안에 데이터를 가지고 있으면, 그것을 이용해 텐서를 변형(reshape)할 수 있다.
처음 예시는 (2,3)형태의 2차원 텐서를 (3,2)형태의 2차원 텐서로 변형한 것이고,
두 번째 예시는 첫번째 예시 결과에서 2번째 행만 불러오도록 슬라이싱되어 1차원 텐서가 되었다.

2. Variables

  • Variable(변수) : 프로그램이 실행될 때 값이 초기화되면서 계속 바뀌는 텐서를 말한다.
    변수는 신경망에서 bias 및 weight에 사용된다.

위의 예시 주요 설명

(1) 변수는 get_variable( )을 이용해서 생성한다.

  • tf.variable( )을 사용할 수도 있지만 tf.get_variable( )이 변수의 재사용 및 다양한 환경에 따라 만들기가 용이하기 때문
  • 예시에서는 변수는 weights란 이름을 가지고 있고, shape는 (1,2)
  • 이번 설명에서는 생략했지만 변수는 scope를 설정할 수 있다.
    (여기에서 변수를 재생성하지 않고, 재사용하는 것이 나옴)

(2) 변수를 만들 때 어떻게 초기화할지 설정해야 한다.

  • 신경망에서는 random normal 하게 초기화하는 것이 보통이다.

(3) 그래프를 형성했다면 변수를 다른 텐서처럼 사용해라.

(4) 세션에서는 모든 변수를 초기화 해라

  • global_variables_initializer() 를 호출 전의 그래프의 상태는 각 노드에 값이 아직 없는 상태를 의미
  • 따라서 해당 함수를 사용해주어야 Variable 의 값이 할당 되는 것이고 텐서의 그래프로써의 효력이 발생

(5) 모든 변수가 초기화 되고 나면, 원하는 어떤 텐서든 evaluate할 수 있다.

위의 예시에서는 train_loop를 상수인 텐서 x와 함께 불러왔다.
하지만 이게 현실적인가? 프로그램에 인풋값을 하드코딩하는가?
이 때 placeholder를 사용할 수 있다.

3. Placeholder

  • Placeholder : 텍스트 파일을 읽어들이는 것처럼 값을 그래프에 넣어주는 것

데이터를 입력받는 비어있는 변수라고 생각할 수 있다. 먼저 그래프를 구성하고, 그 그래프가 실행되는 시점에 입력 데이터를 넣어주는 데 사용한다.
(출처: https://excelsior-cjh.tistory.com/151)

tf.placeholder 함수는 입력으로 사용할 데이터의 타입만 지정해주고,
나중에 세션에서 실행될때 실제값은 딕셔너리 형태로 입력해준다.

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

Debugging TensorFlow Programs  (0) 2019.10.29
TensorFlow 실습 1  (0) 2019.10.28
Graph and Session  (0) 2019.10.25
TensorFlow API Hierarchy  (0) 2019.10.25
What is TensorFlow?  (0) 2019.10.25

1. Graph and Session

텐서플로에서의 DAG(Directed Acyclic Graph)는 다른 그래프들과 동일하게 edge와 node로 구성되어 있다.

(결국, 데이터와 데이터에 대한 연산으로 DAG가 이루어져있다.)

- edge : 데이터(즉, 텐서)

- node : 텐서들에 대한 연산작업

텐서플로는 그래프의 처리, 컴파일, DAG의 중간에 삽입, 수신 등을 할 수 있다.

그리고 DAG의 다른 부분을 다른 디바이스(CPU, GPU, TPU 뿐만 아니라 다른 기기까지)에 할당할 수 있다.

(다른 다바이스들간의 커뮤니케이션과 조정을 수행)

session 클래스는 우리가 사용하는 파이썬 프로그램과 C++ 런타임을 연결해주는 역할을 한다.

텐서플로 그래프를 실행하기 위해서는 세션에서 런을 호출해야 한다.
위의 예시에서는 x와 y라는 1차원 텐서를 정의했고,
z라는 텐서는 tf.add(x,y) (x와 y의 합)의 결과를 말한다.

이것을 evaluate하기 위해 z에 대해 session.run을 호출한다.
위의 예시에서는 파이썬의 with 구문을 통해 세션이 완료되었을 때 자동적으로 종료되게끔 한다.

2. Evaluating a Tensor

z를 evaluate 하기 위해서는 session.run(z) 을 기본적으로 사용하지만
z.eval() 을 사용할 수도 있다.

session.run()을 사용할 때 우리는 텐서의 리스트를 사용할 수도 있다.
그리고 tf.add(x,y) , tf.multiply(x,y) 대신 x+y 와 x*y 같이 간단히 적을 수 있다.

텐서플로에서는 원래 만들어진 그대로 lazy evaluation을 사용하는 것이 권장사항이다.
하지만 개발, 디버깅 등의 작업을 할 때는 즉시 결과를 볼 수 있는 eager mode를 사용하는 것이 편하다.

위와 같이 tf.eager를 불러온 다음 eager execution이 가능하게 하면
session.run 을 수행하지 않고도 즉시 결과를 볼 수 있다.
(하지만 강의에선 이런거는 검증이 끝나면 다시 lazy하게 돌아가는 것을 권장함)

3. Visualizing a graph

지금까지는 그래프를 어떻게 작성하고, 실행하는지 보았습니다.
그래프를 시각화한다거나, 데이터가 흘러 들어가서 어떻게 작동하는지 보고 싶을 때,
뉴런 네트워크 구조를 시각화 하고 싶을 때 어떻게 해야하는지 ??

그래프를 시각화하기 위해서는 tf.Summary.FileWriter("원하는 폴더명", sess.graph) 를 사용하면 된다.
그리고 이 때 정의한 텐서에 name을 지정해주어야 한다.
(그렇지 않으면 Ad_7 등과 같이 자동 생성된 이름이 나와서 파악하기가 어렵다.)
하지만 이것만 실행했을 때 폴더 안의 graph는 binary 값으로 되있기 때문에 우리가 읽을 수 있는 형태가 아니다.

우리가 읽기 위해서는 TensorBoard 라는 프로그램을 사용해야 한다.
아래는 텐서보드를 실행하는 코드이다.(google.datalab.ml을 따로 설치해야 실행이 되는것 같음)

위를 실행해서 나오는 페이지로 가면 아래와 같은 그림을 볼 수 있다.

※ 구글 클라우드셸에서 텐서보드 실행하는 법

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

TensorFlow 실습 1  (0) 2019.10.28
Tensor and Variable  (0) 2019.10.25
TensorFlow API Hierarchy  (0) 2019.10.25
What is TensorFlow?  (0) 2019.10.25
Introduction  (0) 2019.10.25

1. TensorFlow API Hierarchy

(1) Hardware 층
- 여러 하드웨어 플랫폼에서의 실행을 위한 것으로 보통 크게 다룰 일이 없음

(2) Core TensorFlow(C++ API)
- 텐서플로 기반으로 커스텀된 앱을 만들 수 있는 계층

(3) Core TensorFlow(Python API)
- 사칙연산, 행렬곱 등 수치처리를 위한 코드
- 변수 및 텐서 생성, 차원 설정 등의 작업이 이 계층에서 가능함

(4) tf.layers / tf.losses / tf.metrics :
- 커스텀 뉴럴네트워크를 만드는데 유용한 계층
→ 많은 경우에는 학습, 평가, 적용이 표준화된 방법으로 적용이 가능하기 때문에 커스텀이 필요하진 않음
- 활성함수를 통한 hidden layer 계층 만들기(tf.layers), CrossEntropy 계산 등의 작업(tf.losses),
RMSE와 같은 평가지표의 계산(tf.metrics)이 가능함

(5) tf.estimator
- 최상위 계층으로 모델을 학습시키고, 평가하고, 저장하고, 적용하는 계층

2. Lazy Evaluation

위의 예시에서처럼 a와 b를 더하라고 지정을 해주어도
이것만으로는 바로 결과가 나타나지 않는다.
이것을 세션(Session)을 통해서 실행(Run)시켜주어야 한다.
이런 특성 때문에 텐서플로는 lazy 하다고 한다.

결국 정리하자면, 텐서플로를 두 단계를 거쳐야 한다.

(1) 그래프를 그린다(정의한다)
(2) 그래프를 실행시킨다

※ tf.eager 를 통해서 lazy하지 않게 만들 수는 있지만 텐서플로에서 잘 사용되지는 않는다.

그러면 왜 이렇게 Lazy하게 만들었을까?
이런 구조를 통해서 Python에서 C++로의 전환이 최소화되고 계산이 효율적으로 수행될 수 있기 때문이다.

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

TensorFlow 실습 1  (0) 2019.10.28
Tensor and Variable  (0) 2019.10.25
Graph and Session  (0) 2019.10.25
What is TensorFlow?  (0) 2019.10.25
Introduction  (0) 2019.10.25

1. What is TensorFlow?

  • 텐서플로는 단순히 머신러닝에 관한 것이 아니라 수치 계산에 관한 것이다. 가령 미분을 하는데도 쓸 수 있음
  • 텐서플로가 작동하는 방식은 유향 비순환 그래프(directed acyclic graph)를 만드는 것과 같다.
    ※ 유향 비순환 그래프(directed acyclic graph, DAG, 유향 비사이클 그래프)란?

    위와 같은 순환그래프는 그래프에서 볼 수 있듯, A-> B-> C->A 의 싸이클이 발생해서 계속적으로 반복될 수 있는 상황이 발생가능

    하지만 DAG에서는 방향 순환이 없이 무한히 수많은 꼭짓점과 간선으로 구성되며 한 방향으로 나아감
    각각의 노드(node)는 덧셈,뺄셈, 곱셈 등의 간단한 연산부터, softmax 행렬곱 등의 복잡한 연산까지 수학적 연산을 나타냄
    각각의 노드를 연결하는 것은 edge라고 함
    edge는 수학적 연산의 input, output으로 데이터의 array를 나타냄

  • scalar : 3과 5같은 하나의 숫자
  • vector : 숫자들의 일차원 배열
  • maxtrix : 숫자들의 이차원 배열
  • 3D, 4D.... tensor : 숫자들의 3차원, 4차원 등등의 배열
  • tensor : 데이터의 n차원 배열(결국 텐서플로 안의 데이터가 tensor)

→ tensor(데이터)들이 directed acyclic graph를 따라 흘러가기 때문에 TensorFlow라 하는 것

 

2. Benefits of a Directed Graph

그렇다면 TensorFlow는 왜 DAG를 계산에 활용할까?
DAC는 모델의 코드에 사용한 언어와는 독립적이다. 때문에 파이썬으로 DAG를 만들어서 C++에서 구현할 수 있다.
이런 점을 이용해 high level language인 파이썬으로 프로그램을 작성해서 TensorFlow 실행엔진에 의해
작동되는 다른 플랫폼에서 실행할 수 있다.


클라우드에서 모델을 학습시킨 다음, 훈련시킨 모델을 스마트폰 등에서 예측 수행이 가능한 것이다.
(ex: 구글 번역기는 오프라인 상태에서도 예측 정도가 좀 떨어지더라도 학습된 모델로 작동이 가능하다.)

이런 장점으로 인해 텐서플로는 다양한 플랫폼에서 활용될 수 있다.

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

TensorFlow 실습 1  (0) 2019.10.28
Tensor and Variable  (0) 2019.10.25
Graph and Session  (0) 2019.10.25
TensorFlow API Hierarchy  (0) 2019.10.25
Introduction  (0) 2019.10.25

앞으로 배울 것

  • 게으른 평가(lazy evaluation)와 명령형(imperative) 프로그램 작성하기
  • 그래프, 세션(Session), 변수 다루기
  • 텐서플로 그래프 시각화
  • 텐서플로 프로그램 디버깅

'구글 머신러닝 스터디잼(중급) > Introduction to TensorFlow' 카테고리의 다른 글

TensorFlow 실습 1  (0) 2019.10.28
Tensor and Variable  (0) 2019.10.25
Graph and Session  (0) 2019.10.25
TensorFlow API Hierarchy  (0) 2019.10.25
What is TensorFlow?  (0) 2019.10.25

+ Recent posts