LoginSignup
12
12

More than 5 years have passed since last update.

PyTorch is 何

Last updated at Posted at 2018-12-04

PyTorch #とは 
どうも、PyTorch Advent Calendar 1日目担当のアンサンブル提案おじさんです。
初日ということで、PyTorchの最速入門を目指す記事になっております。

 PyTorch is 何

PyTorchとは、DeepLearningフレームワークである。

DeepLearningフレームワークとは、ニューラルネットを簡単に構成できるフレームワークのことで、構築から、入出力、そして学習までをトータルでお任せできるソフトウェアになります。有名なプロダクトで言えばTensorFlowやKerasといったものが存在します。 PyTorchには元となるフレームワークTorchの存在があります。 TorchはLua言語で書かれたDeepLearningフレームワークです。そのTorchをPythonで使いたいと開発を始めたのがPyTorchです。

 PyTorch の特徴

PyTorchの特徴は、

・ Defined by Run
・ 自由にグラフ構築
・ 直感的でわかりやすい実装
・ 比較的早い

などがあります。
KerasのDataloader関連で遅すぎてブチギレたことのある人なら、ぜひPyTorchを使ってみることをオススメします

PyTorch is Smart : Defined by Run

PyTorchの一番の特徴にDefined by Runがあります。これはTensor計算をするごとに自動的にモデルを組むことができるというもので、最近のDeepLearningフレームワークでは多く採用されている形になります。このため、コードが非常に読みやすく、Keras並みに直感的なグラフ構築をすることが可能です。

PyTorch is Flexible : 自由にグラフ構築が可能

Kerasでも、自由にグラフ構築が可能でしたが、もちろんPyTorchでもかなり自由にグラフが構築できます。

PyTrochでは、forward関数作ることでモデル出力を定義します。 具体例としては以下な感じ

class FC_model(nn.Module):
    def __init__(self):
        super(FC_model,self).__init__()#ここはお決まりの書き方
        self.FC = nn.Linear(512,1)#重みを学習するためにインスタンスを保持

    def forward(self, x):
        #ここで入出力モデルを定義。
        #入力はxに入り、出力はreturnで返すことでモデルの入出力を定義
        x=self.FC(x)
        return torch.sigmoid(x).squeeze(1)

つまり、ここの部分を見れば、具体的に何を行なっているのかがわかるという訳で、非常に分かりやすいです。
また、このforward関数、なんでも挟むことができて、

    def forward(self, x):
        #ここで入出力モデルを定義。
        #入力はxに入り、出力はreturnで返すことでモデルの入出力を定義
        x=self.FC(x)
        print(x.shape) #(batch_size,1)
        return torch.sigmoid(x).squeeze(1)

printデバッグを挟むことも、if文を組むことだって可能です。

PyTorch is Beautiful : 可読性の高さ

PyTorchはコードの可読性が非常に高いです。
例えば、ResNetで有名なSkip-Connectionであれば

def forward(self,x):

    f_x=self.ResBlock1(x)
    x = f_x + x #connection
    f_x=self.ResBlock2(x)
    x = f_x + x #connection
    f_x=self.ResBlock3(x)
    x = f_x + x #connection

    return x

といった非常に直感的な書き方で書くことができます。アンサンブルだって簡単

def forward(self,x):
    #word2vec 単体
    emb_vec = self.Embedding_Layer(x)
    emb_mean = torch.mean(emb_vec, dim=1).squeeze(1) 
    #文章のword2vec平均を出力 (batch,seq,emb_dim)データなので、dim=1で平均化
    #(batch,1,emb_dim)=>(batch,emb_dim)に圧縮
    after_w2v = self.FC_w2v(emb_mean) #(batch,emb_dim) => (batch, FC_dim)

    #conv1d
    conv_vec = self.Conv1d(emb_vec)#(batch,seq,emb_dim) => (batch,after_conv_dim)
    after_conv = self.FC_conv(conv_vec)#(batch,after_conv_dim) => (batch,FC_dim)


    #lstm
    hidden = self.LSTM_layer(emb_vec) #(batch,seq,emb_dim) => (batch,hid_dim)
    after_lstm = self.FC_lstm(hidden) #(batch,hid_dim) => (batch, FC_dim)


    # Ensemble
    # それぞれの結果をFC_dimに揃えて加算して平均化
    ensemble_vec = (after_w2v + after_conv + after_lstm)/3 #(batch, FC_dim)
    result = self.FC_last(ensemble_vec) #(batch,FC_dim) => (batch, 1)
    return torch.sigmoid(result) 

グラフのコード自体が非常に読みやすいので、実装のイメージがつきやすいのがPyTorchの一番の特徴と言えるでしょう。

 PyTorchを最速で入門してみよう。

これは後ほど!

12
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
12