0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

コンパイル時に機械学習でFizzBuzzを実現する

Posted at

元ネタ

「機械学習でFizzBuzzを実現する」
https://zenn.dev/tokoroten/articles/c311cf6e3fc8ac

便乗しました.

注意点

元ネタの記事ではFizzBuzzを4クラス分類(Fizz, Buzz, FizzBuzz, 数値)としてやっていたので,本記事もそうします.
また,教師データの作成はコンパイル時にはやりません.

データの準備

import numpy as np


def encode_label(index):
    e = np.zeros((4,))
    e[index] = 1.0
    return e


def encode_fizzbuzz(num):
    if num % 15 == 0:
        return encode_label(0)
    if num % 5 == 0:
        return encode_label(1)
    if num % 3 == 0:
        return encode_label(2)
    return encode_label(3)


def preprocess_num_to_array(num):
    return np.array([num % i for i in range(1,20)])


def load_fizzbuzz():
    x_train = np.random.randint(10000, size=1000)
    y_train = np.frompyfunc(encode_fizzbuzz, 1, 1)(x_train)
    x_test = np.random.randint(10000, size=30)
    y_test = np.frompyfunc(encode_fizzbuzz, 1, 1)(x_test)
    original_x_train = np.copy(x_train)
    original_x_test = np.copy(x_test)
    x_train = np.frompyfunc(preprocess_num_to_array, 1, 1)(x_train)
    x_test = np.frompyfunc(preprocess_num_to_array, 1, 1)(x_test)

    return (x_train, y_train), (x_test, y_test), (original_x_train, original_x_test)


def write_mini_data(data, name, size):
    mini_data = [(x, y) for x, y in zip(data[0], data[1])]
    with open(f'x_{name}_fizzbuzz', 'w') as f, open(f'y_{name}_fizzbuzz', 'w') as g:
        for i, (x, y) in enumerate(mini_data[:size]):
            f.write('{')
            for j_x, x_e in enumerate(x):
                f.write(str(x_e))
                if j_x + 1 != len(x):
                    f.write(',')
            f.write('}')
            if i + 1 != size:
                f.write(',')
            g.write('{')
            for j_y, y_e in enumerate(y):
                g.write(str(y_e))
                if j_y + 1 != len(y):
                    g.write(',')
            g.write('}')
            if i + 1 != size:
                g.write(',')


def write_orig_data(data, name):
    with open(f'{name}_fizzbuzz_orig', 'w') as f:
        for j_X, x_e in enumerate(data):
            f.write(str(x_e))
            if j_X + 1 != len(data):
                f.write(',')


def get_mini_fizzbuzz():
    train, test, orignal = load_fizzbuzz()
    write_mini_data(train, 'train', len(train[0]))
    write_mini_data(test, 'test', len(test[0]))
    write_orig_data(orignal[0], 'train')
    write_orig_data(orignal[1], 'test')


if __name__ == "__main__":
    get_mini_fizzbuzz()

こんな感じでtrain/testデータを作成します.特徴量は元ネタの記事と同じく1~19による剰余です.
trainデータが1000個,testデータが30個と物足りないですが,その辺はマシンとコンパイラにご相談してください.

学習

コンパイル時計算NNが簡単に書けるSCENNを使います.
リポジトリ:https://github.com/Catminusminus/scenn

データを読み込んで

#ifndef SCENN_LOAD_MINI_FIZZBUZZ_HPP
#define SCENN_LOAD_MINI_FIZZBUZZ_HPP

#include <scenn/dataset.hpp>
#include <scenn/util.hpp>
#include <sprout/array.hpp>

namespace scenn {
template <class NumType>
SCENN_CONSTEXPR auto load_mini_fizzbuzz_data() {
  SCENN_STATIC NumType x_train[1000][19] = {
#include <tools/x_train_fizzbuzz>
  };
  SCENN_STATIC NumType x_test[30][19] = {
#include <tools/x_test_fizzbuzz>
  };
  SCENN_STATIC NumType y_train[1000][4] = {
#include <tools/y_train_fizzbuzz>
  };
  SCENN_STATIC NumType y_test[30][4] = {
#include <tools/y_test_fizzbuzz>
  };
  sprout::array<NumType, 1000> x_train_orig = {
#include <tools/train_fizzbuzz_orig>
  };
  sprout::array<NumType, 30> x_test_orig = {
#include <tools/test_fizzbuzz_orig>
  };
  return std::make_tuple(Dataset(make_matrix_from_array(std::move(x_train)),
                                make_matrix_from_array(std::move(y_train))),
                        Dataset(make_matrix_from_array(std::move(x_test)),
                                make_matrix_from_array(std::move(y_test))),
                        std::pair(x_train_orig, x_test_orig)
                        );
}
}  // namespace scenn

#endif

学習させます.

#include <iostream>
#include <scenn/load/mini_fizzbuzz.hpp>
#include <scenn/scenn.hpp>
#include <string>

template <std::size_t Size, class Model, class Data>
SCENN_CONSTEXPR auto predict(const Model& model, const Data& data) {
  std::array<int, Size> arr = {};
  auto i = 0;
  for (auto&& [x, y] : data.get_data()) {
    arr[i] = model.single_forward(x.transposed()).transposed().argmax();
    ++i;
  }
  return arr;
}

SCENN_CONSTEXPR auto mini_fizzbuzz_test() {
  using namespace scenn;
  auto [train_data, test_data, orig_data] = load_mini_fizzbuzz_data<double>();
  auto trained_model =
      SequentialNetwork(CrossEntropy(), DenseLayer<19, 32, double>(),
                        ActivationLayer<32, double>(ReLU()),
                        DenseLayer<32, 4, double>(),
                        ActivationLayer<4, double>(Softmax()))
          .train<100>(std::move(train_data), 300, 0.05);
  auto evaluation = trained_model.evaluate(test_data);
  return std::make_tuple(orig_data, predict<30>(trained_model, test_data),
                         evaluation);
}

int main() {
  // コンパイル時学習
  SCENN_CONSTEXPR auto predictions = mini_fizzbuzz_test();

  // 出力は実行時
  const auto transform_fizzbuzz = [](auto prediction, auto value) {
    using namespace std::literals::string_literals;
    if (prediction == 0) {
      return "FizzBuzz"s;
    }
    if (prediction == 1) {
      return "Buzz"s;
    }
    if (prediction == 2) {
      return "Fizz"s;
    }
    return std::to_string(value);
  };
  auto [orig_data, predict_values, evaluation] = predictions;
  auto [x_train, x_test] = orig_data;
  for (auto i = 0; i < 30; ++i) {
    std::cout << x_test[i] << ": "
              << transform_fizzbuzz(predict_values[i], x_test[i]) << std::endl;
  }
  std::cout << "Acc: " << static_cast<float>(evaluation) / 30 << std::endl;
}

私が試したところAccは0.9でした.

$ llvm/llvm-project/build/bin/clang++ ./scenn/examples/fizzbuzz.cpp -Wall -Wextra -I$SPROUT_PATH -I$SCENN_PATH -std=gnu++2a -fconstexpr-steps=-1
$ ./a.out
2781: Fizz
7339: 7339.000000
2756: Fizz
6996: Fizz
9109: Fizz
6927: Fizz
2258: 2258.000000
323: 323.000000
4185: 4185.000000
6821: 6821.000000
1981: 1981.000000
4673: 4673.000000
6335: Buzz
4892: 4892.000000
8061: Fizz
7821: Fizz
1644: Fizz
6870: FizzBuzz
2837: 2837.000000
578: 578.000000
4585: Buzz
9054: Fizz
6252: Fizz
8291: 8291.000000
9460: Buzz
9715: Buzz
2829: Fizz
9656: 9656.000000
6702: Fizz
6138: Fizz
Acc: 0.9

これで顧客が「コンパイル時にディープラーニングで予測したとプレスリリースを書きたいので、ディープラーニングを使って作ってください!!」と要望してきても以下略

コンパイル時計算の知見

実行時計算にしてデバッグする

なんか精度が出ない,こんな時にコンパイル時計算のままデバッグするのはキツイです.
実行時計算にしてしまえば簡単にデバッグできるのでした方がよいと思います.

マクロによりコンパイル時計算と実行時計算の切り替えを楽にする

上記のように,コンパイル時計算と実行時計算を切り替えるとき,マクロで一行コメントアウトするだけ,みたいにしておくと楽です.

おわりに

gg使うと爆速コンパイル時計算できるんじゃないかと思っているのですが誰かやってください.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?