はじめに
ロシアのGoogleと言われているYandex社が開発した機械学習ライブラリ「Catboost」をRで使いました。
内容は基本的に公式サイトを参考にしています。
環境
Windows10 64bit
R-3.4.2
インストール手順
R上で次のコマンドを実行。※最新のファイルは公式のgithubを参照
install.packages('devtools')
devtools::install_url('https://github.com/catboost/catboost/releases/download/v0.8.1/catboost-R-Windows-0.8.1.tgz', args = c("--no-multiarch"))
実行確認
CatBoost RパッケージのデータセットAdult Data Set を利用して、モデル作成、適用まで実施。
catBoostQuickStart.R
library(catboost)
# データセットの読み込み
pool_path <- system.file("extdata",
"adult_train.1000",
package = "catboost")
cd_path <- system.file("extdata",
"adult.cd",
package = "catboost")
pool <- catboost.load_pool(pool_path, column_description = cd_path)
# モデル作成
fit_params <- list(iterations = 100,
thread_count = 10,
loss_function = 'Logloss')
model <- catboost.train(pool, pool, fit_params)
# 適用
prediction <- catboost.predict(model, pool)
# 結果確認
head(prediction)
実行結果
0: learn: 0.6694262 test: 0.6693275 best: 0.6693275 (0) total: 74.9ms remaining: 7.42s
1: learn: 0.6546126 test: 0.6543231 best: 0.6543231 (1) total: 84ms remaining: 4.12s
2: learn: 0.6353726 test: 0.6350129 best: 0.6350129 (2) total: 103ms remaining: 3.33s
3: learn: 0.6198166 test: 0.6192209 best: 0.6192209 (3) total: 117ms remaining: 2.81s
4: learn: 0.6026045 test: 0.6013479 best: 0.6013479 (4) total: 142ms remaining: 2.69s
5: learn: 0.5873222 test: 0.5857090 best: 0.5857090 (5) total: 164ms remaining: 2.56s
6: learn: 0.5752527 test: 0.5736026 best: 0.5736026 (6) total: 185ms remaining: 2.46s
7: learn: 0.5651169 test: 0.5627786 best: 0.5627786 (7) total: 207ms remaining: 2.38s
8: learn: 0.5549792 test: 0.5526276 best: 0.5526276 (8) total: 222ms remaining: 2.24s
9: learn: 0.5436931 test: 0.5413466 best: 0.5413466 (9) total: 243ms remaining: 2.19s
10: learn: 0.5317787 test: 0.5295916 best: 0.5295916 (10) total: 264ms remaining: 2.14s
11: learn: 0.5216896 test: 0.5194055 best: 0.5194055 (11) total: 286ms remaining: 2.1s
12: learn: 0.5119869 test: 0.5094181 best: 0.5094181 (12) total: 305ms remaining: 2.04s
13: learn: 0.5038046 test: 0.5013013 best: 0.5013013 (13) total: 327ms remaining: 2.01s
14: learn: 0.4966602 test: 0.4942511 best: 0.4942511 (14) total: 348ms remaining: 1.97s
15: learn: 0.4889777 test: 0.4859823 best: 0.4859823 (15) total: 368ms remaining: 1.93s
16: learn: 0.4820448 test: 0.4789994 best: 0.4789994 (16) total: 385ms remaining: 1.88s
17: learn: 0.4757739 test: 0.4723728 best: 0.4723728 (17) total: 405ms remaining: 1.85s
18: learn: 0.4685700 test: 0.4647793 best: 0.4647793 (18) total: 422ms remaining: 1.8s
19: learn: 0.4628205 test: 0.4590794 best: 0.4590794 (19) total: 444ms remaining: 1.78s
20: learn: 0.4577653 test: 0.4538699 best: 0.4538699 (20) total: 465ms remaining: 1.75s
21: learn: 0.4538636 test: 0.4497996 best: 0.4497996 (21) total: 485ms remaining: 1.72s
22: learn: 0.4494997 test: 0.4452677 best: 0.4452677 (22) total: 504ms remaining: 1.69s
23: learn: 0.4458840 test: 0.4414506 best: 0.4414506 (23) total: 523ms remaining: 1.66s
24: learn: 0.4404981 test: 0.4361580 best: 0.4361580 (24) total: 543ms remaining: 1.63s
25: learn: 0.4360685 test: 0.4317012 best: 0.4317012 (25) total: 561ms remaining: 1.6s
26: learn: 0.4315717 test: 0.4271391 best: 0.4271391 (26) total: 578ms remaining: 1.56s
27: learn: 0.4269167 test: 0.4226632 best: 0.4226632 (27) total: 597ms remaining: 1.53s
28: learn: 0.4226198 test: 0.4185094 best: 0.4185094 (28) total: 618ms remaining: 1.51s
29: learn: 0.4197275 test: 0.4155232 best: 0.4155232 (29) total: 639ms remaining: 1.49s
30: learn: 0.4166405 test: 0.4124730 best: 0.4124730 (30) total: 652ms remaining: 1.45s
31: learn: 0.4136576 test: 0.4096500 best: 0.4096500 (31) total: 670ms remaining: 1.42s
32: learn: 0.4105338 test: 0.4066999 best: 0.4066999 (32) total: 690ms remaining: 1.4s
33: learn: 0.4073046 test: 0.4037666 best: 0.4037666 (33) total: 711ms remaining: 1.38s
34: learn: 0.4042087 test: 0.4010269 best: 0.4010269 (34) total: 728ms remaining: 1.35s
35: learn: 0.4018034 test: 0.3985344 best: 0.3985344 (35) total: 749ms remaining: 1.33s
36: learn: 0.3996649 test: 0.3967903 best: 0.3967903 (36) total: 767ms remaining: 1.31s
37: learn: 0.3978337 test: 0.3952699 best: 0.3952699 (37) total: 785ms remaining: 1.28s
38: learn: 0.3963014 test: 0.3937572 best: 0.3937572 (38) total: 802ms remaining: 1.25s
39: learn: 0.3932189 test: 0.3906877 best: 0.3906877 (39) total: 823ms remaining: 1.23s
40: learn: 0.3916126 test: 0.3890991 best: 0.3890991 (40) total: 844ms remaining: 1.21s
41: learn: 0.3890241 test: 0.3865875 best: 0.3865875 (41) total: 863ms remaining: 1.19s
42: learn: 0.3868832 test: 0.3845794 best: 0.3845794 (42) total: 885ms remaining: 1.17s
43: learn: 0.3845722 test: 0.3822688 best: 0.3822688 (43) total: 905ms remaining: 1.15s
44: learn: 0.3827924 test: 0.3805462 best: 0.3805462 (44) total: 919ms remaining: 1.12s
45: learn: 0.3815782 test: 0.3793895 best: 0.3793895 (45) total: 935ms remaining: 1.1s
46: learn: 0.3798771 test: 0.3779259 best: 0.3779259 (46) total: 953ms remaining: 1.07s
47: learn: 0.3784449 test: 0.3764749 best: 0.3764749 (47) total: 974ms remaining: 1.05s
48: learn: 0.3779165 test: 0.3759414 best: 0.3759414 (48) total: 981ms remaining: 1.02s
49: learn: 0.3772907 test: 0.3753137 best: 0.3753137 (49) total: 990ms remaining: 990ms
50: learn: 0.3758383 test: 0.3738934 best: 0.3738934 (50) total: 1.01s remaining: 971ms
51: learn: 0.3749740 test: 0.3730539 best: 0.3730539 (51) total: 1.02s remaining: 944ms
52: learn: 0.3737392 test: 0.3718864 best: 0.3718864 (52) total: 1.04s remaining: 921ms
53: learn: 0.3733902 test: 0.3714978 best: 0.3714978 (53) total: 1.05s remaining: 894ms
54: learn: 0.3720477 test: 0.3700849 best: 0.3700849 (54) total: 1.07s remaining: 875ms
55: learn: 0.3704381 test: 0.3686721 best: 0.3686721 (55) total: 1.09s remaining: 858ms
56: learn: 0.3689810 test: 0.3673272 best: 0.3673272 (56) total: 1.11s remaining: 840ms
57: learn: 0.3677426 test: 0.3661167 best: 0.3661167 (57) total: 1.13s remaining: 818ms
58: learn: 0.3661773 test: 0.3646030 best: 0.3646030 (58) total: 1.15s remaining: 800ms
59: learn: 0.3650990 test: 0.3635791 best: 0.3635791 (59) total: 1.17s remaining: 777ms
60: learn: 0.3639894 test: 0.3626673 best: 0.3626673 (60) total: 1.19s remaining: 758ms
61: learn: 0.3633682 test: 0.3621720 best: 0.3621720 (61) total: 1.2s remaining: 737ms
62: learn: 0.3629123 test: 0.3616917 best: 0.3616917 (62) total: 1.21s remaining: 713ms
63: learn: 0.3617605 test: 0.3606652 best: 0.3606652 (63) total: 1.24s remaining: 695ms
64: learn: 0.3614998 test: 0.3604091 best: 0.3604091 (64) total: 1.24s remaining: 670ms
65: learn: 0.3608548 test: 0.3597880 best: 0.3597880 (65) total: 1.25s remaining: 646ms
66: learn: 0.3603772 test: 0.3593986 best: 0.3593986 (66) total: 1.27s remaining: 625ms
67: learn: 0.3595835 test: 0.3586500 best: 0.3586500 (67) total: 1.29s remaining: 608ms
68: learn: 0.3583888 test: 0.3573487 best: 0.3573487 (68) total: 1.31s remaining: 590ms
69: learn: 0.3574504 test: 0.3564566 best: 0.3564566 (69) total: 1.33s remaining: 570ms
70: learn: 0.3558268 test: 0.3548918 best: 0.3548918 (70) total: 1.35s remaining: 553ms
71: learn: 0.3548097 test: 0.3538881 best: 0.3538881 (71) total: 1.37s remaining: 534ms
72: learn: 0.3532903 test: 0.3523605 best: 0.3523605 (72) total: 1.39s remaining: 515ms
73: learn: 0.3520095 test: 0.3512613 best: 0.3512613 (73) total: 1.42s remaining: 497ms
74: learn: 0.3512087 test: 0.3504542 best: 0.3504542 (74) total: 1.43s remaining: 478ms
75: learn: 0.3503032 test: 0.3496401 best: 0.3496401 (75) total: 1.45s remaining: 459ms
76: learn: 0.3490388 test: 0.3485768 best: 0.3485768 (76) total: 1.47s remaining: 440ms
77: learn: 0.3481170 test: 0.3476758 best: 0.3476758 (77) total: 1.49s remaining: 421ms
78: learn: 0.3473318 test: 0.3469230 best: 0.3469230 (78) total: 1.51s remaining: 402ms
79: learn: 0.3470634 test: 0.3466603 best: 0.3466603 (79) total: 1.52s remaining: 380ms
80: learn: 0.3449355 test: 0.3453018 best: 0.3453018 (80) total: 1.54s remaining: 361ms
81: learn: 0.3437443 test: 0.3446793 best: 0.3446793 (81) total: 1.56s remaining: 343ms
82: learn: 0.3431200 test: 0.3445518 best: 0.3445518 (82) total: 1.58s remaining: 324ms
83: learn: 0.3420508 test: 0.3437707 best: 0.3437707 (83) total: 1.6s remaining: 306ms
84: learn: 0.3418638 test: 0.3435859 best: 0.3435859 (84) total: 1.61s remaining: 285ms
85: learn: 0.3410036 test: 0.3431157 best: 0.3431157 (85) total: 1.63s remaining: 266ms
86: learn: 0.3401323 test: 0.3423150 best: 0.3423150 (86) total: 1.66s remaining: 247ms
87: learn: 0.3392601 test: 0.3414012 best: 0.3414012 (87) total: 1.67s remaining: 228ms
88: learn: 0.3383800 test: 0.3405838 best: 0.3405838 (88) total: 1.69s remaining: 209ms
89: learn: 0.3373359 test: 0.3395540 best: 0.3395540 (89) total: 1.71s remaining: 190ms
90: learn: 0.3363756 test: 0.3387289 best: 0.3387289 (90) total: 1.73s remaining: 171ms
91: learn: 0.3358901 test: 0.3383092 best: 0.3383092 (91) total: 1.75s remaining: 152ms
92: learn: 0.3357097 test: 0.3381380 best: 0.3381380 (92) total: 1.76s remaining: 132ms
93: learn: 0.3354169 test: 0.3378953 best: 0.3378953 (93) total: 1.78s remaining: 114ms
94: learn: 0.3347047 test: 0.3369818 best: 0.3369818 (94) total: 1.8s remaining: 94.7ms
95: learn: 0.3339866 test: 0.3362815 best: 0.3362815 (95) total: 1.82s remaining: 75.9ms
96: learn: 0.3328715 test: 0.3351374 best: 0.3351374 (96) total: 1.84s remaining: 56.8ms
97: learn: 0.3320110 test: 0.3343723 best: 0.3343723 (97) total: 1.86s remaining: 37.9ms
98: learn: 0.3313415 test: 0.3340605 best: 0.3340605 (98) total: 1.88s remaining: 19ms
99: learn: 0.3309613 test: 0.3336766 best: 0.3336766 (99) total: 1.9s remaining: 0us
bestTest = 0.3336765569
bestIteration = 99
Shrink model to first 100 iterations.