LoginSignup
18
14

More than 5 years have passed since last update.

Google Colaboratoryでお手軽Deep learning

Last updated at Posted at 2018-03-12

会社で支給されたノートPCではDeep LearningのDの字も実現できず、セキュリティがきつすぎてsshで外に出ることも出来ず、困り果てたときに出会ったGoogle ColaboratoryについてGoogleへの感謝と備忘録を兼ねて。

Google Driveのマウント

[参照]
https://qiita.com/tomo_makes/items/b3c60b10f7b25a0a5935

必要なパッケージ(keras)の導入とか

Colaboratory環境だとkerasのバックエンドはデフォルトでtensorflowになるようです。
GPUを利用したランタイムを使用する際、注意事項として、絶対にtensorflowやtensorflow-gpuをpipでupgradeしないこと!
なぜかupgradeするとGPUがtensorflowから認識できなくなります。これで3日ほど無駄にしました。

!pip install -q keras
import keras
!apt-get -qq install -y graphviz && pip install -q pydot

GPUを使えるか確認

[ランタイム]-[ランタイムのタイプを変更]で[ハードウェアアクセラレータ]をGPUにセット。pythonは2でも3でもOK

from tensorflow.python.client import device_lib
device_lib.list_local_devices()

うまく認識できてれば以下のようになります。

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 4218635468701404478, name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 11297803469
 locality {
   bus_id: 1
 }
 incarnation: 11094235734753279240
 physical_device_desc: "device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7"]

Qausi-recurrent Neural Networks(QRNN)のモデルをセーブして読み込むときの注意

QRNNはLSTMにとって変わるとか変わらないとか言われているモデルだそうです。
詳しくはこちらで、詳しく解説されておりますので割愛
https://qiita.com/icoxfog417/items/d77912e10a7c60ae680e

kerasでの実装もあって、以下から取得できます。
https://github.com/DingKe/qrnn
この中のQRNNクラスをまるっとコピーすれば良いのですが、モデルをセーブして読み込む際に通常通り

model.save("hoge.h5")
new_model = load_model("hoge.h5",{"QRNN",QRNN})

みたいにすると、うまく動きません。いろいろ調べると、重みデータしかセーブされていないようなので以下のように、モデルの骨組みを作る関数を噛ませて、重みを読み込むようにしてやるといけます。

def createQRNNModel():
  in_out_neurons = 1
  hidden_neurons = 256
  model = Sequential() 
  model.add(QRNN(hidden_neurons, batch_input_shape=(None,LENGTH_OF_SEQUENCES,1),window_size=LENGTH_OF_SEQUENCES))
  model.add(Dense(units=in_out_neurons))
  model.add(Activation("linear"))  
  return model

# モデル作成
model = createQRNNModel()
model.compile(loss="mean_squared_error", optimizer="adam",metrics=['accuracy'])
early_stopping = EarlyStopping(monitor="acc", mode="auto",patience=0)
# モデル学習 
model.fit(X_train,y_train, batch_size=LENGTH_OF_SEQUENCES, epochs=30, callbacks=[early_stopping])
# モデル保存 
model.save(model_outfilename+".h5")

# モデル読み込み
loadmodel = createQRNNModel()
loadmodel.load_weights(model_infilename+".h5")

もしかするとこのQRNN以外でも自作のレイヤーを含む場合はこのやり方が安全かもしれません。

なぜかうまく動かなくなったら(強制リセット)

自分の場合Googleドライブのマウントでなぜかファイルが見えず、Googleドライブ側からGoogleSDKの接続を強制的にきったりしたら、その後マウントできなくなる症状が発生。にっちもさっちもいかなくなったときに、ランタイムを初期状態に以下の方法で戻しました。

!kill -9 -1

90分アクティブでなければ勝手にリセットされますが、そんなに待ってられないときは上記方法でリセットできます。パッケージ類やら設定は最初からやり直してください。

あとがき

現状自分がハマった部分の解決方法を記載しました。それ以外の部分は多くの先人が記載されていますのでリンクの記載にとどめてます。なにかまたハマって解決できたら追記します。

18
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
14