juliaで型を安定化するためのテクニックをまとめる。
型が不安定になる場合
function relu(x::Float64)
if x > 0
# Float64
return x
else
# Int64 typeof(0)
return 0
end
end
typeof(relu(2rand() - 1))
上記を実行すると、 Int64 or Float64 になったりして、型が不安定になる。これだと高速化の妨げになってしまうので、型を安定化するためのテクニックが必要になってくる。
実際に、@code_warntype relu(2rand() - 1)
を実行すると、以下の表示が出てきて、赤色の箇所が型が不安定であることを示している。
MethodInstance for Main.var"workspace#227".relu(::Float64)
from relu(x::Float64) in Main.var"workspace#227" at /Users/sakurairihito/.julia/pluto_notebooks/Cute experiment.jl#==#82d6ec5e-ea9a-4817-b973-3d76fc787557:1
Arguments
#self#::Core.Const(Main.var"workspace#227".relu)
x::Float64
Body::Union{Float64, Int64}
1 ─ %1 = (x > 0)::Bool
└── goto #3 if not %1
2 ─ return x
3 ─ return 0
型の安定化
function relu2(x)
if x > 0
# Float64
return x
else
# Int64 typeof(0)
T = typeof(x)
return zero(T)
#return 0.0
#return T
end
end
上記の関数は以下と等価である。
function relu3(x)
if x > 0
# Float64
return x
else
# xを要素とする代数的な演算が入っている集合の加法の単位元を返す。いい感じの0を返す。
return zero(x) #koredemoyoi
end
end
このままでは、stringのような意図していない型を持つ変数まで許してしまう。
引数を制限すると
function relu4(x::Float64)
if x > 0
# Float64
return x
else
return zero(x)
end
end
となるが、relu4(1.0f0)を実行すると、FLoat32に定義されていないため、エラーが起こってしまう。
このまま不便だから、汎用的な型にしたい。そこで、パラメトリックタイプを活用する気が起こる。
function relu8(x::T) where T
if x > 0
# Float64
return x
else
return zero(T) #入力の変数の型を変数だと思ってプログラムの中で使うこともできる。
end
end
常識的なものだけを引数にくるように制限したい、例えば、ストリングは来ないようにしたいので、
# "where T" は、 "where T <: Any"の略
function relu9(x::T) where T <: Real
if x > 0
# Float64
return x
else
return zero(T) #入力の変数の型を変数だと思ってプログラムの中で使うこともできる。
end
end
おまけ
@code_lowered
を使うと、以下の二つの関数は全く同じであることを知ることができる。
relu(x) = x > 0 ? x : 0
@code_lowered relu(1.0)
function relu2(x)
if x > 0
# Float64
return x
else
# Int64 typeof(0)
return 0
end
end
@code_lowered relu2(1.0)
relu1とrelu2は全く同じコードであることが@code_loweredを使えば知ることができる。
ただ、型の情報を知ることができないので、@code_typedを使って知ることが可能になる。
juliaでは、この順番でプログラムの解析が行われている
@code_lowered
@code_typed
@code_llvm
@code_native