Skip to content

二値分類器を作る

このセクションでは、線形分類不可能なデータセットであるmoons(詳細はこちら)を分類してみます。

分類器として以下4つを試して、各手法ごとの比較をします。

  1. ロジスティック回帰
  2. ランダムフォレスト
  3. RBFカーネルのSVM
  4. 線形カーネルのSVM

moonsは線形分離不可能な二値分類問題ですから、手法としては2と3が有利そうです。 これを実験により確認します。

config.py を作成した後にtrain,inspectを実行してみます。

import copy
train_config_template = {
    'dataset_config': {
        'loader_config': {
            'name': 'binary_classifier_sample_moon',
            'kwargs': {
                'n_samples': 1000,
                'noise': 0.1,
                'random_state': 0,
            },
        },
    },
    'model_config': {
        'evaluate_kwargs': {
            'cross_val_iterator': 'KFold@sklearn.model_selection',
            'cross_val_iterator_kwargs': {
                'shuffle': True,
                'random_state': 0,
                'n_splits': 10, 
            },  
        },
        'pos_index': 1,
    },
    'evaluate_enabled': True,
    'fit_model_enabled': True,
    'dump_result_enabled': True,
}

train_config_1 = copy.deepcopy(train_config_template)
train_config_2 = copy.deepcopy(train_config_template)
train_config_3 = copy.deepcopy(train_config_template)
train_config_4 = copy.deepcopy(train_config_template)
train_config_1['model_config']['name'] = 'SklearnLogisticRegression'
train_config_1['model_config']['init_kwargs'] = {'random_state': 0}
train_config_2['model_config']['name'] = 'SklearnRandomForestClassifier'
train_config_2['model_config']['init_kwargs'] = {'random_state': 0}
train_config_3['model_config']['name'] = 'SklearnSVC'
train_config_3['model_config']['init_kwargs'] = {'kernel': 'rbf', 'probability': True, 'random_state': 0}
train_config_4['model_config']['name'] = 'SklearnSVC'
train_config_4['model_config']['init_kwargs'] = {'kernel': 'linear', 'probability': True, 'random_state': 0}

train_config = [train_config_1, train_config_2, train_config_3, train_config_4]
$ akebono train
$ akebono inspect

### 実行結果
### trainのログは省略

=== scenario summary .. tag: default ===

------------------------------------------------------------
train_id: 0

accuracy  f1_score  log_loss  precision recall    roc_auc  
0.87400   0.87192   4.35194   0.87180   0.87369   0.96103  

------------------------------------------------------------
train_id: 1

accuracy  f1_score  log_loss  precision recall    roc_auc  
0.99400   0.99382   0.20723   0.99583   0.99194   0.99885  

------------------------------------------------------------
train_id: 2

accuracy  f1_score  log_loss  precision recall    roc_auc  
0.99200   0.99168   0.27632   0.98803   0.99555   0.99992  

------------------------------------------------------------
train_id: 3

accuracy  f1_score  log_loss  precision recall    roc_auc  
0.87500   0.87274   4.31740   0.87163   0.87510   0.96118

このように、概ね予想通りであることが確認できました。