2-2. 사이킷런의 기반 프레임워크

1. Estimator 클래스 및 fit(), predict() 메서드

기본 Estimator 클래스 : Classifier와 Regressor로 나뉨

각각의 Estimator는 내부에서 fit()과 predict()를 내부에서 구현

  • fit() : 모델 학습
  • predict() : 학습된 모델의 예측
  • transform() : 입력된 데이터의 형태에 맞추어 데이터를 변환

cross_val_score() (evaluation 함수)나 GridSearchCV (하이퍼파라미터 튜닝) 같은 클래스의 경우 Estimator를 인자로 받고 cross_val_score(), GridSearchCV.fit() 함수 내에서 fit과 predict를 호출해서 평가하거나 튜닝을 수행

2. 사이킷런의 주요 모듈

In [1]:
import pandas as pd
pd.read_csv('scikit-learn의 모듈.csv')
Out[1]:
분류 모듈명 설명
0 예제 데이터 sklearn.datasets 사이킷런 내장 예제 데이터 세트
1 피처 처리 sklearn.preprocessing 데이터 전처리 필요한 다양한 가공기능(인코딩, 정규화, 스케일링 등)
2 피처 처리 sklearn.feature_selection 알고리즘에 큰 영향을 미치는 피처를 우선순위대로 셀렉션하는 기능 제공
3 피처 처리 sklearn.feature_extraction 텍스트, 이미지 데이터의 벡터화된 피처를 추출하는데 사용
4 차원축소 sklearn.decomposition 차원 축소와 관련된 알고리즘을 지원(PCA, NMF, Truncated SVD 등)
5 모델 선택 sklearn.model_selection 훈련, 테스트 데이터 분리, 그리드 서치 등의 기능 제공
6 평가 sklearn.metrics 다양한 모델의 성능평가 측정방법 제공(Accuracy, ROC-AUC, RMSE 등
7 알고리즘 sklearn.ensemble 앙상블 알고리즘 제공(RandomForest, AdaBoost, Gradient B...
8 알고리즘 sklearn.linear_model 회귀 관련 알고리즘 제공(linear, Ridge, Lasso, Logistic 등)
9 알고리즘 sklearn.naive_bayes 나이브 베이즈 알고리즘 제공(Gaussian NB emd)
10 알고리즘 sklearn.neighbors 최근접이웃 알고리즘 제공(K-NN 등)
11 알고리즘 sklearn.svm 서포트 벡터 머신 알고리즘
12 알고리즘 sklearn.tree 의사결정 트리 알고리즘 제공
13 알고리즘 sklearn.cluster 비지도 클러스터링 알고리즘 제공
14 유틸리티 sklearn.pipeline 피처처리 등의 변환과 ML 알고리즘 학습, 예측 등을 함께 묶어서 실행하는 유틸리티 제공

3. 내장된 예제 데이터 세트 살펴보기

In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()
print(type(iris_data))
<class 'sklearn.utils.Bunch'>
In [3]:
keys = iris_data.keys()
print('붓꽃 데이터 세트의 키들: ', keys)
붓꽃 데이터 세트의 키들:  dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
In [4]:
print('feature_names의 type: ', type(iris_data.feature_names))
print('feature_names의 shape: ', len(iris_data.feature_names))
print('feature_names: ', iris_data.feature_names)

print('\ntarget_names의 type: ', type(iris_data.target_names))
print('target_names의 shape: ', len(iris_data.target_names))
print('target_names: ', iris_data.target_names)

print('\ndata의 type: ', type(iris_data.data))
print('data의 shape: ', iris_data.data.shape)
print('data: \n', iris_data['data'])

print('\ntarget의 type: ', type(iris_data.target))
print('target의 shape: ', iris_data.target.shape)
print('target: \n', iris_data['target'])
feature_names의 type:  <class 'list'>
feature_names의 shape:  4
feature_names:  ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

target_names의 type:  <class 'numpy.ndarray'>
target_names의 shape:  3
target_names:  ['setosa' 'versicolor' 'virginica']

data의 type:  <class 'numpy.ndarray'>
data의 shape:  (150, 4)
data: 
 [[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.4 3.9 1.7 0.4]
 [4.6 3.4 1.4 0.3]
 [5.  3.4 1.5 0.2]
 [4.4 2.9 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.4 3.7 1.5 0.2]
 [4.8 3.4 1.6 0.2]
 [4.8 3.  1.4 0.1]
 [4.3 3.  1.1 0.1]
 [5.8 4.  1.2 0.2]
 [5.7 4.4 1.5 0.4]
 [5.4 3.9 1.3 0.4]
 [5.1 3.5 1.4 0.3]
 [5.7 3.8 1.7 0.3]
 [5.1 3.8 1.5 0.3]
 [5.4 3.4 1.7 0.2]
 [5.1 3.7 1.5 0.4]
 [4.6 3.6 1.  0.2]
 [5.1 3.3 1.7 0.5]
 [4.8 3.4 1.9 0.2]
 [5.  3.  1.6 0.2]
 [5.  3.4 1.6 0.4]
 [5.2 3.5 1.5 0.2]
 [5.2 3.4 1.4 0.2]
 [4.7 3.2 1.6 0.2]
 [4.8 3.1 1.6 0.2]
 [5.4 3.4 1.5 0.4]
 [5.2 4.1 1.5 0.1]
 [5.5 4.2 1.4 0.2]
 [4.9 3.1 1.5 0.2]
 [5.  3.2 1.2 0.2]
 [5.5 3.5 1.3 0.2]
 [4.9 3.6 1.4 0.1]
 [4.4 3.  1.3 0.2]
 [5.1 3.4 1.5 0.2]
 [5.  3.5 1.3 0.3]
 [4.5 2.3 1.3 0.3]
 [4.4 3.2 1.3 0.2]
 [5.  3.5 1.6 0.6]
 [5.1 3.8 1.9 0.4]
 [4.8 3.  1.4 0.3]
 [5.1 3.8 1.6 0.2]
 [4.6 3.2 1.4 0.2]
 [5.3 3.7 1.5 0.2]
 [5.  3.3 1.4 0.2]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.9 3.1 4.9 1.5]
 [5.5 2.3 4.  1.3]
 [6.5 2.8 4.6 1.5]
 [5.7 2.8 4.5 1.3]
 [6.3 3.3 4.7 1.6]
 [4.9 2.4 3.3 1. ]
 [6.6 2.9 4.6 1.3]
 [5.2 2.7 3.9 1.4]
 [5.  2.  3.5 1. ]
 [5.9 3.  4.2 1.5]
 [6.  2.2 4.  1. ]
 [6.1 2.9 4.7 1.4]
 [5.6 2.9 3.6 1.3]
 [6.7 3.1 4.4 1.4]
 [5.6 3.  4.5 1.5]
 [5.8 2.7 4.1 1. ]
 [6.2 2.2 4.5 1.5]
 [5.6 2.5 3.9 1.1]
 [5.9 3.2 4.8 1.8]
 [6.1 2.8 4.  1.3]
 [6.3 2.5 4.9 1.5]
 [6.1 2.8 4.7 1.2]
 [6.4 2.9 4.3 1.3]
 [6.6 3.  4.4 1.4]
 [6.8 2.8 4.8 1.4]
 [6.7 3.  5.  1.7]
 [6.  2.9 4.5 1.5]
 [5.7 2.6 3.5 1. ]
 [5.5 2.4 3.8 1.1]
 [5.5 2.4 3.7 1. ]
 [5.8 2.7 3.9 1.2]
 [6.  2.7 5.1 1.6]
 [5.4 3.  4.5 1.5]
 [6.  3.4 4.5 1.6]
 [6.7 3.1 4.7 1.5]
 [6.3 2.3 4.4 1.3]
 [5.6 3.  4.1 1.3]
 [5.5 2.5 4.  1.3]
 [5.5 2.6 4.4 1.2]
 [6.1 3.  4.6 1.4]
 [5.8 2.6 4.  1.2]
 [5.  2.3 3.3 1. ]
 [5.6 2.7 4.2 1.3]
 [5.7 3.  4.2 1.2]
 [5.7 2.9 4.2 1.3]
 [6.2 2.9 4.3 1.3]
 [5.1 2.5 3.  1.1]
 [5.7 2.8 4.1 1.3]
 [6.3 3.3 6.  2.5]
 [5.8 2.7 5.1 1.9]
 [7.1 3.  5.9 2.1]
 [6.3 2.9 5.6 1.8]
 [6.5 3.  5.8 2.2]
 [7.6 3.  6.6 2.1]
 [4.9 2.5 4.5 1.7]
 [7.3 2.9 6.3 1.8]
 [6.7 2.5 5.8 1.8]
 [7.2 3.6 6.1 2.5]
 [6.5 3.2 5.1 2. ]
 [6.4 2.7 5.3 1.9]
 [6.8 3.  5.5 2.1]
 [5.7 2.5 5.  2. ]
 [5.8 2.8 5.1 2.4]
 [6.4 3.2 5.3 2.3]
 [6.5 3.  5.5 1.8]
 [7.7 3.8 6.7 2.2]
 [7.7 2.6 6.9 2.3]
 [6.  2.2 5.  1.5]
 [6.9 3.2 5.7 2.3]
 [5.6 2.8 4.9 2. ]
 [7.7 2.8 6.7 2. ]
 [6.3 2.7 4.9 1.8]
 [6.7 3.3 5.7 2.1]
 [7.2 3.2 6.  1.8]
 [6.2 2.8 4.8 1.8]
 [6.1 3.  4.9 1.8]
 [6.4 2.8 5.6 2.1]
 [7.2 3.  5.8 1.6]
 [7.4 2.8 6.1 1.9]
 [7.9 3.8 6.4 2. ]
 [6.4 2.8 5.6 2.2]
 [6.3 2.8 5.1 1.5]
 [6.1 2.6 5.6 1.4]
 [7.7 3.  6.1 2.3]
 [6.3 3.4 5.6 2.4]
 [6.4 3.1 5.5 1.8]
 [6.  3.  4.8 1.8]
 [6.9 3.1 5.4 2.1]
 [6.7 3.1 5.6 2.4]
 [6.9 3.1 5.1 2.3]
 [5.8 2.7 5.1 1.9]
 [6.8 3.2 5.9 2.3]
 [6.7 3.3 5.7 2.5]
 [6.7 3.  5.2 2.3]
 [6.3 2.5 5.  1.9]
 [6.5 3.  5.2 2. ]
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]]

target의 type:  <class 'numpy.ndarray'>
target의 shape:  (150,)
target: 
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

+ Recent posts