@git-tacet

Are you sure you want to delete the question?

Leaving a resolved question undeleted may help others!

MusicVAE modelのTypeErrorについて

Colab上で、MagentaのMusicVAEを利用して、ある一定の傾向を持っているMIDIファイルをもとにモデルを生成し、MIDIファイルを生成したいと思っています。

当方ほんのちょっぴりかじった程度の知識しかなく、ChatGPTやQiitaを参考に見様見真似で記述を進めていたのですが、行き詰ってしまったので、助言をいただきたいです。

発生している問題・エラー

TypeError                                 Traceback (most recent call last)
<ipython-input-16-f07dbcf8372f> in <module>
     23 
     24 # Train the MusicVAE model
---> 25 model = MusicVAE(input_shape=train_data.shape[1:], z_size=256)
     26 
     27 

TypeError: __init__() got an unexpected keyword argument 'input_shape'

該当するソースコード

from google.colab import drive
drive.mount('/content/drive')

!pip install magenta
!pip install --upgrade tensorflow
!pip install --upgrade music21
!pip install --upgrade pretty_midi
!pip install --upgrade keras

import magenta
import pretty_midi
import numpy as np
import os
from magenta.models.music_vae import MusicVAE
import tensorflow as tf

def preprocess_midi(midi_path):
  midi_data = pretty_midi.PrettyMIDI(midi_path)
  piano_roll = midi_data.get_piano_roll(fs=50)
  piano_roll = (piano_roll - piano_roll.mean()) / piano_roll.std()
  return piano_roll

midi_data = []
max_len = 0
midi_data_path = '/content/drive/MyDrive/twelve_tone'
midi_files = os.listdir(midi_data_path)
for file in midi_files:
  midi = preprocess_midi(os.path.join(midi_data_path, file))
  max_len = max(max_len, midi.shape[1])
midi_data = np.zeros((len(midi_files), 128, max_len))
for i, file in enumerate(midi_files):
  midi = preprocess_midi(os.path.join(midi_data_path, file))
  midi_data[i, :, :midi.shape[1]] = midi

train_data = midi_data[:int(0.8 * len(midi_data))]
val_data = midi_data[int(0.8 * len(midi_data)):]

model = MusicVAE(input_shape=train_data.shape[1:], z_size=256)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001))
model.fit(train_data, epochs=10, validation_data=val_data)

z = np.random.randn(1, 256)
generated_sequence = model.decode(z, length=100)

generated_midi = pretty_midi.PrettyMIDI()
piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
piano = pretty_midi.Instrument(program=piano_program)
piano.notes = pretty_midi.util.piano_roll_to_pretty_midi(generated_sequence[0], fs=100)
generated_midi.instruments.append(piano)
generated_midi.write('generated.mid')

自分で試したこと

ChatGPTさんに「input_dim」と「imput_shape」という名前の引数を提案してもらったのですが、特に突破できることはありませんでした。

「input_dim」と「imput_shape」は存在しない引数らしいので、存在するそれっぽい引数をひたすらコピペしましたが思うような結果は得られませんでした。

もう一度ChatGPTさんに聞いたところ、
『MagentaのMusicVAEクラスには、'input_dim'というキーワード引数は存在しません。代わりに、'input_shape'というキーワード引数を使用する必要があります。』
『「input_dim」や「input_shape」といった引数が「MusicVAE」に存在しない場合、代わりに「input_shape」引数を使用することができます。例えば、以下のようになります。』
などと言われて、すこし萎えそうになりました。

ChatGPTがだめならQiitaだ!やっぱり生身の人間だよ、、!と思いQiitaで検索をかけたところ、望ましい投稿は確認できませんでした。

それではダメもとで聞いてみよう。。と思い、Qiitaの質問として投稿することにしました。←イマココ

0 likes

1Answer

TypeError: __init__() got an unexpected keyword argument 'input_shape'

そも,MusicVAEにそのキーワード引数が無いのでエラーになっております.ChatGPTの推薦は誤字の可能性を鑑みた似たような引数の提案でした.

上記リンクから一部抜粋
class MusicVAE(object):
  """Music Variational Autoencoder."""

  def __init__(self, encoder, decoder):
    """Initializer for a MusicVAE model.
    Args:
      encoder: A BaseEncoder implementation class to use.
      decoder: A BaseDecoder implementation class to use.
    """
    self._encoder = encoder
    self._decoder = decoder

しかし,実装を参照すると,shape周りを要求するように引数が作られていません.エンコーダとデコーダを指定するように書く必要があります.

GANとなれば普通のtf.keras.models.Modelとは異なる実装が多いので気をつけましょう.特にinput_shapeを持つのはtf.keras.layers.Layerに多いですが,これらの要件をSubclassing APIに適用することは少ないと感じます.

Keras公式におけるVAEの実装を次に例示します.

上記リンクから一部抜粋
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

少なくとも,VAEにおける実装は最低限トレースされてMusicVAEが構成されていることがわかると思います.

0Like

Your answer might help someone💌