3
3

顧客離れの予測 AWS SageMaker CLI版(1/2)

Last updated at Posted at 2024-03-08

顧客を失うことは、どのようなビジネスにとってもコストがかかります。不満を抱いている顧客を早期に特定すると、顧客に継続するインセンティブを提供する 機会 が得られます。

顧客を失うことはビジネスでは高く付きます。満足度の低い顧客を早い段階で特定することで、利用継続のインセンティブを与えられる可能性があります。

これでは、ピンとこない人に数字で説明します。

1対5の法則とは、新規顧客に販売するコストは既存顧客に販売するコストの5倍かかるという法則です。

つまり、同額の売り上げを達成する場合でも、新規顧客に販売する方がコストが高くなり、利益率が低くなるということです。

1対5.PNG

新規顧客(ターゲット顧客)に当社の商品やサービスを買ってもらうには、まずは知ってもらわなければいけません。そのためには、広告を出したり、割引率を大きくしたり、無料サンプルを配布したり、無料体験サービスなどを実施する必要があるからです。

AIDMAの法則によると、消費者が物を購入するときには 「注意→興味→欲求→記憶→購入」 の流れを無意識に行っているとされています。

それに対して、既存顧客であれば、割引利率も少なくリピート購入をしていただけるケースが多いですし、「これまでの付き合いもあるのでなんとなく」当社の商品・サービスを選択してもらいやすいといえます。このようなシチュエーションは皆様も心当たりがあると思います。

消費者は、現在使っている商品やサービスから、他社の商品やサービスに乗り換えることは、心理的に大きな負担を感じると言われています。これは 「スイッチングコスト」 と呼ばれるものです。

「スイッチングコスト」 とは、現在利用している製品やサービスから、別の製品やサービスに乗り換える際に負担する金銭的、心理的、手間などのコストのことである。 例えば、携帯電話の通信キャリアを変更する際にかかる費用 もこれに含まれます。

新規顧客を獲得するには、対象となる新規顧客見込み先の 「スイッチングコスト」 のハードルを下げることも必要なために、多大な労力や投資が必要になってくるものです。

525.PNG

具体的な例を挙げますと、売上1,000万円の会社が毎年約半数の新規客を失い、それを補うために新規開拓が必要となります。粗利率が30%で粗利が300万円、営業経費が全て固定費として240万円の場合、営業利益は60万円で売上高の6%になります。 営業利益5%以上は、一つの目安です。

もしも売上の5%すなわち50万円の顧客流出が防止できたら 、それだけ売上が増えることになり、それに伴う増加する粗利は売上50万円×30%=15万円です。営業利益60万円の25%に相当しますね。

この場合の利益は、あまり厳密に考えなくても良いと思われますが、例えば営業利益の 25%程が改善される といったイメージで良いと思います。

これら二つの法則から言えることは、新規顧客の獲得よりも、 既存顧客の維持 に目を向けた方が、より効率的に事業成長を遂げることができるということです。

AWS SageMakerで、解決していきたいと思います。

機械学習 (ML) を用いて満足度の低い顧客を自動的に特定する方法
顧客のチャーン予測 (customer churn prediction) とも呼ばれます を説明します。

我々にとって身近なチャーン、携帯電話事業者を解約する例を用いることにします
不満ならいつでも見つけられそうです。
もし通信会社が自分が解約しようとしていることを知っているなら、一時的なインセンティブ いつでも携帯をアップグレードできるとか、新しい機能が使えるようになるとか を与えて
契約を継続させるでしょう。

インセンティブは通常、失った顧客を再獲得するよりも圧倒的にコスト効率が良いのです。

携帯電話会社は、どの顧客が最終的に解約に至り、どの顧客がサービスを使い続けたかに関する履歴記録を持っています。 この履歴情報を使用して、トレーニングと呼ばれるプロセスを使用して、1 つの携帯電話事業者の解約の ML モデルを構築できます。 モデルをトレーニングした後、任意の顧客のプロファイル情報 (モデルのトレーニングに使用したのと同じプロファイル情報) をモデルに渡し、この顧客が解約するかどうかをモデルに予測させることができます。 もちろん、モデルが間違いを犯すことは予想されます。 結局のところ、未来を予測するのは難しい仕事です。 ただ、予測エラーに対処する方法があるので、それを学びます。

私たちが使用するデータセットは公開されており、ダニエル T. ラローズ著『Discovering Knowledge in Data』で言及されています。 著者は、これはカリフォルニア大学アーバイン校の機械学習データセット リポジトリによるものであると考えています。 今すぐそのデータセットをダウンロードして読み込んでみましょう。

詳細は、下記をご確認下さい。日本語版がなかったので、操作しながら、翻訳してみました。
https://sagemaker-examples.readthedocs.io/en/latest/introduction_to_applying_machine_learning/xgboost_customer_churn/xgboost_customer_churn.html

s3 = boto3.client("s3")
s3.download_file(
    f"sagemaker-example-files-prod-{sess.boto_region_name}",
    "datasets/tabular/synthetic/churn.txt",
    "churn.txt",
)
churn = pd.read_csv("./churn.txt")
pd.set_option("display.max_columns", 500)
churn
	State	Account Length	Area Code	Phone	Int'l Plan	VMail Plan	VMail Message	Day Mins	Day Calls	Day Charge	Eve Mins	Eve Calls	Eve Charge	Night Mins	Night Calls	Night Charge	Intl Mins	Intl Calls	Intl Charge	CustServ Calls	Churn?
0	PA	163	806	403-2562	no	yes	300	8.162204	3	7.579174	3.933035	4	6.508639	4.065759	100	5.111624	4.928160	6	5.673203	3	True.
1	SC	15	836	158-8416	yes	no	0	10.018993	4	4.226289	2.325005	0	9.972592	7.141040	200	6.436188	3.221748	6	2.559749	8	False.
2	MO	131	777	896-6253	no	yes	300	4.708490	3	4.768160	4.537466	3	4.566715	5.363235	100	5.142451	7.139023	2	6.254157	4	False.
3	WY	75	878	817-5729	yes	yes	700	1.268734	3	2.567642	2.528748	5	2.333624	3.773586	450	3.814413	2.245779	6	1.080692	6	False.
4	WY	146	878	450-4942	yes	no	0	2.696177	3	5.908916	6.015337	3	3.670408	3.751673	250	2.796812	6.905545	4	7.134343	6	True.
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
4995	NH	4	787	151-3162	yes	yes	800	10.862632	5	7.250969	6.936164	1	8.026482	4.921314	350	6.748489	4.872570	8	2.122530	9	False.
4996	SD	140	836	351-5993	no	no	0	1.581127	8	3.758307	7.377591	7	1.328827	0.939932	300	4.522661	6.938571	2	4.600473	4	False.
4997	SC	32	836	370-3127	no	yes	700	0.163836	5	4.243980	5.841852	3	2.340554	0.939469	450	5.157898	4.388328	7	1.060340	6	False.
4998	MA	142	776	604-2108	yes	yes	600	2.034454	5	3.014859	4.140554	3	3.470372	6.076043	150	4.362780	7.173376	3	4.871900	7	True.
4999	AL	141	657	294-2849	yes	yes	500	1.803907	0	5.125716	8.357508	0	2.109823	2.624299	400	3.713631	5.798783	6	5.485345	7	False.
5000 rows × 21 columns
len(churn.columns)
21

現代の基準からすると、これはわずか 5,000 件のレコードを含む比較的小規模なデータセットであり、各レコードは 21 の属性を使用して米国の未知の携帯電話会社の顧客のプロフィールを記述します。 属性は次のとおりです。

州: 顧客が居住する米国の州。2 文字の略語で示されます。
たとえば、オハイオ州やニュージャージー州など
アカウントの長さ: このアカウントがアクティブであった日数
市外局番: 対応する顧客の電話番号の 3 桁の市外局番
電話番号: 残りの 7 桁の電話番号
国際プラン: 顧客が国際通話プランを持っているかどうか: はい/いいえ
VMail プラン: 顧客がボイスメール機能を持っているかどうか: はい/いいえ
VMail メッセージ: 1 か月あたりのボイスメール メッセージの平均数
Day Mins: 日中に使用された通話の合計分数
日中の通話数: 日中に発信された通話の合計数
Day Charge: 日中の通話料金の請求額
Eve Mins、Eve Calls、Eve Charge: 夜間に発信された通話の請求料金
Night Mins、Night Calls、Night Charge: 夜間にかかった通話の請求料金
Intl Mins、Intl Calls、Intl Charge: 国際電話の請求料金
CustServ Calls: カスタマー サービスにかけられた通話の数
チャーン?: 顧客がサービスをやめたかどうか: true/false
最後の属性、Churn? はターゲット属性として知られており、ML モデルに予測させたい属性です。 ターゲット属性はバイナリであるため、モデルはバイナリ分類とも呼ばれるバイナリ予測を実行します。

データの探索を始めましょう。

# Frequency tables for each categorical feature
for column in churn.select_dtypes(include=["object"]).columns:
    display(pd.crosstab(index=churn[column], columns="% observations", normalize="columns"))

# Histograms for each numeric features
display(churn.describe())
%matplotlib inline
hist = churn.hist(bins=30, sharey=True, figsize=(10, 10))
col_0	% observations
State	
AK	0.0170
AL	0.0200
AR	0.0220
AZ	0.0180
CA	0.0208
CO	0.0182
CT	0.0178
DC	0.0224
DE	0.0182
FL	0.0178
GA	0.0166
HI	0.0190
IA	0.0206
ID	0.0222
IL	0.0198
IN	0.0190
KS	0.0158
KY	0.0182
LA	0.0202
MA	0.0208
MD	0.0226
ME	0.0148
MI	0.0202
MN	0.0220
MO	0.0212
MS	0.0212
MT	0.0180
NC	0.0190
ND	0.0160
NE	0.0218
NH	0.0188
NJ	0.0202
NM	0.0166
NV	0.0198
NY	0.0196
OH	0.0222
OK	0.0186
OR	0.0204
PA	0.0198
RI	0.0240
SC	0.0226
SD	0.0204
TN	0.0194
TX	0.0196
UT	0.0190
VA	0.0198
VT	0.0194
WA	0.0202
WI	0.0170
WV	0.0208
WY	0.0206
col_0	% observations
Phone	
100-2030	0.0002
100-2118	0.0002
100-3505	0.0002
100-5224	0.0002
101-3371	0.0002
...	...
999-3178	0.0002
999-5498	0.0002
999-5816	0.0002
999-8494	0.0002
999-9817	0.0002
4999 rows × 1 columns

col_0	% observations
Int'l Plan	
no	0.5014
yes	0.4986
col_0	% observations
VMail Plan	
no	0.4976
yes	0.5024
col_0	% observations
Churn?	
False.	0.5004
True.	0.4996
Account Length	Area Code	VMail Message	Day Mins	Day Calls	Day Charge	Eve Mins	Eve Calls	Eve Charge	Night Mins	Night Calls	Night Charge	Intl Mins	Intl Calls	Intl Charge	CustServ Calls
count	5000.000000	5000.000000	5000.000000	5000.000000	5000.00000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000	5000.000000
mean	101.675800	773.791400	226.680000	5.518757	3.50460	5.018902	5.026199	3.140400	5.017557	4.000917	224.790000	5.023490	5.025876	5.475400	4.328242	5.525800
std	57.596762	63.470888	273.998527	3.433485	1.68812	2.195759	2.135487	2.525621	2.127857	1.631001	97.302875	1.748900	1.019302	1.877045	2.440311	2.041217
min	1.000000	657.000000	0.000000	0.000215	0.00000	0.004777	0.004659	0.000000	0.013573	0.008468	0.000000	0.054863	1.648514	0.000000	0.000769	0.000000
25%	52.000000	736.000000	0.000000	2.682384	2.00000	3.470151	3.588466	1.000000	3.529613	2.921998	150.000000	3.873157	4.349726	4.000000	2.468225	4.000000
50%	102.000000	778.000000	0.000000	5.336245	3.00000	4.988291	5.145656	3.000000	5.006860	3.962089	200.000000	5.169154	5.034905	5.000000	4.214058	6.000000
75%	151.000000	806.000000	400.000000	7.936151	5.00000	6.559750	6.552962	5.000000	6.491725	5.100128	300.000000	6.272015	5.716386	7.000000	5.960654	7.000000
max	200.000000	878.000000	1300.000000	16.897529	10.00000	12.731936	13.622097	14.000000	12.352871	10.183378	550.000000	10.407778	8.405644	12.000000	14.212261	13.000000

Histograms for each numeric features.PNG

次のことがすぐにわかります。

状態はかなり均等に分散されているようです。
電話番号はユニークな値を持ちすぎて実用にはなりません。 プレフィックスを解析すると何らかの値が得られる可能性がありますが、プレフィックスがどのように割り当てられるかについての詳細なコンテキストがなければ、その使用を避けるべきです。
数値特徴のほとんどは驚くほどうまく分散されており、多くは鐘のようなガウス分布を示しています。 VMail メッセージは注目すべき例外です (そして市外局番は非数値に変換する必要がある機能として表示されます)。

churn = churn.drop("Phone", axis=1)
churn["Area Code"] = churn["Area Code"].astype(object)

次に、各特徴とターゲット変数の関係を見てみましょう。

for column in churn.select_dtypes(include=["object"]).columns:
    if column != "Churn?":
        display(pd.crosstab(index=churn[column], columns=churn["Churn?"], normalize="columns"))

for column in churn.select_dtypes(exclude=["object"]).columns:
    print(column)
    hist = churn[[column, "Churn?"]].hist(by="Churn?", bins=30)
    plt.show()
Churn?	False.	True.
State		
AK	0.015588	0.018415
AL	0.021583	0.018415
AR	0.022782	0.021217
AZ	0.015588	0.020416
CA	0.020384	0.021217
CO	0.018785	0.017614
CT	0.015588	0.020016
DC	0.022382	0.022418
DE	0.018385	0.018014
FL	0.019984	0.015612
GA	0.017986	0.015212
HI	0.019185	0.018815
IA	0.018385	0.022818
ID	0.021583	0.022818
IL	0.021982	0.017614
IN	0.021583	0.016413
KS	0.014788	0.016813
KY	0.017186	0.019215
LA	0.020783	0.019616
MA	0.021183	0.020416
MD	0.019584	0.025620
ME	0.013589	0.016013
MI	0.018785	0.021617
MN	0.022782	0.021217
MO	0.020783	0.021617
MS	0.019584	0.022818
MT	0.017586	0.018415
NC	0.017186	0.020817
ND	0.017186	0.014812
NE	0.019185	0.024420
NH	0.019984	0.017614
NJ	0.022382	0.018014
NM	0.017186	0.016013
NV	0.023181	0.016413
NY	0.015188	0.024019
OH	0.019185	0.025220
OK	0.021183	0.016013
OR	0.019185	0.021617
PA	0.018785	0.020817
RI	0.024380	0.023619
SC	0.021583	0.023619
SD	0.021583	0.019215
TN	0.021982	0.016813
TX	0.019185	0.020016
UT	0.018385	0.019616
VA	0.021183	0.018415
VT	0.022382	0.016413
WA	0.021982	0.018415
WI	0.018785	0.015212
WV	0.019584	0.022018
WY	0.020783	0.020416
Churn?	False.	True.
Area Code		
657	0.037170	0.036829
658	0.022782	0.021217
659	0.015588	0.020416
676	0.020384	0.021217
677	0.018785	0.017614
678	0.015588	0.020016
686	0.040767	0.040432
707	0.019984	0.015612
716	0.017986	0.015212
727	0.019185	0.018815
736	0.039968	0.045637
737	0.043565	0.034027
758	0.031974	0.036029
766	0.020783	0.019616
776	0.054357	0.062050
777	0.062350	0.064452
778	0.037170	0.041233
786	0.053557	0.060048
787	0.059552	0.051641
788	0.038369	0.040432
797	0.040368	0.041233
798	0.019185	0.021617
806	0.018785	0.020817
827	0.024380	0.023619
836	0.043165	0.042834
847	0.021982	0.016813
848	0.019185	0.020016
858	0.018385	0.019616
866	0.021183	0.018415
868	0.022382	0.016413
876	0.021982	0.018415
877	0.018785	0.015212
878	0.040368	0.042434
Churn?	False.	True.
Int'l Plan		
no	0.5	0.502802
yes	0.5	0.497198
Churn?	False.	True.
VMail Plan		
no	0.496403	0.498799
yes	0.503597	0.501201

Account Length
1.PNG

VMail Message
2.PNG

Day Mins
day.PNG

Day Calls
Day Calls.PNG

Day Charge
Day Charge.PNG

Eve Mins
Eve Mins.PNG

Eve Calls
Eve Calls.PNG

Eve Charge
Eve Charge.PNG

Night Mins
Night Mins.PNG

Night Calls
Night Calls.PNG

Night Charge
Night Charge.PNG

Intl Mins
Intl Mins.PNG

Intl Calls
Intl Calls.PNG

Intl Charge
Intl Charge.PNG

CustServ Calls
CustServ Calls.PNG

display(churn.corr(numeric_only=True))
pd.plotting.scatter_matrix(churn, figsize=(12, 12))
plt.show()

Account Length	VMail Message	Day Mins	Day Calls	Day Charge	Eve Mins	Eve Calls	Eve Charge	Night Mins	Night Calls	Night Charge	Intl Mins	Intl Calls	Intl Charge	CustServ Calls
Account Length	1.000000	-0.009030	-0.015878	0.011659	-0.007468	0.000213	0.026515	-0.012795	0.016400	-0.002383	-0.034925	0.017277	-0.003735	0.028285	-0.036721
VMail Message	-0.009030	1.000000	-0.143272	0.002762	-0.182712	-0.104667	-0.101240	-0.029212	0.061370	0.135042	-0.155475	-0.015162	0.131964	0.010120	0.068657
Day Mins	-0.015878	-0.143272	1.000000	-0.087598	0.667941	0.482641	-0.184939	0.766489	0.188190	-0.445212	0.570508	0.001988	0.236131	0.239331	-0.195322
Day Calls	0.011659	0.002762	-0.087598	1.000000	-0.222556	0.033903	0.185881	-0.052051	-0.085222	-0.083050	0.046641	-0.022548	-0.045671	-0.120064	-0.065518
Day Charge	-0.007468	-0.182712	0.667941	-0.222556	1.000000	0.574697	0.236626	0.371580	0.150700	-0.130722	0.374861	0.010294	0.119584	0.251748	-0.260945
Eve Mins	0.000213	-0.104667	0.482641	0.033903	0.574697	1.000000	-0.067123	0.269980	-0.090515	0.067315	0.317481	-0.015678	0.070456	0.448910	-0.167347
Eve Calls	0.026515	-0.101240	-0.184939	0.185881	0.236626	-0.067123	1.000000	-0.467814	0.221439	0.218149	-0.324936	-0.001593	-0.112062	0.017036	-0.433467
Eve Charge	-0.012795	-0.029212	0.766489	-0.052051	0.371580	0.269980	-0.467814	1.000000	0.184230	-0.454649	0.546137	-0.003569	0.164104	0.243936	-0.011019
Night Mins	0.016400	0.061370	0.188190	-0.085222	0.150700	-0.090515	0.221439	0.184230	1.000000	-0.223023	-0.140482	-0.012781	0.038831	0.271179	-0.332802
Night Calls	-0.002383	0.135042	-0.445212	-0.083050	-0.130722	0.067315	0.218149	-0.454649	-0.223023	1.000000	-0.390333	-0.009821	0.181237	-0.155736	0.110211
Night Charge	-0.034925	-0.155475	0.570508	0.046641	0.374861	0.317481	-0.324936	0.546137	-0.140482	-0.390333	1.000000	0.012585	-0.009720	-0.330772	0.439805
Intl Mins	0.017277	-0.015162	0.001988	-0.022548	0.010294	-0.015678	-0.001593	-0.003569	-0.012781	-0.009821	0.012585	1.000000	-0.007220	-0.010907	-0.008672
Intl Calls	-0.003735	0.131964	0.236131	-0.045671	0.119584	0.070456	-0.112062	0.164104	0.038831	0.181237	-0.009720	-0.007220	1.000000	-0.233809	-0.012260
Intl Charge	0.028285	0.010120	0.239331	-0.120064	0.251748	0.448910	0.017036	0.243936	0.271179	-0.155736	-0.330772	-0.010907	-0.233809	1.000000	-0.661833
CustServ Calls	-0.036721	0.068657	-0.195322	-0.065518	-0.260945	-0.167347	-0.433467	-0.011019	-0.332802	0.110211	0.439805	-0.008672	-0.012260	-0.661833	1.000000

matrix.PNG

本質的に相互に 100% の相関関係を持ついくつかの特徴が見られます。 一部の機械学習アルゴリズムにこれらの機能ペアを含めると、致命的な問題が発生する可能性がありますが、他のアルゴリズムでは、軽度の冗長性とバイアスが導入されるだけです。 相関性の高い各ペアから 1 つの特徴を削除しましょう。日中分とのペアから日中料金、夜間分とのペアから夜間料金、国際分とのペアから国際料金です。

churn = churn.drop(["Day Charge", "Eve Charge", "Night Charge", "Intl Charge"], axis=1)

データセットをクリーンアップしたので、使用するアルゴリズムを決定しましょう。 上で述べたように、高値と低値 (中間ではない) の両方がチャーンを予測する変数がいくつかあるようです。 これを線形回帰のようなアルゴリズムに対応させるには、多項式 (またはバケット化された) 項を生成する必要があります。 代わりに、勾配ブースト ツリーを使用してこの問題をモデル化してみましょう。 Amazon SageMaker は、管理された分散設定でトレーニングし、リアルタイム予測エンドポイントとしてホストするために使用できる XGBoost コンテナを提供します。 XGBoost は、特徴とターゲット変数の間の非線形関係を自然に考慮し、特徴間の複雑な相互作用に対応する勾配ブースト ツリーを使用します。

Amazon SageMaker XGBoost は、CSV または LibSVM 形式のデータでトレーニングできます。 この例では、CSV をそのまま使用します。 それはすべきです

最初の列に予測子変数を含めます
ヘッダー行がない
まず、カテゴリ特徴を数値特徴に変換しましょう。

model_data = pd.get_dummies(churn)
model_data = pd.concat(
    [model_data["Churn?_True."], model_data.drop(["Churn?_False.", "Churn?_True."], axis=1)], axis=1
)
model_data = model_data.astype(float)

次に、データをトレーニング、検証、テスト セットに分割しましょう。 これにより、モデルの過剰適合を防止し、まだ確認されていないデータでモデルの精度をテストできるようになります。

train_data, validation_data, test_data = np.split(
    model_data.sample(frac=1, random_state=1729),
    [int(0.7 * len(model_data)), int(0.9 * len(model_data))],
)
train_data.to_csv("train.csv", header=False, index=False)
validation_data.to_csv("validation.csv", header=False, index=False)
len(train_data.columns)
100

次に、これらのファイルを S3 にアップロードします。

boto3.Session().resource("s3").Bucket(bucket).Object(
    os.path.join(prefix, "train/train.csv")
).upload_file("train.csv")
boto3.Session().resource("s3").Bucket(bucket).Object(
    os.path.join(prefix, "validation/validation.csv")
).upload_file("validation.csv")

トレーニングに進むと、まず XGBoost アルゴリズム コンテナの場所を指定する必要があります。

container = sagemaker.image_uris.retrieve("xgboost", sess.boto_region_name, "1.7-1")
display(container)

次に、CSV ファイル形式でトレーニングしているため、トレーニング関数が S3 内のファイルへのポインターとして使用できる TrainingInput を作成します。

s3_input_train = TrainingInput(
    s3_data="s3://{}/{}/train".format(bucket, prefix), content_type="csv"
)
s3_input_validation = TrainingInput(
    s3_data="s3://{}/{}/validation/".format(bucket, prefix), content_type="csv"
)

これで、使用するトレーニング インスタンスの種類と数などのいくつかのパラメーターと、XGBoost ハイパーパラメーターを指定できるようになりました。 いくつかの主要なハイパーパラメータは次のとおりです。

max_ Depth は、アルゴリズム内の各ツリーを構築できる深さを制御します。 ツリーが深くなると適合性が向上しますが、計算コストが高くなり、過剰適合につながる可能性があります。 通常、モデルのパフォーマンスにはトレードオフがあり、多数の浅いツリーと少数の深いツリーの間で検討する必要があります。
subsample はトレーニング データのサンプリングを制御します。 この手法は過剰適合を減らすのに役立ちますが、設定が低すぎるとデータ モデルが不足する可能性もあります。
num_round はブースティング ラウンドの数を制御します。 これは本質的に、前の反復の残差を使用してトレーニングされる後続のモデルです。 繰り返しになりますが、ラウンド数を増やすと、トレーニング データにより適切な適合が得られますが、計算コストが高くなったり、過剰適合につながる可能性があります。
eta はブースティングの各ラウンドの攻撃性を制御します。 値を大きくすると、ブーストがより控えめになります。
ガンマは、木がどれだけ積極的に成長するかを制御します。 値が大きいほど、より保守的なモデルになります。
XGBoost のハイパーパラメータの詳細については、GitHub ページを参照してください。

sess = sagemaker.Session()

xgb = sagemaker.estimator.Estimator(
    container,
    role,
    instance_count=1,
    instance_type="ml.m4.xlarge",
    output_path="s3://{}/{}/output".format(bucket, prefix),
    sagemaker_session=sess,
)
xgb.set_hyperparameters(
    max_depth=5,
    eta=0.2,
    gamma=4,
    min_child_weight=6,
    subsample=0.8,
    verbosity=0,
    objective="binary:logistic",
    num_round=100,
)

xgb.fit({"train": s3_input_train, "validation": s3_input_validation})

アルゴリズムをトレーニングしたので、モデルを作成して、ホストされているエンドポイントにデプロイしましょう。

xgb_predictor = xgb.deploy(
    initial_instance_count=1, instance_type="ml.m4.xlarge", serializer=CSVSerializer()
)

評価する
ホストされたエンドポイントが実行されているので、http POST リクエストを行うだけで、モデルからリアルタイムの予測を非常に簡単に行うことができます。 ただし、最初に、test_data NumPy 配列をエンドポイントの背後にあるモデルに渡すためのシリアライザーとデシリアライザーをセットアップする必要があります。

ここで、簡単な関数を使用して次のことを行います。

テスト データセットをループします
行のミニバッチに分割します
それらのミニバッチを CSV 文字列ペイロードに変換します
XGBoost エンドポイントを呼び出してミニバッチ予測を取得する
予測を収集し、モデルが提供する CSV 出力から NumPy 配列に変換します。

def predict(data, rows=500):
    split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))
    predictions = ""
    for array in split_array:
        predictions = "".join([predictions, xgb_predictor.predict(array).decode("utf-8")])

    return predictions.split("\n")[:-1]


predictions = predict(test_data.to_numpy()[:, 1:])
predictions = np.array([float(num) for num in predictions])
print(predictions)
[1.45354465e-01 9.78284001e-01 2.29995721e-03 4.43856046e-03
 4.18119103e-01 9.61098611e-01 9.88697410e-01 7.09247828e-01
 9.00699615e-01 9.90865409e-01 9.16577816e-01 9.16562043e-03
 1.32644475e-02 9.32690680e-01 9.90951777e-01 9.90174174e-01
 9.90583062e-01 6.71710148e-02 9.68500853e-01 9.88211036e-01
 9.49407101e-01 3.49131855e-03 6.24201074e-03 9.03135717e-01
 9.15755212e-01 3.02289635e-01 9.91930068e-01 1.00098653e-02
 9.89490330e-01 2.89308596e-02 3.35139036e-02 4.30867486e-02
 9.57395673e-01 7.48302042e-03 4.56276210e-03 2.22346629e-03
 6.68902099e-01 3.98493201e-01 9.89661694e-01 9.71789002e-01
 8.21430683e-01 9.90990996e-01 4.96863037e-01 9.68719840e-01
 2.86707142e-03 9.22252357e-01 8.58608517e-04 1.73726588e-01
 9.83624935e-01 4.02040966e-03 1.40238460e-02 9.71329868e-01
 9.67991769e-01 1.32331580e-01 1.07120641e-01 9.92310762e-01
 2.77149454e-02 4.27350076e-03 9.49290674e-03 2.99536940e-02
 3.51974726e-01 9.51662481e-01 9.29118633e-01 3.48184891e-02
 1.18576223e-02 9.74126458e-01 9.93539035e-01 1.40150879e-02
 9.77024496e-01 9.83173251e-01 9.95684385e-01 9.78204787e-01
 9.88796592e-01 8.59204680e-03 1.41497612e-01 2.97577064e-02
 1.25241457e-02 9.63524282e-01 8.11392162e-03 2.85406202e-01
 1.65301427e-01 9.58968878e-01 4.39115660e-03 9.10739958e-01
 9.48271096e-01 4.49727356e-01 7.44590536e-02 1.07207075e-01
 2.51774257e-03 2.84799263e-02 1.14294756e-02 4.73506562e-03
 6.37701713e-03 1.22722927e-02 1.83815416e-02 8.98610801e-03
 6.26715362e-01 9.20443892e-01 6.18711160e-03 7.90525019e-01
 9.76670146e-01 1.01614125e-01 5.69215775e-01 2.14107353e-02
 8.74110013e-02 6.63182139e-01 7.77046561e-01 4.70338296e-03
 9.15105343e-01 7.95063853e-01 2.42506526e-02 2.71020848e-02
 8.79000366e-01 9.80753362e-01 7.46096671e-01 9.89202201e-01
 7.03168929e-01 9.90700841e-01 1.61304278e-03 4.45735082e-03
 5.88556342e-02 2.29934603e-02 9.86055076e-01 6.79414034e-01
 7.88709149e-03 9.69462633e-01 1.56352744e-02 2.56143987e-01
 3.67543101e-02 6.26953959e-01 6.61722012e-03 9.94859099e-01
 3.82713467e-01 5.97921491e-01 3.44944117e-03 3.09076309e-02
 1.86218007e-03 4.09094803e-03 8.22253644e-01 9.86833751e-01
 2.58615706e-03 1.62289676e-03 9.95795965e-01 9.90770459e-01
 2.96862692e-01 7.15770602e-01 9.34843600e-01 5.03771473e-03
 6.77827187e-03 2.05048500e-03 6.56762570e-02 9.79595363e-01
 1.09374665e-01 9.62667227e-01 7.79664963e-02 9.73859191e-01
 9.93983448e-01 1.42154479e-02 1.61886781e-01 9.97698247e-01
 8.77197206e-01 8.69093637e-04 9.92125094e-01 2.23713927e-02
 9.88184571e-01 8.84966776e-02 2.16101762e-03 9.95631695e-01
 4.23407555e-03 9.90727663e-01 1.02155237e-03 5.47027737e-02
 7.75001466e-01 1.79885551e-02 1.10104810e-02 9.07850027e-01
 7.23688025e-03 4.51596221e-03 3.75032648e-02 9.93271172e-01
 9.73003685e-01 4.83639119e-03 9.65361834e-01 6.54217741e-03
 9.95753646e-01 6.49030745e-01 5.40636480e-02 9.92799640e-01
 6.01739824e-01 4.85836603e-02 3.84727806e-01 4.02415404e-03
 4.33045113e-03 8.95203292e-01 5.23443460e-01 9.80668366e-01
 6.31612718e-01 9.90977764e-01 5.78983724e-01 3.31457867e-03
 9.74618316e-01 2.30999030e-02 9.54991281e-01 1.45205874e-02
 9.97269809e-01 1.89454347e-01 8.36716473e-01 9.74477470e-01
 1.55935399e-02 5.11737943e-01 2.02344228e-02 1.94717404e-02
 1.00200446e-02 9.00413930e-01 5.37246287e-01 9.03650284e-01
 3.99313904e-02 9.74407196e-01 4.92372066e-01 9.13743556e-01
 8.71497393e-01 5.27827116e-03 9.07855630e-01 9.94372368e-01
 2.45616809e-01 8.43104303e-01 1.54190627e-03 4.28327173e-01
 9.61117983e-01 9.89492893e-01 5.14664412e-01 1.10597778e-02
 9.63337958e-01 1.25640079e-01 9.98162851e-03 6.83929352e-03
 7.87677709e-03 9.96537089e-01 5.43673988e-03 9.78044689e-01
 9.81553078e-01 1.77942531e-03 7.96547998e-03 3.91911775e-01
 9.94523644e-01 2.79172093e-01 9.77499545e-01 6.28175866e-03
 1.17029762e-03 9.94890809e-01 5.88768184e-01 9.42278087e-01
 9.96033132e-01 9.92791295e-01 9.86207962e-01 2.58573472e-01
 4.06831736e-03 7.12859809e-01 9.94767785e-01 6.28266215e-01
 2.41812738e-03 9.63064671e-01 8.39839101e-01 8.68149102e-03
 8.24116636e-03 5.56686580e-01 9.15104931e-04 6.50778413e-01
 1.01858182e-02 3.05671901e-01 5.02485549e-03 1.24349468e-03
 7.90653348e-01 9.89975989e-01 8.76368940e-01 9.71927285e-01
 6.83342457e-01 9.93650496e-01 9.19665396e-02 9.93050933e-01
 9.97511625e-01 2.76980898e-03 8.08949113e-01 4.89457250e-01
 6.13950416e-02 2.85638473e-03 4.52913754e-02 9.97025192e-01
 9.90609288e-01 6.60308823e-02 8.47459018e-01 9.58468556e-01
 8.90724003e-01 8.45401347e-01 9.65758562e-01 1.99787281e-02
 2.73887128e-01 9.94841278e-01 8.89521003e-01 4.60449234e-03
 9.87804651e-01 9.91844594e-01 9.96751666e-01 9.58904624e-01
 2.66824901e-01 9.87034202e-01 9.96843815e-01 5.47693390e-03
 1.18411712e-01 7.17594405e-04 9.90793645e-01 9.98004377e-01
 9.73549426e-01 4.81882459e-03 3.80819798e-01 2.75687903e-01
 9.07744348e-01 9.94196653e-01 5.89165986e-01 5.11417210e-01
 3.49242985e-01 7.47146130e-01 7.37948064e-03 9.95374501e-01
 9.48148370e-01 1.39009638e-03 5.12049496e-02 9.09888983e-01
 1.35909393e-03 9.98756051e-01 8.90588939e-01 4.58958978e-03
 9.51509297e-01 4.76172613e-03 9.45550144e-01 9.36734378e-01
 1.23661561e-02 9.85040426e-01 1.84003323e-01 6.54030621e-01
 9.97584462e-01 9.71761107e-01 9.68084395e-01 8.06119323e-01
 5.35942376e-01 4.46534157e-03 6.75057769e-01 9.83872950e-01
 5.27089760e-02 9.83510137e-01 9.88768339e-01 8.59659731e-01
 5.03594995e-01 8.32698822e-01 3.37436348e-02 4.00324836e-02
 8.42627704e-01 7.54910568e-03 5.10879867e-02 6.13431692e-01
 7.89988399e-01 4.42159688e-03 9.67317998e-01 9.52267289e-01
 9.95635331e-01 2.11556442e-02 9.70325112e-01 6.66326901e-04
 5.52287605e-03 2.13798732e-02 4.34392065e-01 1.06131583e-01
 8.87469828e-01 2.03434541e-03 9.97525871e-01 7.66908526e-01
 9.04627562e-01 7.00816629e-04 2.62071360e-02 9.36227620e-01
 9.74594295e-01 2.62008756e-01 2.16106735e-02 4.35848087e-01
 9.95881677e-01 8.77979770e-03 9.78187323e-01 4.58735507e-03
 1.72593091e-02 9.75049555e-01 9.62046981e-01 3.63443582e-03
 9.25317407e-03 6.51314855e-01 5.12395948e-02 9.66752350e-01
 1.06241509e-01 1.13275245e-01 9.94940758e-01 4.00280021e-03
 4.71546799e-01 9.86425400e-01 8.42836797e-01 9.89564657e-01
 4.30389354e-03 9.96762514e-01 1.49545923e-01 3.36292312e-02
 7.80505538e-01 1.24553464e-01 9.93947148e-01 9.28793669e-01
 6.44920915e-02 9.74430621e-01 4.60103853e-03 5.73051989e-01
 6.59529030e-01 3.84716779e-01 3.15497786e-01 9.96967375e-01
 9.17054176e-01 6.42081425e-02 1.78040806e-02 4.56508622e-02
 8.51092339e-01 8.18399191e-01 9.51399922e-01 9.10877134e-04
 3.33398860e-03 2.79733628e-01 9.49600458e-01 3.71026754e-01
 9.87830222e-01 3.94266914e-04 5.92266172e-02 9.48842585e-01
 9.18808818e-01 9.43920016e-03 8.98346126e-01 3.77703644e-02
 6.16019487e-01 5.42564318e-02 3.75937134e-01 9.51113522e-01
 9.20304239e-01 6.49452135e-02 8.95597681e-04 8.86587501e-01
 4.57955524e-02 7.92529225e-01 4.57344055e-02 1.38586555e-02
 8.97114873e-01 5.45595912e-03 9.68935549e-01 9.57905650e-01
 1.33015960e-01 4.85805841e-03 4.04586375e-01 9.96664703e-01
 1.75451767e-02 6.70278668e-02 9.84654307e-01 9.91094708e-01
 4.70215781e-03 6.05097376e-02 4.07888740e-03 9.27925885e-01
 8.11935402e-03 9.87158597e-01 1.82459988e-02 7.46906340e-01
 4.22766712e-03 3.05428356e-03 4.20122407e-02 1.78867783e-02
 2.15065870e-02 4.90936115e-02 2.81943604e-02 9.27526414e-01
 2.83649494e-03 7.00003561e-03 2.43664929e-03 9.69419181e-01
 9.74273980e-01 9.84279633e-01 8.29754412e-01 9.63483989e-01
 9.26144779e-01 3.10315769e-02 9.82711852e-01 2.47274399e-01
 5.10260239e-02 9.78781641e-01 4.42271121e-02 8.71273756e-01
 9.96401072e-01 9.71360147e-01 9.72122014e-01 9.76645887e-01]

機械学習モデルのパフォーマンスを比較するにはさまざまな方法がありますが、実際の値と予測値を単純に比較することから始めましょう。 この場合、顧客が解約したか (1) か否か (0) を単純に予測しており、混同行列が生成されます。

pd.crosstab(
    index=test_data.iloc[:, 0],
    columns=np.round(predictions),
    rownames=["actual"],
    colnames=["predictions"],
)
predictions	0.0	1.0
actual		
0.0	235	18
1.0	11	236

48 人の解約者のうち、39 人 (真陽性) を正しく予測しました。 また、4 人の顧客が解約するだろうと誤って予測しましたが、結局解約しませんでした (誤検知)。 また、解約しないと予測していたが、最終的に解約した顧客も 9 名います (偽陰性)。

ここで重要な点は、上記の np.round() 関数により、単純なしきい値 (またはカットオフ) 0.5 を使用していることです。 xgboost からの予測では 0 から 1 までの連続値が得られ、それらを最初のバイナリ クラスに強制的に入れます。 ただし、顧客が離脱すると、離脱する可能性があると思われる顧客を積極的に引き留めようとするよりも企業にとってコストが高くなることが予想されるため、このカットオフ値を下げることを検討する必要があります。 これにより、ほぼ確実に偽陽性の数が増加しますが、真陽性の数が増加し、偽陰性の数が減少することも期待できます。

ここで大まかな直感を得るために、予測の連続値を見てみましょう。

plt.hist(predictions)
plt.xlabel("Predicted churn probability")
plt.ylabel("Number of customers")
plt.show()

predict.PNG

私たちのモデルから得られる連続値の予測は 0 または 1 に偏る傾向がありますが、0.1 と 0.9 の間には十分な質量があり、カットオフを調整することで実際に多くの顧客の予測が変化するはずです。 例えば...

pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))
col_0	0	1
Churn?_True.		
0.0	221	32
1.0	2	245

カットオフを 0.5 から 0.3 に下げると、真陽性が 1 件増え、偽陽性が 3 件増え、偽陰性が 1 件減ることがわかります。 ここでの数字は全体としては小さいですが、カットオフの変更により移行している顧客は全体の 6 ~ 10% に相当します。 これは正しい決断でしたか? 最終的にはさらに 3 人の顧客を維持することになるかもしれませんが、どうせ留まるであろうさらに 5 人の顧客にも不必要にインセンティブを与えました。 最適なカットオフを決定することは、機械学習を現実の環境に適切に適用するための重要なステップです。 これについてより広範に議論してから、現在の問題に具体的な仮説的な解決策を適用してみましょう。

エラーの相対コスト
実際のバイナリ分類問題では、同様に敏感なカットオフが生成される可能性があります。 それ自体は問題ありません。 結局のところ、2 つのクラスのスコアが本当に簡単に分離できる場合、問題はおそらく最初はそれほど難しくなく、ML の代わりに決定論的なルールを使用して解決できる可能性さえあります。

さらに重要なのは、ML モデルを運用環境に導入すると、モデルが誤って偽陽性と偽陰性を割り当てることに関連したコストが発生することです。 また、真陽性と真陰性の正確な予測に関連する同様のコストも検討する必要があります。 カットオフの選択はこれら 4 つの統計すべてに影響するため、予測ごとにこれら 4 つの結果ごとにビジネスに対する相対コストを考慮する必要があります。

コストの割り当て
携帯電話会社の解約という問題の代償はどれくらいでしょうか? もちろん、コストは企業が行う特定のアクションによって異なります。 ここでいくつかの仮定を立ててみましょう。

まず、真の陰性者にコスト 0 を割り当てます。 この場合、私たちのモデルは基本的に満足している顧客を正しく識別したため、何もする必要はありません。

偽陰性が最も問題となるのは、離反した顧客が継続すると誤って予測するためです。 私たちは顧客を失い、代わりの顧客を獲得するためのすべてのコストを支払わなければなりません。これには、取り残された収益、広告費、管理費、POS コスト、そしておそらく電話ハードウェアの補助金も含まれます。 インターネットで簡単に検索すると、そのような費用は通常数百ドルかかることがわかります。そのため、この例では 500 ドルと仮定します。 これは偽陰性のコストです。

最後に、モデルが離脱していると識別した顧客に対して、100 ドルの維持インセンティブが与えられると仮定します。 プロバイダーが顧客にそのような譲歩を提供した場合、顧客は離れる前によく考えるかもしれません。 これは、真陽性結果と偽陽性結果の両方のコストです。 誤検知の場合(顧客は満足しているが、モデルが誤ってチャーンを予測した場合)、100 ドルの譲歩を「無駄」にします。 おそらくその 100 ドルをもっと効果的に使えたはずですが、すでに忠実な顧客の忠誠心を高めた可能性があるので、それほど悪くはありません。

最適なカットオフを見つける
偽陰性は偽陽性よりも大幅にコストがかかることは明らかです。 顧客数に基づいて誤差を最適化するのではなく、次のようなコスト関数を最小化する必要があります。

$500 * FN(C) + $0 * TN(C) + $100 * FP(C) + $100 * TP(C)
FN(C) は、偽陰性パーセンテージがカットオフ C の関数であることを意味し、TN、FP、TP についても同様です。 式の結果が最小となるカットオフ C を見つける必要があります。

これを行う簡単な方法は、多数の考えられるカットオフでシミュレーションを実行することです。 以下の for ループで 100 個の可能な値をテストします。

cutoffs = np.arange(0.01, 1, 0.01)
costs = []
for c in cutoffs:
    costs.append(
        np.sum(
            np.sum(
                np.array([[0, 100], [500, 100]])
                * pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > c, 1, 0))
            )
        )
    )

costs = np.array(costs)
plt.plot(cutoffs, costs)
plt.xlabel("Cutoff")
plt.ylabel("Cost")
plt.show()

cutoff.PNG

print(
    "Cost is minimized near a cutoff of:",
    cutoffs[np.argmin(costs)],
    "for a cost of:",
    np.min(costs),
)

コストはカットオフ付近で最小化されます: 0.32、コスト: 28400
上のグラフは、すべての顧客に維持インセンティブが与えられるため、しきい値の選択が低すぎるとコストが急増することを示しています。 一方、しきい値を高く設定しすぎると、多くの顧客を失うことになり、最終的にはほぼ同じコストがかかるようになります。 カットオフを 0.46 に設定すると、全体のコストを 8,400 ドルに最小限に抑えることができます。これは、何もアクションを起こさない場合に損失が予想される 20,000 ドル以上よりも大幅に優れています。

拡張機能
このノートブックでは、顧客が離脱する可能性があるかどうかを予測するモデルを構築する方法と、真陽性、偽陽性、および偽陰性のコストを考慮したしきい値を最適に設定する方法を紹介しました。 これを拡張するには、次のようないくつかの方法があります。

維持インセンティブを受け取った顧客の中には、依然として離脱する人もいます。 インセンティブを受け取ったにもかかわらず解約する確率をコスト関数に含めることで、リテンション プログラムの ROI が向上します。
低価格プランに切り替える顧客、または有料機能を無効にする顧客は、個別にモデル化できるさまざまな種類の顧客離れを表しています。
顧客の行動の進化をモデル化します。 使用量が減少し、カスタマー サービスへの通話数が増加している場合、傾向が逆の場合よりもチャーンが発生する可能性が高くなります。 顧客プロファイルには行動傾向を組み込む必要があります。
実際のトレーニング データと金銭的コストの割り当ては、より複雑になる可能性があります。
チャーンのタイプごとに複数のモデルが必要になる場合があります。
さらなる複雑さに関係なく、このノートブックで説明されている同様の原則が適用される可能性があります。

顧客離れの予測 AWS SageMaker GUI版(2/2)
https://qiita.com/kimuni-i/items/05d9525042eac6774b2c

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3