LoginSignup
1
1

More than 5 years have passed since last update.

Torch入門5つの簡単なステップ

Posted at

torchを使わないといけないことがあったので、
簡単なところだけ勉強しました。

元記事
http://torch.ch/docs/five-simple-examples.html#_

チュートリアルのソース
https://gist.github.com/EnsekiTT/dd174ada9247a32277d6269c215eb853

さっさと動かしてみたい方はチュートリアルのソースをそのまま実行。
なんとなく最適化されてるなーって確認が可能。

インストールからって方はこっち
http://qiita.com/masataka46/items/41f9fac011590f3cd4f1

もっとつっこんでやりたい場合はこの辺をみてやって見ると良さそう。
http://qiita.com/perrying/items/c53864d1bd2155a0635d


Torchを使い始める、5つの簡単なステップ。
このチュートリアルでは、パッケージtorch

require 'torch'

またはREPL thを使用していること(自動的にそれを必要とする)。

1.正の定符号二次形式を定義する

ここではいくつかのtorch関数に頼っています:

rand() 一様分布から引き出されたテンソルを生成する
t() これはテンソルを転置します(新しいビューを返します)
dot() 2つのテンソルの間にドット積を実行する
eye() 恒等行列を返す
* 行列演算子(行列ベクトルまたは行列行列乗算を実行する)
最初に、ランダムな種が誰にとっても同じであることを確認します

torch.manualSeed(1234)
-- choose a dimension
N = 5

-- create a random NxN matrix
A = torch.rand(N, N)

-- make it symmetric positive
A = A*A:t()

-- make it definite
A:add(0.001, torch.eye(N))

-- add a linear term
b = torch.rand(N)

-- create the quadratic form
function J(x)
   return 0.5*x:dot(A*x)-b:dot(x)
end

関数値(ここではランダムな点)を簡単に印刷することができます:

print(J(torch.rand(N)))

2.正確な最小値を求める

行列を逆算することができます(数値的に最適ではないかもしれません)

xs = torch.inverse(A)*b
print(string.format('J(x^*) = %g', J(xs)))

3.勾配降下による最小値の検索

私たちは、第1の傾きWRT定義するxのをJ(x):

function dJ(x)
  return A*x-b
end

次に、現在の解決策をいくつか定義します。

x = torch.rand(N)

そして、(与えられた学習率で勾配降下を適用lrしばらくの間):

lr = 0.01
for i=1,20000 do
  x = x - dJ(x)*lr
  -- we print the value of the objective function at each iteration
  print(string.format('at iter %d J(x) = %f', i, J(x)))
end

最小値が表示される

...
at iter 19995 J(x) = -3.135664
at iter 19996 J(x) = -3.135664
at iter 19997 J(x) = -3.135665
at iter 19998 J(x) = -3.135665
at iter 19999 J(x) = -3.135665
at iter 20000 J(x) = -3.135666

4.最適化パッケージを使用する

共役グラデーションやLBFGSなどのより高度な最適化手法を使用したいですか?optimパッケージには、その目的のためにそこにあります!まず、インストールする必要があります:

luarocks install optim

ローカル変数に関する単語
実際に、グローバル変数を使用することをお勧めしません。localを至る所で使用します。この例では、すべてをグローバルで定義しているため、インタプリタのコマンドラインでカットアンドペーストすることができます。確かに、ローカルのような定義:

local A = torch.rand(N, N)

インタープリタを実行しているときに現在の入力行に制限されている現在のスコープでのみ使用できます。後続の行はこのローカルにアクセスできません。

Luaの1はでスコープを定義することができますdo...endディレクティブ:

do
   local A = torch.rand(N, N)
   print(A)
end
print(A)

あなたはコマンドラインでこれをカットアンドペーストした場合、最初のprintは、(5×5の行列になります。Aはdo...endの間は定義される )表示されますが、次はnilで表示されない。

上向きの値を持つクロージャを定義する
我々は両方の返し閉鎖定義する必要がありますJ(x)とをdJ(x)。ここでは、との範囲を定義do...endローカル変数は、そのようなことを、nevalと上位値であるJdJ(x)のみ:JdJ(x)それを知っているであろう。スクリプトの中で、一つは持っている必要はないことに注意do...endの範囲のように、スコープを nevalスクリプトファイル(とないコマンドラインのような行の末尾)の終わりまでになります。

do
   local neval = 0
   function JdJ(x)
      local Jx = J(x)
      neval = neval + 1
      print(string.format('after %d evaluations J(x) = %f', neval, Jx))
      return Jx, dJ(x)
   end
end

最適化のトレーニング
パッケージはデフォルトでは読み込まれないため、必要になります:

require 'optim'

共役勾配の状態を最初に定義する:

state = {
   verbose = true,
   maxIter = 100
}

そして訓練する:

x = torch.rand(N)
optim.cg(JdJ, x, state)

次のようなものが表示されます。

after 120 evaluation J(x) = -3.136835
after 121 evaluation J(x) = -3.136836
after 122 evaluation J(x) = -3.136837
after 123 evaluation J(x) = -3.136838
after 124 evaluation J(x) = -3.136840
after 125 evaluation J(x) = -3.136838

5.プロット

プロットはさまざまな方法を紹介します。例えば、iTorchのパッケージを使用することができます。ここでは、gnuplotを使用しようとしています。

luarocks install gnuplot

中間関数の評価を格納する
私たちは、中間関数の評価(これまでの訓練に要した実時間)を格納するように、少し変更します。

evaluations = {}
time = {}
timer = torch.Timer()
neval = 0
function JdJ(x)
   local Jx = J(x)
   neval = neval + 1
   print(string.format('after %d evaluations, J(x) = %f', neval, Jx))
   table.insert(evaluations, Jx)
   table.insert(time, timer:time().real)
   return Jx, dJ(x)
end

これで訓練することができます:

state = {
   verbose = true,
   maxIter = 100
}

x0 = torch.rand(N)
cgx = x0:clone() -- make a copy of x0
timer:reset()
optim.cg(JdJ, cgx, state)

-- we convert the evaluations and time tables to tensors for plotting:
cgtime = torch.Tensor(time)
cgevaluations = torch.Tensor(evaluations)

確率的勾配降下のサポートを追加
optimを使用して、確率的勾配でトレーニングを追加してみましょう:

evaluations = {}
time = {}
neval = 0
state = {
  lr = 0.1
}

-- we start from the same starting point than for CG
x = x0:clone()

-- reset the timer!
timer:reset()

-- note that SGD optimizer requires us to do the loop
for i=1,1000 do
  optim.sgd(JdJ, x, state)
  table.insert(evaluations, Jx)
end

sgdtime = torch.Tensor(time)
sgdevaluations = torch.Tensor(evaluations)

最終プロット
グラフをプロットすることができます。最初の単純なアプローチは、使用することですgnuplot.plot(x, y)。gnuplot.figure()のプロットは異なる図面上にあることを確認します。

require 'gnuplot'
gnuplot.figure(1)
gnuplot.title('CG loss minimisation over time')
gnuplot.plot(cgtime, cgevaluations)

gnuplot.figure(2)
gnuplot.title('SGD loss minimisation over time')
gnuplot.plot(sgdtime, sgdevaluations)

同じグラフ上のすべてをプロットするより高度な方法は、次のようになります。ここではすべてをPNGファイルに保存します。

gnuplot.pngfigure('plot.png')
gnuplot.plot(
   {'CG',  cgtime,  cgevaluations,  '-'},
   {'SGD', sgdtime, sgdevaluations, '-'})
gnuplot.xlabel('time (s)')
gnuplot.ylabel('J(x)')
gnuplot.plotflush()
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1