6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

「Elixirで機械学習に初挑戦」をやってみた(後編)

Last updated at Posted at 2023-04-03

はじめに

@piacerex さんの「Elixirで機械学習に初挑戦」シリーズを自分なりにやってみるシリーズです

これまでの記事はこちら

  • 前編:

  • 中編:

今回も Livebook を使います

Elixirで機械学習に初挑戦⑤:データ処理に強いElixirでKaggle挑戦(後編)…「統計」と「EDA」で Kaggleに挑む

今回は「統計」と「EDA(探索的データ分析:Exploratory Data Analysis)」を使って前処理を改善し、Kaggleの順位を上げていきます

実装したノートブックはこちら

セットアップ

前回の Explorer 版を土台として改良を加えていきます

インストールするモジュールは同じです

Mix.install([
  {:exla, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:kino, "~> 0.9"},
  {:kino_vega_lite, "~> 0.1"},
  {:explorer, "~> 0.5"}
])

エイリアスをつけておきます

alias Explorer.DataFrame
alias Explorer.Series
require Explorer.DataFrame

データのアップロード

前回 Kaggle からダウンロードしたタイタニックのデータを Livebook にアップロードして使います

Kino.Input.file を使ってファイル選択の UI を作ります

train_data_input = Kino.Input.file("train data")

ファイル選択で「train.csv」を選択し、アップロードします

同じく、「test.csv」をアップロードします

test_data_input = Kino.Input.file("test data")

学習データを読み込みます

train_data =
  train_data_input
  |> Kino.Input.read()
  |> Map.get(:file_ref)
  |> Kino.Input.file_path()
  |> DataFrame.from_csv!()

Kino.DataTable.new(train_data)

スクリーンショット 2023-04-02 7.41.37.png

評価データを読み込みます

test_data =
  test_data_input
  |> Kino.Input.read()
  |> Map.get(:file_ref)
  |> Kino.Input.file_path()
  |> DataFrame.from_csv!()

Kino.DataTable.new(test_data)

スクリーンショット 2023-04-02 7.44.12.png

全体を見渡したいので、学習データと評価データを合わせた全体データを作ります

「Survived」は評価データに存在しないので取り除きます

full_data =
  train_data
  |> DataFrame.discard("Survived")
  |> DataFrame.concat_rows(test_data)

Kino.DataTable.new(full_data)

スクリーンショット 2023-04-02 7.44.48.png

欠損値補完

前回は適当な値や平均値で補完していましたが、改めてどういう値で補完するのがいいか検討してみましょう

年齢

まず何件欠損しているのか確認します

Series.nil_count(full_data["Age"])

結果は 263 でした

それなりの割合で欠損しているので、この値の補完は重要そうです

どういう値で補完すべきか検討するため、年齢に関する平均値、中央値、最頻値を取得してみます

最頻値はまだ Explorer に実装されていない([現在 Pull Request がマージされていない状態]
(https://github.com/elixir-nx/explorer/pull/453))なので、自前で実装します

get_mode = fn series ->
  series
  |> Series.frequencies()
  |> DataFrame.filter(is_not_nil(values))
  |> then(& &1["values"])
  |> Series.first()
end

今後も何回か使うので、平均値、中央値、最頻値を取得して表示する関数を用意しておきます

get_statistics = fn series ->
  %{
    "平均値": Series.mean(series),
    "中央値": Series.median(series),
    "最頻値": get_mode.(series)
  }
end

では年齢について適用してみましょう

get_statistics.(full_data["Age"])

結果は以下のようになります

%{中央値: 28.0, 平均値: 29.881137667304014, 最頻値: 24.0}

中央値も平均値もおよそ同じですが、ヒストグラム(値の範囲毎に何件データがあるかを表すグラフ)を見て、もう少し検討してみましょう

表示対象がカテゴリーの場合は単なる棒グラフにしています

histgram = fn df, colname ->
  value_list = Series.to_list(df[colname])

  unique_count = value_list |> Enum.uniq() |> Enum.count()

  {x_type, bin} =
    if unique_count > 50 do
      {:quantitative, %{maxbins: 50}}
    else
      {:nominal, nil}
    end

  VegaLite.new(width: 600, height: 300)
  |> VegaLite.data_from_values(value: value_list)
  |> VegaLite.mark(:bar, tooltip: true)
  |> VegaLite.encode_field(:x, "value", type: x_type, bin: bin, title: colname)
  |> VegaLite.encode_field(:y, "value", type: :quantitative, aggregate: :count)
end

年齢のヒストグラムは以下のようになります

histgram.(full_data, "Age")

age_hist.png

これを見ると、明らからに頂点が2つあることが分かります

つまり、大人たちと、子どもたちです

単純に全体の平均値や中央値を使うのではなく、それぞれグループ毎の平均値や中央値を使った方が精度は高くなりそうです

では10歳未満のデータを見てみましょう

full_data
|> DataFrame.filter(col("Age") < 10)
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 8.23.43.png

これをパッとみてわかることが2つあります

  • 敬称が「Master」か「Miss」しかない
  • Parch のほとんどが 1 以上

大人の場合は「Mr.」や「Mrs.」がありますが、子どもの敬称は「Master」と「Miss」の2種類だけなので、これで子どもかどうか判断できそうです

また、子どもで1人で乗船しているケースは極稀なので、Parch(親、もしくは子ども)はほとんど 1 以上になります

この仮説を元に、乗客を「多分子ども」と「多分大人」に分けます

full_data =
  full_data
  |> DataFrame.mutate(
    prob_child:
      col("Name") |> contains("Master")
      or col("Name") |> contains("Miss")
      and col("Parch") > 0
  )
  |> DataFrame.mutate(prob_adult: not prob_child)

Kino.DataTable.new(full_data)

スクリーンショット 2023-04-02 8.29.27.png

各グループの年齢を見てみましょう

full_data
|> DataFrame.filter(prob_child)
|> DataFrame.pull("Age")
|> get_statistics.()

「多分子ども」の値は以下のようになりました

%{中央値: 7.0, 平均値: 9.427674418604651, 最頻値: 2.0}

明らかに低いですね

逆に「多分大人」を見てみましょう

full_data
|> DataFrame.filter(prob_adult)
|> DataFrame.pull("Age")
|> get_statistics.()
%{中央値: 30.0, 平均値: 32.75845147219193, 最頻値: 24.0}

子どもたちが分離され、平均値や中央値が上がっていますね

ヒストグラムも見てみましょう

full_data
|> DataFrame.filter(prob_child)
|> histgram.("Age")

child_hist.png

full_data
|> DataFrame.filter(prob_adult)
|> histgram.("Age")

adult_hist.png

それぞれ、子どもではない人や大人ではない人が少し入っていますが、おおよそは問題なさそうです

どちらも左側(若い方)に偏りが見られるため、平均値よりは中央値を使った方が良さそうですね

料金

料金の欠損を見てみましょう

Series.nil_count(full_data["Fare"])

結果は「1」で、補完しても1件しか改善できないことが分かります

ともあれ統計量とヒストグラムを見てみましょう

get_statistics.(full_data["Fare"])

統計量は以下のようになります

%{中央値: 14.4542, 平均値: 33.29547928134557, 最頻値: 8.05}

中央値と平均値、最頻値がかなり違いますね

histgram.(full_data, "Fare")

fare_hist.png

ヒストグラムを見ると、ほとんどの乗客は安い料金で、極一部にものすごく高い料金の乗客がいる、というのが分かります

というわけで中央値で補完しましょう

搭乗港

Series.nil_count(full_data["Embarked"])

欠損数は 2 で、これもあまり影響しなさそうです

搭乗港は文字列なので、単純に棒グラフで見てみます

histgram.(full_data, "Embarked")

embarked_hist.png

圧倒的に「S」が多いので、「S」で補完してしまいましょう

客室番号

Series.nil_count(full_data["Cabin"])

欠損数が 1014 でほとんど欠損しています

これは使えそうにないので、補完せず、予測にも使わないことにします

生存との相関

どの項目が生存に強く関係しているか見てみましょう

生存率

そもそも全体での生存率はどの程度でしょうか

histgram.(train_data, "Survived")

surviced_hist.png

半数以上生き残れなかったようです

survived_counts =
  train_data["Survived"]
  |> Series.frequencies()
  |> DataFrame.arrange(values)
  |> DataFrame.to_columns()
  |> Map.get("counts")

survived_rate = Enum.at(survived_counts, 1) / Enum.sum(survived_counts)

生存率は 0.3838383838383838 でした

チケット階級

チケット階級が上であるほど生存率は高そうな気がします

積み上げ棒グラフで確認してみましょう

color_histgram = fn df, colname, color_colname ->
  value_list = Series.to_list(df[colname])
  color_list = Series.to_list(df[color_colname])

  unique_count = value_list |> Enum.uniq() |> Enum.count()

  {x_type, bin} =
    if unique_count > 20 do
      {:quantitative, %{maxbins: 20}}
    else
      {:nominal, nil}
    end

  VegaLite.new(width: 600, height: 300)
  |> VegaLite.data_from_values(value: value_list, color: color_list)
  |> VegaLite.mark(:bar, tooltip: true)
  |> VegaLite.encode_field(:x, "value", type: x_type, bin: bin, title: colname)
  |> VegaLite.encode_field(:y, "value", type: :quantitative, aggregate: :count)
  |> VegaLite.encode_field(:color, "color", type: :nominal)
end
color_histgram.(train_data, "Pclass", "Survived")

pclass_bar.png

直観的にチケット階級が生存率を左右することが分かりますね

チケット階級と生存の組み合わせ毎に何件なのか、表にしてみます

cross_table =
  train_data
  |> DataFrame.group_by(["Pclass", "Survived"])
  |> DataFrame.summarise(count: count(col("Survived")))
  |> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
  |> DataFrame.arrange(col("Pclass"))

Kino.DataTable.new(cross_table)

スクリーンショット 2023-04-02 8.59.34.png

このような2つの項目の組み合わせ毎の件数表をクロス集計表と言います

1 と 2 は生存者の方が多く、 3 は死者の方が多いです

もっと分かりやすく生存率をつけてみます

cross_table
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") +  col("Survived_1")))
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 9.01.06.png

思った以上に明白な差が出ていますね

チケット階級は生存率に大きく影響しています

料金

料金も同じように影響していそうです

color_histgram.(train_data, "Fare", "Survived")

fare_suv_hist.png

明確に料金が低い人の生存率が低いです

料金を 50 ドル単位でグループ化してクロス集計表を見てみましょう

train_data
|> DataFrame.filter(is_not_nil(col("Fare")))
|> DataFrame.mutate(fare_group: col("Fare") / 50 |> floor() |> cast(:integer))
|> DataFrame.group_by([:fare_group, "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.arrange(fare_group)
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") +  col("Survived_1")))
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 10.07.51.png

料金を 50 ドル単位にグループ分けして学習してみましょう

性別

color_histgram.(train_data, "Sex", "Survived")

sex_suv_hist.png

性別も明確に差が出ますね

train_data
|> DataFrame.group_by(["Sex", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") +  col("Survived_1")))
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 10.14.46.png

男女で生存率に圧倒的な差があるので、これは使えそうです

年齢

年齢は 10 歳毎の年齢層で見てみましょう

train_data
|> DataFrame.filter(is_not_nil(col("Age")))
|> DataFrame.mutate(age_group: col("Age") / 10 |> floor() |> cast(:integer))
|> color_histgram.("age_group", "Survived")

age_suv_hist.png

子どもの生存率が高く、お年寄りの生存率は低そうです

train_data
|> DataFrame.filter(is_not_nil(col("Age")))
|> DataFrame.mutate(age_group: col("Age") / 10 |> floor() |> cast(:integer))
|> DataFrame.group_by([:age_group, "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.arrange(age_group)
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") +  col("Survived_1")))
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 10.19.21.png

年齢層を使ってみましょう

搭乗港

どの港から乗ってきたか、なんてことが生存率に関係しているのでしょうか

color_histgram.(train_data, "Embarked", "Survived")

emb_suv_hist.png

結構相関がありそうです

train_data
|> DataFrame.group_by(["Embarked", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") +  col("Survived_1")))
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 23.13.09.png

シェルブールからの乗客は生存率が高いようです

逆にサウサンプトンは全体の生存率よりも低くなっています

なぜこのような差が出るのでしょうか

搭乗港毎の平均料金を見てみましょう

train_data
|> DataFrame.group_by("Embarked")
|> DataFrame.summarise(mean: col("Fare") |> mean())
|> Kino.DataTable.new()

スクリーンショット 2023-04-02 23.15.51.png

シェルブールの料金が高くなっています

チケット階級はどうでしょう

color_histgram.(full_data, "Embarked", "Pclass")

emb_pcl_hist.png

シェルブールは半分以上1級で、クイーンズタウンはほとんど3級です

これだけ見ると、シェルブールの生存率が高いのはお金持ちが多かったため、と推測できます

しかし、クイーンズタウンとサウサンプトンの生存率は料金やチケット階級から逆転しています

イングランド、アイルランド、フランスの国民性なども影響しているのでしょうか

ともあれ搭乗港も予測に使いましょう

家族、同乗者

同乗していた兄弟姉妹、配偶者の数を見てみます

color_histgram.(train_data, "SibSp", "Survived")

sib_suv_hist.png

兄弟姉妹、配偶者がいない人よりは1人いる人の方が生存率が高いようです

しかし、2人以上いる場合は生存率が低くなっています

親子の場合も見てみましょう

color_histgram.(train_data, "Parch", "Survived")

par_suv_hist.png

こちらも同じような傾向です

親がいる子どもの方が生存率が高いし、兄弟で乗船しているケースの多くは子どもだからでしょう

ただし、子どもが多すぎると全員を助けようとして上手くいかないケースが出てくるのでしょうか

しかし、もう少し進めて考えてみると、家族以外で一緒に乗船しているケースも考えられます

同じチケット番号の乗客は同乗者として集計してみます

followers_df =
  full_data["Ticket"]
  |> Series.frequencies()
  |> DataFrame.rename(["Ticket", "followers"])
  |> DataFrame.mutate(followers: followers - 1)

Kino.DataTable.new(followers_df)

スクリーンショット 2023-04-03 0.10.13.png

チケット番号「LINE」というのは正しい番号ではなさそうなので、これは除いて結合します

train_data =
  train_data
  |> DataFrame.join(DataFrame.filter(followers_df, col("Ticket") != "LINE"), how: :left)
  |> then(&DataFrame.put(&1, :followers, Series.fill_missing(&1["followers"], 0)))

Kino.DataTable.new(train_data)

スクリーンショット 2023-04-03 0.13.20.png

これで同乗者数が分かるようになりました

家族が同じ船にいるのに違うチケットで乗っている人はいるでしょうか

train_data
|> DataFrame.filter(followers == 0 and col("SibSp") > 0 and col("Parch") > 0)
|> Kino.DataTable.new()

スクリーンショット 2023-04-03 0.16.15.png

5人いましたが、うち4人は死亡しています

年齢を見ると、10代後半から20代前半です

ある程度成長して、船でも家族とは別の個室を使っていたのでしょう

では同乗者数と生存の関係を見てみます

color_histgram.(train_data, "followers", "Survived")

fol_suv_hist.png

train_data
|> DataFrame.group_by(["followers", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.arrange(col("followers"))
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") +  col("Survived_1")))
|> Kino.DataTable.new()

スクリーンショット 2023-04-03 0.22.39.png

同乗者数 0 は生存率が低く、 1 から 3 は生存率が高く、 4 以上でまた下がっています

家族ではなく同乗者数の方を学習してみましょう

前処理モジュール定義

ここまでの検討を元にして前処理モジュールを定義します

defmodule PreProcess do
  def load_csv(kino_input) do
    kino_input
    |> Kino.Input.read()
    |> Map.get(:file_ref)
    |> Kino.Input.file_path()
    |> DataFrame.from_csv!()
  end

  def fill_empty(data, fill_map) do
    fill_map
    |> Enum.reduce(data, fn {column_name, fill_value}, acc ->

      fill_value =
        if fill_value == :median do
          Series.median(data[column_name])
        else
          fill_value
        end

      DataFrame.put(
        acc,
        column_name,
        Series.fill_missing(data[column_name], fill_value)
      )
    end)
  end

  def replace_dummy(data, columns_names) do
    data
    |> DataFrame.dummies(columns_names)
    |> DataFrame.concat_columns(DataFrame.discard(data, columns_names))
  end

  def to_tensor(data) do
    data
    |> DataFrame.to_columns()
    |> Map.values()
    |> Nx.tensor(backend: EXLA.Backend)
    |> Nx.transpose()
    |> Nx.to_batched(1)
    |> Enum.to_list()
  end

  def process(kino_input, id_key, label_key, followers_df) do
    data_org = load_csv(kino_input)

    id_list = Series.to_list(data_org[id_key])

    has_label_key =
      data_org
      |> DataFrame.names()
      |> Enum.member?(label_key)

    label_list =
      if has_label_key do
        data_org[label_key]
        |> Series.to_tensor(backend: EXLA.Backend)
        |> Nx.as_type(:f32)
        |> Nx.new_axis(1)
        |> Nx.to_batched(1)
        |> Enum.to_list()
      else
        nil
      end

    inputs =
      if has_label_key do
        DataFrame.discard(data_org, [id_key, label_key])
      else
        DataFrame.discard(data_org, [id_key])
      end
      |> DataFrame.mutate(
        prob_child:
          col("Name") |> contains("Master")
          or col("Name") |> contains("Miss")
          and col("Parch") > 0
      )

    filled_age =
      [
        Series.to_list(inputs["Age"]),
        Series.to_list(inputs["prob_child"]),
      ]
      |> Enum.zip()
      |> Enum.map(fn
        {nil, true} ->
          9
        {nil, false} ->
          30
        {age, _prob_child} ->
          age
      end)
      |> Series.from_list()

    inputs =
      inputs
      |> DataFrame.put("Age", filled_age)
      |> DataFrame.join(followers_df, how: :left)
      |> fill_empty(%{"followers" => 0, "Embarked" => "S", "Fare" => :median})
      |> replace_dummy(["Embarked", "Pclass"])
      |> DataFrame.mutate(is_man: col("Sex") == "male")
      |> DataFrame.mutate(fare_group: col("Fare") / 50 |> floor())
      |> DataFrame.mutate(age_group: col("Age") / 10 |> floor())
      |> DataFrame.discard(["Cabin", "Name", "Ticket", "Sex", "Fare", "Age", "SibSp", "Parch"])
      |> to_tensor()

    {id_list, label_list, inputs}
  end
end

学習

あとは前回と同じです

{
  train_id_list,
  train_label_list,
  train_inputs
} = PreProcess.process(train_data_input, "PassengerId", "Survived", followers_df)
{
  test_id_list,
  test_label_list,
  test_inputs
} = PreProcess.process(test_data_input, "PassengerId", "Survived", followers_df)
model =
  Axon.input("input", shape: {nil, 11})
  |> Axon.dense(48, activation: :tanh)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(48, activation: :tanh)
  |> Axon.dense(1, activation: :sigmoid)
train_data = Enum.zip(train_inputs, train_label_list)
loss_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "loss", type: :quantitative)
  |> Kino.VegaLite.new()

acc_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
  |> Kino.VegaLite.new()

Kino.Layout.grid([loss_plot, acc_plot], columns: 2)
trained_state =
  model
  |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(0.0005))
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
  |> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
  |> Axon.Loop.run(train_data, %{}, epochs: 50, compiler: EXLA)

予測結果

results =
  test_inputs
  |> Nx.concatenate()
  |> then(&Axon.predict(model, trained_state, &1))
  |> Nx.to_flat_list()
  |> Enum.map(&round(&1))
  |> then(
    &%{
      "PassengerId" => test_id_list,
      "Survived" => &1
    }
  )
  |> DataFrame.new()

Kino.DataTable.new(results)
results
|> DataFrame.dump_csv!()
|> then(&Kino.Download.new(fn -> &1 end, filename: "result.csv"))

まとめ

色々試してみて、最高は 78.947 でした

順位は 1,065 位で、15,385 チームの上位 7% に入りました

スクリーンショット 2023-04-03 9.48.37.png

まだまだ分析できそうなことはあると思いますが、とりあえずここまで

その他、 Elixir AI・ML で何ができるのか、なぜ Elixir を使うのかについては @piacerex さんの記事を是非参照してください

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?