1
0

More than 3 years have passed since last update.

分類木の動作確認用サンプルデータセットを生成するpythonスクリプト

Last updated at Posted at 2021-01-04

やりたいこと

  • 機械学習で使うダミーのデータセット(顧客リストのサンプル)を作成したい。
  • 顧客リストはCSVファイルに出力したい。
  • 顧客リストの特徴を任意に設定したい。
  • 顧客リストの特徴とは「年配の男性会社役員が買っている」「関西の既婚女性が買っている」といったデータの傾向。
  • そのためのスクリプトをpythonで実装したい。

生成したデータセットの利用例

決定木(分類木)を実装して、そのモデルがデータセットを正しく分類できるか?
を確かめたい場合、モデルに読み込ませるCSVファイルとして利用する。
CSV作成時に特徴を任意に設定できるので、モデルがその特徴を分類できたなら、
モデルが正しく動作していると判断できる。

生成したいデータセットの仕様

カラム

  • 性別: 顧客の性別(1:男性 / 0:女性)
  • 年齢: 顧客の年齢(数値)
  • 婚姻: 顧客の婚姻状況(0:未婚 / 1:既婚)
  • 都道府県: 顧客が住んでいる都道府県(47都道府県のいずれか)
  • 職業: 顧客の職業(1~8のいずれか。項目は後述)
  • 趣味: 顧客の趣味(1~7のいずれか。項目は後述)
  • y: 目的変数。たとえば「ある商品を購入したか、していないか」('yes:購入した' / 'no:購入していない')

職業

1:会社員
2:会社役員
3:自営業
4:フリーランス
5:学生
6:専業主婦
7:無職
8:その他

趣味

  • 1:読書
  • 2:映画
  • 3:音楽
  • 4:アウトドア
  • 5:自動車
  • 6:オートバイ
  • 7:テレビゲーム

データの例

たとえば下記のようなデータセットを生成したい。

image.png

この例では、
1行目の顧客は、
男性で、55歳で、既婚者で、兵庫県に住んでいて、会社員で、オートバイが趣味であり、商品を購入した顧客である。

2行目の顧客は、
女性で、31歳で、未婚で、和歌山県に住んでいて、会社員で、自動車が趣味であり、商品を購入していない顧客である。

左右の表は、いずれも同じデータセットである。
左の表は、各値を日本語で表記したものである。右の表は、各値を数値で表記したものである。

このような顧客リストがあるとして、
「商品を買った人の特徴(傾向)を知りたい」というケース。
男女別では、どちらが多く買っているのか?
住んでいる地域によって違いはあるのか?
年齢や職業は関係あるのか?
など「どんな人が買ったのか?」の傾向が分かれば、マーケティングの施策に役立つ。
たとえば、新規でダイレクトメールを打つ場合、反応率を最大化するためには、
送付先の地域を限定すべきか?年齢や性別で絞るべきか?などが判断しやすくなる。

実装

import random
import csv

##################################################
# マスタデータを定義する。
##################################################
def getMstHobby():
    mst = []
    mst.append('趣味')
    mst.append('読書')
    mst.append('映画')
    mst.append('音楽')
    mst.append('アウトドア')
    mst.append('自動車')
    mst.append('オートバイ')
    mst.append('テレビゲーム')
    return mst

def getMstJob():
    mst = []
    mst.append('職業')
    mst.append('会社員')
    mst.append('会社役員')
    mst.append('自営業')
    mst.append('フリーランス')
    mst.append('学生')
    mst.append('専業主婦')
    mst.append('無職')
    mst.append('その他')
    return mst

def getMstGender():
    mst = []
    mst.append('女性')
    mst.append('男性')
    return mst

def getMstMarriage():
    mst = []
    mst.append('未婚')
    mst.append('既婚')
    return mst

def getMstPref():
    mst = []
    mst.append('都道府県')
    mst.append('北海道')
    mst.append('青森県')
    mst.append('岩手県')
    mst.append('宮城県')
    mst.append('秋田県')
    mst.append('山形県')
    mst.append('福島県')
    mst.append('茨城県')
    mst.append('栃木県')
    mst.append('群馬県')
    mst.append('埼玉県')
    mst.append('千葉県')
    mst.append('東京都')
    mst.append('神奈川県')
    mst.append('新潟県')
    mst.append('富山県')
    mst.append('石川県')
    mst.append('福井県')
    mst.append('山梨県')
    mst.append('長野県')
    mst.append('岐阜県')
    mst.append('静岡県')
    mst.append('愛知県')
    mst.append('三重県')
    mst.append('滋賀県')
    mst.append('京都府')
    mst.append('大阪府')
    mst.append('兵庫県')
    mst.append('奈良県')
    mst.append('和歌山県')
    mst.append('鳥取県')
    mst.append('島根県')
    mst.append('岡山県')
    mst.append('広島県')
    mst.append('山口県')
    mst.append('徳島県')
    mst.append('香川県')
    mst.append('愛媛県')
    mst.append('高知県')
    mst.append('福岡県')
    mst.append('佐賀県')
    mst.append('長崎県')
    mst.append('熊本県')
    mst.append('大分県')
    mst.append('宮崎県')
    mst.append('鹿児島県')
    mst.append('沖縄県')
    return mst

##################################################
# リストを受け取りcsvファイルに出力する。
##################################################
def outputCsv(fileName, listData):
    f = open(fileName, 'w')
    writer = csv.writer(f, lineterminator='\n')
    writer.writerows(listData)
    f.close()    

##################################################
# データセットの特徴を設定する。
##################################################
def setFeatures(gender, age, marriage, pref, job, hobby):

    if gender == '男性' and age >= 40 and job == '会社員':
        return 'yes'

    if gender == '女性' and age <= 29:
        return 'yes'

    if marriage == '既婚' and job == '会社員':
        return 'yes'

    return 'no'

##################################################
# 処理開始。
##################################################

mst_gender = getMstGender()
mst_pref = getMstPref()
mst_job = getMstJob()
mst_hobby = getMstHobby()
mst_marriage = getMstMarriage()

users_label = []
users_int = []

# CSVファイルのヘッダ。
csv_header_en = ['gender', 'age', 'marriage', 'pref', 'job', 'hobby', 'y']
csv_header_jp = ['性別', '年齢', '婚姻', '都道府県', '職業', '趣味', 'y']
users_int.append(csv_header_en)
users_label.append(csv_header_jp)

# 作成するCSVのレコード件数。
recordNum = 50000

# CSVレコードを生成する。
for num in range(recordNum):
    ageInt = random.randint(20, 80) # 最少年齢と最大年齢
    prefInt = random.randint(1, 47) # 都道府県番号。北海道なら1, 東京都なら13
    prefLabel = mst_pref[prefInt]
    genderInt = random.randint(0, 1) # 0:女性, 1:男性
    genderLabel = mst_gender[genderInt]
    marriageInt = random.randint(0, 1) # 0:未婚, 1:既婚
    marriageLabel = mst_marriage[marriageInt] 
    jobInt = random.randint(1, 7) # 「その他」を除く職業
    jobLabel = mst_job[jobInt]
    hobbyInt = random.randint(1, 7)
    hobbyLabel = mst_hobby[hobbyInt]
    y = setFeatures(genderLabel, ageInt, marriageLabel, prefLabel, jobLabel, hobbyLabel)
    dataInt = [genderInt, ageInt, marriageInt, prefInt, jobInt, hobbyInt, y]
    dataLabel = [genderLabel, ageInt, marriageLabel, prefLabel, jobLabel, hobbyLabel, y]
    users_int.append(dataInt)
    users_label.append(dataLabel)

# CSVファイルに出力する。
outputCsv('out_int.csv', users_int)
outputCsv('out_label.csv', users_label)


解説

まずランダムに顧客を生成して、
目的変数y(買ったか、買わなかったか)は、生成された各カラムの値に応じて、
条件文で分岐させることで「どんな顧客が買ったのか?」の特徴を設定している。

関数setFeatures

で特徴を設定している。
この関数を自由に修正すれば、データセットに任意の傾向を持たせることができる。

この実装例では、

  • 40歳以上の男性会社員が買った。
  • 29歳以下の女性(職業問わず)が買った。
  • 既婚の会社員が買った。

という特徴を持たせている。
現実のデータセットは、このような綺麗な分類にはならないが、
上記の条件に合致しなくても、ランダムにyesを返す処理を加えれば、
「条件に合致していないけど買った顧客」を混在させるこも可能。

出力結果の例

out_int.csv

gender,age,marriage,pref,job,hobby,y
0,22,1,45,7,7,yes
1,20,1,4,3,6,no
0,64,1,18,6,7,no
0,44,0,29,1,4,no
0,69,0,18,2,5,no
0,49,1,20,7,1,no
0,40,1,41,7,4,no

out_label.csv

性別,年齢,婚姻,都道府県,職業,趣味,y
女性,22,既婚,宮崎県,無職,テレビゲーム,yes
男性,20,既婚,宮城県,自営業,オートバイ,no
女性,64,既婚,福井県,専業主婦,テレビゲーム,no
女性,44,未婚,奈良県,会社員,アウトドア,no
女性,69,未婚,福井県,会社役員,自動車,no
女性,49,既婚,長野県,無職,読書,no
女性,40,既婚,佐賀県,無職,アウトドア,no

ためしに決定木で分類させてみると

import pandas as pd
import numpy as np
import pydotplus
from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeClassifier as DT
from IPython.display import Image
train = pd.read_csv("csv/out_label.csv",delimiter=",")
y = train["y"]
trainx = train.iloc[:,0:6]
trainxd = pd.get_dummies(trainx)
clf3 = DT(max_depth=20, min_samples_leaf=500)
clf3.fit(trainxd,y)
export_graphviz(clf3, out_file="tree_clf3.dot", feature_names=trainxd.columns, class_names=["0","1"], filled=True, rounded=True)
graph = pydotplus.graphviz.graph_from_dot_file('tree_clf3.dot')
Image(graph.create_png())

image.png

まとめ

モデルの動作を確認する際は、
モデルに読み込ませるデータセットの特徴(答え)をあらかじめ把握しておき、
「その答えの通りに分類できたか?」を確認すると分かりやすい。
分類できていれば、そのモデルの実装は正しいと判断できるし、
分類できていなければ、実装が悪いのか、パラメータが適切ではないのか、など原因を探っていく。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0