Help us understand the problem. What is going on with this article?

言語処理100本ノックで MATLAB 入門!第6章: 機械学習 55-59

はじめに

もしかして 「え?MATLAB で言語処理やるの??」と思いました・・?(5回目)

言語処理 100 本ノック 2020 で MATLAB の練習をするシリーズ。今回は 第6章: 機械学習 50-54 の続きです。

一部都合よく問題文を読み替えていますがご容赦ください。気になるところあれば是非コメントください。

実行環境

  • MATLAB R2020a (Windows 10)
  • Text Analytics Toolbox
  • Statistics and Machine Learning Toolbox

Livescript 版(MATLAB)は GitHub: NLP100-MATLAB1 に置いてあります。そしてノックを一緒にやってくれる MATLAB 芸人は引き続き募集中です!詳細は GitHub の方で。

他章へのリンク

第6章: 機械学習

本章では,Fabio Gasparetti氏が公開しているNews Aggregator Data Setを用い,ニュース記事の見出しを「ビジネス」「科学技術」「エンターテイメント」「健康」のカテゴリに分類するタスク(カテゴリ分類)に取り組む

ここでの結果は問題70でも使うとのことなので、再現性があるように乱数シードを固定しておきます。

Code
clear
rng(0)

55. 混同行列の作成

52で学習したロジスティック回帰モデルの混同行列(confusion matrix)を,学習データおよび評価データ上で作成せよ.

まずは学習データの混同行列

Code
confusionchart(YTrain, YPredTrain,...
    'ColumnSummary',"column-normalized",...
    'RowSummary',"row-normalized");

attach:cat

そして評価データの混同行列(出力を cm で確保しておき後ほど使います)

Code
cm = confusionchart(YTest, YPredTest,...
    'ColumnSummary',"column-normalized",...
    'RowSummary',"row-normalized");

attach:cat

56. 適合率,再現率,F1スコアの計測

52で学習したロジスティック回帰モデルの適合率,再現率,F1スコアを,評価データ上で計測せよ.カテゴリごとに適合率,再現率,F1スコアを求め,カテゴリごとの性能をマイクロ平均(micro-average)とマクロ平均(macro-average)で統合せよ.

それぞれの意味は下記がかなり丁寧にまとめて頂いているのでどうぞ。

【入門者向け】機械学習の分類問題評価指標解説(正解率・適合率・再現率など)

上で求めた混同行列に書き込むと以下の通り:

attach:cat

式で書くと:

  • Precision (適合率) = True Positive / (True Positive + False Positive) = 正予測の正答率
  • Recall (再現率) = True Positive / (True Positive + False Negative) = 正に対する正答率
  • F1スコア = 2 x (Precision x Recall) / (Precision + Recall)

上で確保した cm から混同行列の数値を取って、実際に計算してみます。

Code
confMat = cm.NormalizedValues;

precision = zeros(4,1);
recall = zeros(4,1);

% Precision (適合率) = True Positive / (True Positive + False Positive)
for ii = 1:size(confMat,1)
    precision(ii)=confMat(ii,ii)/sum(confMat(ii,:));
end

% Recall(再現率) = True Positive / (True Positive + False Negative)
for ii = 1:size(confMat,2)
    recall(ii)=confMat(ii,ii)/sum(confMat(:,ii));
end

% F1スコア = 2 x (Precision x Recall) / (Precision + Recall)
F1 = 2*(precision.*recall)./(precision+recall);

% table 型にまとめて表示
scores = table(precision,recall,F1);
scores.Row = ["Business", "Entertainment", "Health", "Technology"]
precision recall F1
1 Business 0.8255 0.8990 0.8607
2 Entertainment 0.9225 0.8370 0.8777
3 Health 0.7176 0.5648 0.6321
4 Technology 0.6353 0.7826 0.7013

マイクロ平均(micro-average)とマクロ平均(macro-average)

上で紹介した記事にも記載がありますが、

  • マクロ平均は単純な算術平均(相加平均):各クラスごとの正解率、適合率等の値を平均するだけ
  • マイクロ平均は混合行列全体からTP等の値集計をして算出であり

ということで、マクロ平均は

Code
macroAve = varfun(@mean,scores)
mean_precision mean_recall mean_F1
1 0.7752 0.7709 0.7679

そしてマイクロ平均は適合率,再現率,F1スコアともに同じ値になります。

Code
trace(confMat)/sum(confMat(:))
Output
ans = 0.8328

57. 特徴量の重みの確認

52で学習したロジスティック回帰モデルの中で,重みの高い特徴量トップ10と,重みの低い特徴量トップ10を確認せよ.

各クラスとそれ以外を分類する 4 つの回帰モデルの重みを確認して、各1つのクラスを他クラスから分類する特徴量トップ10を確認します。

順番は

Code
mdl.ClassNames
Output
ans = 4x1 categorical    
Business         
Entertainment    
Health           
Technology       

で確認できます。以下、それぞれそれっぽい単語が確認できます。

Business vs others

Code
coef = mdl.BinaryLearners{1}.Beta;
[~,idx] = sort(coef,'descend');
[string(coef(idx(1:10))),bag.Vocabulary(idx(1:10))']
Output
ans = 10x2 string    
"0.70628"    "low"        
"0.70023"    "bank"       
"0.69599"    "fed"        
"0.66601"    "euro"       
"0.66279"    "ecb"        
"0.66182"    "china"      
"0.65203"    "update"     
"0.5804"     "rise"       
"0.5773"     "profit"     
"0.55126"    "high"       

Entertainment vs others

Code
coef = mdl.BinaryLearners{2}.Beta;
[~,idx] = sort(coef,'descend');
[string(coef(idx(1:10))),bag.Vocabulary(idx(1:10))']
Output
ans = 10x2 string    
"0.92388"    "kardashian"    
"0.73854"    "kim"           
"0.60877"    "star"          
"0.54974"    "cyrus"         
"0.54974"    "miley"         
"0.5419"     "chris"         
"0.426"      "justin"        
"0.42167"    "bieber"        
"0.41351"    "film"          
"0.41033"    "movie"         

Health vs others

Code
coef = mdl.BinaryLearners{3}.Beta;
[~,idx] = sort(coef,'descend');
[string(coef(idx(1:10))),bag.Vocabulary(idx(1:10))']
Output
ans = 10x2 string    
"1.2932"     "ebola"       
"1.2065"     "drug"        
"1.1688"     "study"       
"1.0042"     "cancer"      
"0.8444"     "fda"         
"0.73315"    "health"      
"0.73187"    "mers"        
"0.65464"    "heart"       
"0.6027"     "outbreak"    
"0.58848"    "virus"       

Technology vs others

Code
coef = mdl.BinaryLearners{4}.Beta;
[coefsort,idx] = sort(coef,'descend');
[string(coef(idx(1:10))),bag.Vocabulary(idx(1:10))']
Output
ans = 10x2 string    
"1.377"      "google"       
"1.0362"     "apple"        
"0.94431"    "facebook"     
"0.76967"    "climate"      
"0.6199"     "microsoft"    
"0.56628"    "recall"       
"0.47101"    "car"          
"0.45785"    "tesla"        
"0.4374"     "space"        
"0.41975"    "fcc"          

58. 正則化パラメータの変更

ロジスティック回帰モデルを学習するとき,正則化パラメータを調整することで,学習時の過学習(overfitting)の度合いを制御できる.異なる正則化パラメータでロジスティック回帰モデルを学習し,学習データ,検証データ,および評価データ上の正解率を求めよ.実験の結果は,正則化パラメータを横軸,正解率を縦軸としたグラフにまとめよ.

正則化パラメータは 'Lambda' で指定するんですが、ドキュメンテーションページを確認すると他にもいろいろ選択肢がある。

引用元:templateLinear: 線形分類学習器テンプレート

'Lambda' 以外を特に何も指定しない場合のデフォルトの挙動を確認しておきます。ペナルティは 'lasso''ridge' かを 'Regularization' で指定できるけども、

Solver'sparsa' の場合、Regularization の既定値は 'lasso' になります。それ以外の場合は、既定値は 'ridge' です。

とのこと。では Solver の既定値はなんだろう・・とみると

  • 予測子データセットに 100 個以下の予測子変数が格納されている場合にリッジ ペナルティ (Regularization を参照) を指定すると、既定のソルバーは 'bfgs' になります。
  • 予測子データセットに 100 個より多い予測子変数が格納されている場合に SVM モデル (Learner を参照) とリッジ ペナルティを指定すると、既定のソルバーは 'dual' になります。
  • 予測子データセットに 100 個以下の予測子変数が格納されている場合に LASSO ペナルティを指定すると、既定のソルバーは 'sparsa' になります。

それ以外の場合、既定のソルバーは 'sgd' になります。

だそうな。

今回の場合は、変数は 5000 個ほどあり、'Learner''logistic' を使っているので、Solver は 'sgd' (確率的勾配降下法)のはず。ということで、何も指定しなければペナルティは 'ridge'。ここではあえて 'Lasso' を指定してやってみます。

Code
Nmodels = 20; % 試す正則化パラメータの数
Lambda = logspace(-8,-3,Nmodels); % 1e-8 から 1e-3 までの値で計算させてみます。

XTrain = bag.Counts;
YTrain = dataTrain2.Category;
t = templateLinear('Learner','logistic','Lambda',Lambda,'Regularization','lasso');
mdl = fitcecoc(XTrain,YTrain,'Learners',t,'Prior',"uniform",'Coding','onevsall');

それぞれのデータセットでの正解率を確認します。

Code
% 学習データでの正解率
YPred = predict(mdl,XTrain);
accTrain = sum(YPred == YTrain)/numel(YTrain);

% 検証データでの正解率
documentsValid = preprocessText(dataValid.Title);
XValid = encode(bag,documentsValid);
YValid = dataValid.Category;
YPred = predict(mdl,XValid);
accValid = sum(YPred == YValid)/numel(YValid);

% 評価データでの正解率
documentsTest = preprocessText(dataTest.Title);
XTest = encode(bag,documentsTest);
YTest = dataTest.Category;
YPred = predict(mdl,XTest);
accTest = sum(YPred == YTest)/numel(YTest);

精度をプロットします。

正則化パラメータが大きいと過学習(学習データの正解率だけが高い状態)が避けられては、、いるのかな。

Code
semilogx(Lambda,[accTrain;accValid;accTest]);
legend(["学習","検証","評価"])

attach:cat

59. ハイパーパラメータの探索

学習アルゴリズムや学習パラメータを変えながら,カテゴリ分類モデルを学習せよ.検証データ上の正解率が最も高くなる学習アルゴリズム・パラメータを求めよ.また,その学習アルゴリズム・パラメータを用いたときの評価データ上の正解率を求めよ.

ここでは引き続き線形分類モデル(ロジスティック or SVM)を使用して、fitcecoc 関数実行時に 'OptimizeHyperparameters''all' とすることでハイパーパラメータの調整を行います。

注意:ここではまず簡単にハイパーパラメータの探索をすることを目的として、学習用データを使って交差検定での正解率が最も高くなるパラメータを求めています。検証用のデータを別途与える方法はまた別途。

この設定では

  • Lambda:範囲 [1e-5/NumObservations,1e5/NumObservations] の対数スケールの正の値
  • Learner'svm' および 'logistic'
  • Regularization'ridge' および 'lasso'

に加えて上で触れた onevsall / onevsone の設定も加えたの中で最適なパラメータをベイズ最適化で求めます。(既定では反復は30回で終了)

詳細:ハイパーパラメーターの最適化

Code
XTrain = bag.Counts;
YTrain = dataTrain2.Category;
t = templateLinear();
[mdl,hyperParamResults] = fitcecoc(XTrain,YTrain,'Learners',t,...
    'Prior','uniform','OptimizeHyperparameters','all');
Output
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |       Coding |       Lambda |      Learner | Regularizati-|
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |              | on           |
|===================================================================================================================================|
|    1 | Best   |     0.22918 |     0.67376 |     0.22918 |     0.22918 |     onevsone |   2.7232e-09 |     logistic |        lasso |
|    2 | Accept |     0.66684 |     0.73746 |     0.22918 |     0.25711 |     onevsone |       2.0151 |     logistic |        ridge |
|    3 | Accept |     0.74997 |     0.49351 |     0.22918 |     0.25752 |     onevsall |      0.69029 |          svm |        lasso |
|    4 | Best   |     0.20683 |     0.64425 |     0.20683 |     0.23018 |     onevsall |   1.1286e-05 |     logistic |        ridge |
|    5 | Accept |     0.20954 |     0.67717 |     0.20683 |      0.2082 |     onevsall |   4.1207e-06 |     logistic |        ridge |
|    6 | Accept |     0.75018 |     0.57807 |     0.20683 |      0.2082 |     onevsall |       9.3614 |     logistic |        lasso |
|    7 | Accept |      0.2135 |     0.47338 |     0.20683 |      0.2082 |     onevsall |       9.3392 |          svm |        ridge |
|    8 | Accept |      0.2135 |     0.55999 |     0.20683 |      0.2082 |     onevsone |       9.3307 |          svm |        ridge |
|    9 | Accept |     0.74964 |     0.66744 |     0.20683 |      0.2082 |     onevsone |       9.3677 |          svm |        lasso |
|   10 | Accept |     0.71118 |     0.75657 |     0.20683 |     0.20693 |     onevsall |       9.3602 |     logistic |        ridge |
|   11 | Best   |     0.16669 |     0.94512 |     0.16669 |     0.16672 |     onevsall |     0.026835 |          svm |        ridge |
|   12 | Accept |     0.17545 |     0.69829 |     0.16669 |      0.1667 |     onevsall |   7.7028e-06 |          svm |        ridge |
|   13 | Accept |     0.17528 |     0.80189 |     0.16669 |      0.1668 |     onevsone |   0.00025478 |          svm |        ridge |
|   14 | Accept |      0.1897 |     0.68357 |     0.16669 |     0.16762 |     onevsone |   9.3861e-10 |          svm |        ridge |
|   15 | Accept |     0.21486 |     0.68052 |     0.16669 |     0.16734 |     onevsall |   9.4569e-10 |     logistic |        ridge |
|   16 | Accept |     0.75003 |      1.8664 |     0.16669 |     0.16698 |     onevsone |     0.068546 |     logistic |        lasso |
|   17 | Accept |     0.17159 |     0.62222 |     0.16669 |     0.16811 |     onevsall |   9.5997e-10 |          svm |        ridge |
|   18 | Accept |     0.18688 |     0.72309 |     0.16669 |     0.16835 |     onevsone |    4.424e-07 |          svm |        ridge |
|   19 | Best   |     0.13224 |     0.68209 |     0.13224 |     0.14572 |     onevsall |   0.00080424 |          svm |        ridge |
|   20 | Accept |     0.13247 |     0.67497 |     0.13224 |     0.13249 |     onevsall |   0.00070709 |          svm |        ridge |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |       Coding |       Lambda |      Learner | Regularizati-|
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |              | on           |
|===================================================================================================================================|
|   21 | Accept |     0.22322 |     0.71171 |     0.13224 |     0.13246 |     onevsone |    9.376e-10 |     logistic |        ridge |
|   22 | Accept |     0.22831 |      0.5771 |     0.13224 |     0.13244 |     onevsone |    9.379e-10 |          svm |        lasso |
|   23 | Accept |     0.22266 |     0.55386 |     0.13224 |     0.13241 |     onevsall |   9.4698e-10 |     logistic |        lasso |
|   24 | Accept |     0.16779 |     0.48329 |     0.13224 |     0.13239 |     onevsall |   9.4362e-10 |          svm |        lasso |
|   25 | Accept |     0.28698 |     0.70422 |     0.13224 |     0.13236 |     onevsone |     0.042247 |          svm |        ridge |
|   26 | Accept |     0.16689 |     0.47344 |     0.13224 |     0.13236 |     onevsall |   8.0688e-08 |          svm |        lasso |
|   27 | Accept |     0.21894 |     0.54948 |     0.13224 |     0.13236 |     onevsall |   6.6158e-07 |     logistic |        lasso |
|   28 | Accept |     0.22618 |     0.58532 |     0.13224 |     0.13236 |     onevsone |   7.2612e-07 |          svm |        lasso |
|   29 | Accept |     0.23204 |     0.76241 |     0.13224 |     0.13236 |     onevsone |   6.1776e-07 |     logistic |        ridge |
|   30 | Accept |     0.17709 |     0.64137 |     0.13224 |      0.1324 |     onevsall |   3.3226e-08 |          svm |        ridge |

attach:cat

Output
__________________________________________________________
最適化が完了しました。
MaxObjectiveEvaluations の 30 に達しました。
関数の評価回数の合計: 30
経過時間の合計: 45.3979 秒。
目的関数の評価時間の合計: 20.6819

最適な観測実行可能点:
     Coding       Lambda      Learner    Regularization
    ________    __________    _______    ______________

    onevsall    0.00080424      svm          ridge     

観測された目的関数値 = 0.13224
推定される目的関数値 = 0.13244
関数の評価時間 = 0.68209

最適な推定実行可能点 (モデルに基づく):
     Coding       Lambda      Learner    Regularization
    ________    __________    _______    ______________

    onevsall    0.00070709      svm          ridge     

推定される目的関数値 = 0.1324
推定される関数評価時間 = 0.7132

ということで、以下の最適化パラメータに落ち着きました。

Code
hyperParamResults.XAtMinEstimatedObjective
Coding Lambda Learner Regularization
1 onevsall 7.0709e-04 svm ridge

改めて学習データでの精度確認しておきます。

Code
YPredTrain = predict(mdl,XTrain);
confusionchart(YTrain, YPredTrain,...
    'ColumnSummary',"column-normalized",...
    'RowSummary',"row-normalized");

attach:cat

評価用データはこちら。

Code
YTest = dataTest.Category;
documents = preprocessText(dataTest.Title);
XTest = encode(bag,documents);
YPredTest = predict(mdl,XTest);
confusionchart(YTest, YPredTest,...
    'ColumnSummary',"column-normalized",...
    'RowSummary',"row-normalized");

attach:cat

さすがに正解率は高いですね。

興味深いのは、Precision (適合率) と Recall(再現率)の差が大きい 'Health'。適合率、すなわち 'Health' 記事であれば 'Health' と高い確率で予想するが、それ以外のものを 'Heath' と(比較的)勘違いしやすい状態。学習データ数の偏りに対して 'Uniform' と指定して学習させたことが影響していそう。

ヘルパー関数

文章の前処理用関数。以下の処理をしています。

  1. tokenizedDocument で入力文をトークン化(単語に分割)
  2. removeStopWords でand, of, the などのストップワードを削除
  3. normalizeWords で各単語を原型(?)にそろえる
  4. erasePunctuation でコンマやピリオドなど削除
  5. removeShortWords で 2文字以下の単語を削除
  6. removeLongWords で 15 文字以上の単語を削除
Code
function documents = preprocessText(textData)

    % Tokenize the text.
    documents = tokenizedDocument(textData);

    % Remove a list of stop words then lemmatize the words. To improve
    % lemmatization, first use addPartOfSpeechDetails.
    documents = addPartOfSpeechDetails(documents);
    documents = removeStopWords(documents);
    documents = normalizeWords(documents,'Style','lemma');

    % Erase punctuation.
    documents = erasePunctuation(documents);

    % Remove words with 2 or fewer characters, and words with 15 or more
    % characters.
    documents = removeShortWords(documents,2);
    documents = removeLongWords(documents,15);

end

文章を入力して用意したモデルで予測結果(スコア)を返す関数

Code
function [class, score] = predictClassFromTitle(mdl,bag,title)
    documents = preprocessText(title);
    X = encode(bag,documents);
    [class,~,~,posterior] = predict(mdl,X);
    % class は予想されるカテゴリ
    % posterior は各クラスに分類される確率
    % なので最大値が "class" に予測される確率
    score = max(posterior,[],2); 
end

  1. Livescript から markdown への変換は livescript2markdown​: MATLAB's live scripts to markdown を使っています。 

eigs
MATLAB の中の人. 公式ブログも書いています. All comments and opinions expressed are mine alone and do not necessarily reflect those of my employers, past or present.
https://blogs.mathworks.com/japan-community/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした