0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

TensorFlowで複数のモデルを繋げてそれぞれから出力を取る話

Last updated at Posted at 2023-04-20

1. はじめに

※この記事は2年前にnoteに投稿したものと同じ内容です。

この記事を書くにあたっては、kerasの開発者であるFrançois Chollet氏がTwitterに投稿しているコードを大いに参考にしました。
https://twitter.com/fchollet

さて、GANを構築するときのように、1つのモデルの出力を別のモデルの入力に繋げて、出力結果もそれぞれで取りたい状況というのはしばしば発生します。このようなモデルを記述したいときには、全体のモデルをクラスとして宣言して、その中のdefに複数のモデルを記述する方法があります。

2. データセットの作成

まずは、今回使用するデータセットを作ることから始めます。必要なパッケージも最初にimportしておきましょう。都合上ディレクトリに保存するようにしますが、出力画像のプレビューも載せていくので、ディレクトリ汚したくない方はそちらを見ていっていただけたらと思います。

    from PIL import Image, ImageDraw, ImageFont
    import string
    import pathlib
    import tensorflow as tf
    import numpy as np

    d = []

    def noise(x):
       y = np.where(x < 100, x + 100, x) -np.random.randint(0, 100, 64*64).reshape(64,64)
       return y

    for seq in range(100000):
       font = ImageFont.truetype('times.ttf', 20)
       im_str = Image.new("L",(64, 64),0)
       draw_str = ImageDraw.Draw(im_str)

       c = np.random.randint(1,5)
   
       for i in range(c):
           sym = np.random.choice(list(string.ascii_letters))
           xy = np.random.randint(0, 20, 2)
           draw_str.text(xy, sym, font=font, fill=255)
           im_str = Image.fromarray(np.rot90(im_str))
           draw_str = ImageDraw.Draw(im_str)
           if i==(c-1):
               d.append(i)
       im_str.save('sep_2/label/label/'+'{:0=10}'.format(seq)+'.png')
       im_str = np.uint8(noise(np.array(im_str)))
       Image.fromarray(im_str).save('sep_1/im/im/'+'{:0=10}'.format(seq)+'.png')
    np.savetxt('sep_3/label/label/label.csv', np.array(d), fmt='%d')

このプログラムで出力されるのは、1~4文字のアルファベットが書かれた黒背景の画像、それにランダムなノイズを乗せた画像、そして画像の文字数を記録した配列(csvファイルに出力)の3つが10万枚分です。実際のプレビューは以下のようになります。数値は実際の文字数から1を引いた数になっています。

nc_seg2.PNG
nc_seg3.PNG

3. ジェネレータの作成

データセットを作成しましたが、10万枚分のリストを全てメモリに乗せたくはないので、ジェネレータとして取り出していきます。以下のようにすれば、上の3つの要素がタプル化されたジェネレータを作ることができます。

    class DatasetGenerator(tf.keras.utils.Sequence):
       def __init__(self, image_path, label1_path, label2):
           self.image = image_path
           self.label1 = label1_path
           self.label2 = label2
           self.indices = np.arange(100000)

           self.length = 100000
           self.batch_size = 64

       def __getitem__(self, idx):
           idx_shuffle = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]

           image_batch_path = []
           label1_batch_path = []
           label2_batch = []
           for i in idx_shuffle:
               image_batch_path.append(self.image[i])
               label1_batch_path.append(self.label1[i])
               label2_batch.append(self.label2[i])

           image_batch=[]
           for path in image_batch_path:
               image_batch.append(np.array(Image.open(path)))

           label1_batch=[]
           for path in label1_batch_path:
               label1_batch.append(np.array(Image.open(path)))

           image_batch = np.reshape(np.array(image_batch)/255.0,(self.batch_size, 64, 64, 1))
           label1_batch = np.reshape(np.array(label1_batch)/255.0,(self.batch_size, 64, 64, 1))
           label2_batch = np.reshape(tf.keras.utils.to_categorical(label2_batch), (self.batch_size, 4))

           return image_batch, label1_batch, label2_batch

       def __len__(self):
           return self.length//self.batch_size

       def on_epoch_end(self):
           np.random.shuffle(self.indices)

imageとlabel1が画像を読み込むときの処理で、label2がnumpy配列を読み込むときの処理です。これに対してパスを割り当ててやります。ちなみに、tf.keras.utils.to_categorical()は1→[0,1,0,0]や2→[0,0,1,0]のような処理を行います。

    image_root = pathlib.Path('sep_1/im/')
    all_image_paths = list(image_root.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]

    label1_root = pathlib.Path('sep_2/label/')
    all_label1_paths = list(label1_root.glob('*/*'))
    all_label1_paths = [str(path) for path in all_label1_paths]

    label2_path = np.loadtxt('sep_3/label/label/label.csv')

これでジェネレータの準備が終わりました。

4. モデルの構築

今回行いたいのは、ノイズ有りの画像からノイズを取り除き、書かれた文字数をカウントするという処理です。そのため、画像からノイズを取り除くモデルと、書かれた文字をカウントするモデルの2つを用意したいです。
実際のところ、これらのモデルを連結するよりも、片方ずつ学習させたほうが効率が良いのですが、あくまで例として捉えていただきたいです。

では、コードから見ていきましょう。まず、__init__で2つのモデルを宣言しておきます。そして、compileをクラスの中に入れ込むことで、モデルの数だけoptimizerとロス関数を入れ込んでcompileすることができます。compile()を入れ込むことによってfit()で学習できるようになるため、fit()の便利な機能を活用することができるようになります。
metricsプロパティは正解率を正しく呼ぶために必要です。

そして、以下でモデルを記述することになります。今回は、ノイズを除去するnoise_clean()と文字数をカウントするsegmentation()の2つのモデルを入れ込みます。ここで注意すべき点は、inputレイヤーが必要であることと、返り値としてtf.keras.models.Model(inputs, outputs)を設定しなければならないことです。

train_stepでも注意点があって、tf.GradientTape()で微分を行う場合には、withの中身で微分する全ての値を与える必要があるということです。この場合、1つ目のモデルの出力が2つ目のモデルの入力になるため、1つ目のモデルの出力を2つ目のモデルに対応するtf.GradientTape()の中で出さないといけないということです。また、返り値としてロスや正解率を書き込むのもここになります。

test_stepはtrain_stepから微分を取り除いたものです。ここでもロスと正解率を返り値として書き込みます。

tf.keras.Modelクラスでは、必ずcallを書き込む必要があります。ここで予測に使う出力を取ることができるため、返り値のリストとして[self.noise_clean(), self.segmentation()]と書き込むことで、それぞれの予測結果をリスト形式で取得することができます。

    class segmentation_noisy_image(tf.keras.Model):
       def __init__(self, **kwargs):
           super().__init__(**kwargs)
           self.noise_clean = self.create_noise_clean()
           self.segmentation = self.create_segmentation()
           self.accuracy_nc = tf.keras.metrics.BinaryAccuracy(name='nc_loss')
           self.accuracy_seg = tf.keras.metrics.CategoricalAccuracy(name='seg_loss')

       def compile(self, nc_optimizer, seg_optimizer):
           super().compile()
           self.nc_optimizer = nc_optimizer
           self.seg_optimizer = seg_optimizer

       @property
       def metrics(self):
           return [self.accuracy_nc, self.accuracy_seg]

       def create_noise_clean(self):
           im_input1 = tf.keras.layers.Input(shape=(64, 64, 1))
           conv1 = tf.keras.layers.Conv2D(16, (5, 5), padding='same')(im_input1)
           conv1 = tf.keras.layers.BatchNormalization()(conv1)
           conv1 = tf.keras.layers.Activation('relu')(conv1)

           conv2 = tf.keras.layers.Conv2D(32, (5, 5),padding='same')(conv1)
           conv2 = tf.keras.layers.BatchNormalization()(conv2)
           conv2 = tf.keras.layers.Activation('relu')(conv2)

           train_out1 = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', padding='same')(conv2)
           return tf.keras.models.Model(im_input1, train_out1)

       def create_segmentation(self):
           im_input2 = tf.keras.layers.Input(shape=(64, 64, 1))
           conv4 = tf.keras.layers.Conv2D(16, (5, 5), padding='same')(im_input2)
           conv4 = tf.keras.layers.BatchNormalization()(conv4)
           conv4 = tf.keras.layers.Activation('relu')(conv4)
       
           dense = tf.keras.layers.Dropout(0.25)(conv4)
           dense = tf.keras.layers.Flatten()(dense)
           dense = tf.keras.layers.Dense(256, activation='relu')(dense)
           dense = tf.keras.layers.Dropout(0.5)(dense)
           train_out2 = tf.keras.layers.Dense(4, activation='softmax')(dense)
           return tf.keras.models.Model(im_input2, train_out2)

       def train_step(self, gen_image_label):
           images, labels1, labels2 = gen_image_label
       
           with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
               predictions_nc = self.noise_clean(images)
               loss_nc = tf.keras.losses.binary_crossentropy(labels1, predictions_nc, from_logits=False)
           
               gradients_nc = tape1.gradient(loss_nc, self.noise_clean.trainable_variables)
               self.nc_optimizer.apply_gradients(zip(gradients_nc, self.noise_clean.trainable_variables))
               self.accuracy_nc.update_state(labels1, predictions_nc)
           
               predictions_seg = self.segmentation(predictions_nc)
               loss_seg = tf.keras.losses.categorical_crossentropy(labels2, predictions_seg, from_logits=False)

               gradients_seg = tape2.gradient(loss_seg, self.segmentation.trainable_variables)
               self.seg_optimizer.apply_gradients(zip(gradients_seg, self.segmentation.trainable_variables))
               self.accuracy_seg.update_state(labels2, predictions_seg)

           return {'nc_loss': loss_nc, 'nc_accuracy': self.accuracy_nc.result(),'seg_loss': loss_seg, 'seg_accuracy': self.accuracy_seg.result()}

       def test_step(self, gen_image_label):
           images, labels1, labels2 = gen_image_label
       
           predictions_nc = self.noise_clean(images)
           loss_nc = tf.keras.losses.binary_crossentropy(labels1, predictions_nc, from_logits=False)
           self.accuracy_nc.update_state(labels1, predictions_nc)
       
           predictions_seg = self.segmentation(predictions_nc)
           loss_seg = tf.keras.losses.categorical_crossentropy(labels2, predictions_seg, from_logits=False)
           self.accuracy_seg.update_state(labels2, predictions_seg)

           return {'nc_loss': loss_nc, 'nc_accuracy': self.accuracy_nc.result(),'seg_loss': loss_seg, 'seg_accuracy': self.accuracy_seg.result()}

       def call(self, z):
           w = self.noise_clean(z)
           return [self.noise_clean(z), self.segmentation(w)]

このモデルをコンパイルしてfitします。

    segmentation_noisy_image = segmentation_noisy_image()
    segmentation_noisy_image.compile(
       nc_optimizer = tf.keras.optimizers.Adam(lr=0.0001, epsilon=1e-06),
       seg_optimizer = tf.keras.optimizers.Adam(lr=0.0001, epsilon=1e-06))

    segmentation_noisy_image.fit(
       DatasetGenerator(all_image_paths, all_label1_paths, label2_path), epochs=5,
       validation_data=DatasetGenerator(all_image_paths, all_label1_paths, label2_path),
       shuffle=False, steps_per_epoch=90000//64, validation_steps=10000//64)

5. 学習結果と予測

このモデルをCPUのみで5エポック学習させた結果が次になります。

    Epoch 1/5
    1406/1406 [==============================] - 399s 283ms/step 
    - nc_loss: 0.0298 - nc_accuracy: 0.9470 - seg_loss: 0.0040 - seg_accuracy: 0.5281 
    - val_nc_loss: 0.0298 - val_nc_accuracy: 0.9647 - val_seg_loss: 0.0043 - val_seg_accuracy: 0.9984
    Epoch 2/5
    1406/1406 [==============================] - 390s 278ms/step 
    - nc_loss: 0.0246 - nc_accuracy: 0.9649 - seg_loss: 0.0090 - seg_accuracy: 0.9986 
    - val_nc_loss: 0.0250 - val_nc_accuracy: 0.9651 - val_seg_loss: 0.0027 - val_seg_accuracy: 0.9990
    Epoch 3/5
    1406/1406 [==============================] - 396s 282ms/step 
    - nc_loss: 0.0260 - nc_accuracy: 0.9651 - seg_loss: 4.0780e-04 - seg_accuracy: 0.9995 
    - val_nc_loss: 0.0235 - val_nc_accuracy: 0.9651 - val_seg_loss: 2.4709e-04 - val_seg_accuracy: 0.9998
    Epoch 4/5
    1406/1406 [==============================] - 382s 271ms/step 
    - nc_loss: 0.0246 - nc_accuracy: 0.9651 - seg_loss: 6.4611e-05 - seg_accuracy: 0.9995 
    - val_nc_loss: 0.0247 - val_nc_accuracy: 0.9651 - val_seg_loss: 8.6831e-05 - val_seg_accuracy: 0.9994
    Epoch 5/5
    1406/1406 [==============================] - 373s 265ms/step 
    - nc_loss: 0.0223 - nc_accuracy: 0.9651 - seg_loss: 5.9787e-05 - seg_accuracy: 0.9998 
    - val_nc_loss: 0.0203 - val_nc_accuracy: 0.9655 - val_seg_loss: 1.5029e-05 - val_seg_accuracy: 0.9999

正しくlossが減少していることがわかりますね。実際に画像で見てみましょう。
nc_seg4.PNG
予測結果がちゃんとリストで返り、ノイズ除去の結果がpred_im[0]に、文字数カウントの結果がpred_im[1]に格納されていることが確認できました。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?