会社で支給されたノートPCではDeep LearningのDの字も実現できず、セキュリティがきつすぎてsshで外に出ることも出来ず、困り果てたときに出会ったGoogle ColaboratoryについてGoogleへの感謝と備忘録を兼ねて。
Google Driveのマウント
[参照]
https://qiita.com/tomo_makes/items/b3c60b10f7b25a0a5935
必要なパッケージ(keras)の導入とか
Colaboratory環境だとkerasのバックエンドはデフォルトでtensorflowになるようです。
GPUを利用したランタイムを使用する際、注意事項として、絶対にtensorflowやtensorflow-gpuをpipでupgradeしないこと!
なぜかupgradeするとGPUがtensorflowから認識できなくなります。これで3日ほど無駄にしました。
!pip install -q keras
import keras
!apt-get -qq install -y graphviz && pip install -q pydot
GPUを使えるか確認
[ランタイム]-[ランタイムのタイプを変更]で[ハードウェアアクセラレータ]をGPUにセット。pythonは2でも3でもOK
from tensorflow.python.client import device_lib
device_lib.list_local_devices()
うまく認識できてれば以下のようになります。
[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 4218635468701404478, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 11297803469
locality {
bus_id: 1
}
incarnation: 11094235734753279240
physical_device_desc: "device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7"]
Qausi-recurrent Neural Networks(QRNN)のモデルをセーブして読み込むときの注意
QRNNはLSTMにとって変わるとか変わらないとか言われているモデルだそうです。
詳しくはこちらで、詳しく解説されておりますので割愛
https://qiita.com/icoxfog417/items/d77912e10a7c60ae680e
kerasでの実装もあって、以下から取得できます。
https://github.com/DingKe/qrnn
この中のQRNNクラスをまるっとコピーすれば良いのですが、モデルをセーブして読み込む際に通常通り
model.save("hoge.h5")
new_model = load_model("hoge.h5",{"QRNN",QRNN})
みたいにすると、うまく動きません。いろいろ調べると、重みデータしかセーブされていないようなので以下のように、モデルの骨組みを作る関数を噛ませて、重みを読み込むようにしてやるといけます。
def createQRNNModel():
in_out_neurons = 1
hidden_neurons = 256
model = Sequential()
model.add(QRNN(hidden_neurons, batch_input_shape=(None,LENGTH_OF_SEQUENCES,1),window_size=LENGTH_OF_SEQUENCES))
model.add(Dense(units=in_out_neurons))
model.add(Activation("linear"))
return model
# モデル作成
model = createQRNNModel()
model.compile(loss="mean_squared_error", optimizer="adam",metrics=['accuracy'])
early_stopping = EarlyStopping(monitor="acc", mode="auto",patience=0)
# モデル学習
model.fit(X_train,y_train, batch_size=LENGTH_OF_SEQUENCES, epochs=30, callbacks=[early_stopping])
# モデル保存
model.save(model_outfilename+".h5")
# モデル読み込み
loadmodel = createQRNNModel()
loadmodel.load_weights(model_infilename+".h5")
もしかするとこのQRNN以外でも自作のレイヤーを含む場合はこのやり方が安全かもしれません。
なぜかうまく動かなくなったら(強制リセット)
自分の場合Googleドライブのマウントでなぜかファイルが見えず、Googleドライブ側からGoogleSDKの接続を強制的にきったりしたら、その後マウントできなくなる症状が発生。にっちもさっちもいかなくなったときに、ランタイムを初期状態に以下の方法で戻しました。
!kill -9 -1
90分アクティブでなければ勝手にリセットされますが、そんなに待ってられないときは上記方法でリセットできます。パッケージ類やら設定は最初からやり直してください。
あとがき
現状自分がハマった部分の解決方法を記載しました。それ以外の部分は多くの先人が記載されていますのでリンクの記載にとどめてます。なにかまたハマって解決できたら追記します。