LoginSignup
1
2

More than 1 year has passed since last update.

Transformerで多クラス分類する(keras使用)

Posted at

IRISデータを多クラス分類するコードです。
日本語の記事などではkerasを使ったコードの情報が少なかったので自分の備忘録として残しておきます。

コードを書くときに以下のkeras.ioのサンプルコードを参考にしてたのですが、自力では挫折したのでChatGPT使いました。ChatGPT便利ですね。。。
https://keras.io/examples/nlp/text_classification_with_transformer/

iris.py
import os
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.datasets import load_iris
from tensorflow.keras import layers, Model
from tensorflow.keras.utils import to_categorical


# Transformer用のMultiHeadAttentionレイヤーを定義
class MultiHeadAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0

        self.query_dense = layers.Dense(embed_dim)
        self.key_dense = layers.Dense(embed_dim)
        self.value_dense = layers.Dense(embed_dim)
        self.combine_heads = layers.Dense(embed_dim)

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.embed_dim // self.num_heads))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]

        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)

        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))
        output = self.combine_heads(concat_attention)
        return output

# Transformerブロックを定義
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): # rate:ドロップアウトの割合
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadAttention(embed_dim, num_heads)
        self.ffn = tf.keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self,inputs, training):
        attn_output = self.att(inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

# モデルの構築
class TransformerClassifier(Model):
    def __init__(self, num_classes, embed_dim, num_heads, ff_dim, input_shape, num_layers=2, rate=0.1):
        super(TransformerClassifier, self).__init__()
        self.embedding = layers.Dense(embed_dim, input_shape=input_shape)
        #self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim, rate)
        self.transformer_blocks = [TransformerBlock(embed_dim, num_heads, ff_dim, rate) for _ in range(num_layers)]
        self.pool = layers.GlobalAveragePooling1D()
        self.dropout = layers.Dropout(rate)
        self.classifier = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training):
        x = self.embedding(inputs)
        #x = self.transformer_block(x, training)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        x = self.pool(x)
        x = self.dropout(x, training=training)
        return self.classifier(x)

# カスタムコールバックの定義
class CustomModelCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self, save_freq):
        super(CustomModelCheckpoint, self).__init__()
        self.save_freq = save_freq

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_freq == 0:
            model_name = f"transformer_iris_epoch_{epoch + 1}"
            self.model.save(model_name, save_format="tf")  # save_format="tf" を追加
            print(f"\nModel saved as {model_name}\n")


def main():

    # データセットの読み込み
    iris = load_iris()
    data = iris.data
    target = to_categorical(iris.target)

    # データの前処理
    scaler = StandardScaler()
    data = scaler.fit_transform(data)

    # データを訓練用とテスト用に分割
    train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.2, random_state=42)


    # ハイパーパラメータ
    embed_dim = 64     # 埋め込み次元数. num_headsで割り切れる数じゃないとダメ
    num_heads = 8      # マルチヘッドアテンション内のヘッド数
    ff_dim = 64        # フィードフォワードネットワークの中間層のニューロン数
    num_classes = 3    # 分類するクラス数
    num_layers = 6     # Transformerの段数

    # モデルを保存する頻度を指定
    save_freq = 100
    custom_checkpoint = CustomModelCheckpoint(save_freq)

    # モデルのインスタンス化
    model = TransformerClassifier(num_classes, embed_dim, num_heads, ff_dim, num_layers=num_layers, input_shape=(train_x.shape[1],))

    # モデルのコンパイル
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss='categorical_crossentropy', metrics=['accuracy'])


    '''
    # 保存されたモデルをロード
    model_name = "transformer_iris_epoch_100"  # この部分は、適切なモデル名に変更してください
    model = tf.keras.models.load_model(model_name)

    # モデルの要約を表示
    model.summary()

    # モデルのコンパイル
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                     loss='categorical_crossentropy',
                     metrics=["accuracy"])
    '''

    # モデルの訓練
    history = model.fit(train_x, train_y, batch_size=12, epochs=100, validation_split=0.2, callbacks=[custom_checkpoint])

    # モデルの評価
    loss, accuracy = model.evaluate(test_x, test_y)
    print('Test loss:', loss)
    print('Test accuracy:', accuracy)

if __name__ == '__main__':
    main()

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2