はじめに
だいぶ時間が空きました.
前置きはここと同じなので省略.
HalideでデバッグするときにFunc
の中身を覗きたいときに役に立つことに関して書きました.
Halide::Funcの本体と出力
Halideで関数を表すHalide::Func
ですが,その本体はHalide::Internal::Function
にあり,更に関数の定義はargsとvaluesによってなされています
例えばFunc f
でf(x,y)=2*x+y
のとき,argsは{x, y}でvaluesは{2*x+y}です.
argsは純粋定義の場合vector<string>
(純粋定義はHalide::Var
を用いて行われるが,実際は各Var
の固有な名前しか使われていない?),更新定義の場合vector<Halide::Expr>
で,ベクトル要素数=関数の次元数です.
valuesは純粋定義,更新定義ともにvector<Halide::Expr>
で,ベクトル要素数=Funcの出力をTupleと見たときの要素数です.
Halide::Expr
はostreamへ出力可能なので,argsとvaluesを引っ張ってこれれば,Func
の出力ができそうです.
using namespace Halide;
// Halide::Internal::Functionの出力
ostream& operator<<(ostream& os, const Halide::Internal::Function f)
{
if (f.has_pure_definition())
{
// 純粋定義の出力
os << f.name() << "(";
for (int i = 0; i < f.args().size(); i++)
{
os << f.args()[i];
if (i < f.args().size() - 1) os << ", ";
}
os << ")=";
if (f.values().size() > 1) // 関数がTupleの場合
{
os << "Tuple(";
for (int i = 0; i < f.values().size(); i++)
{
os << f.values()[i];
if (i < f.values().size() - 1) os << ", ";
}
os << ")" << endl;
}
else
{
os << f.values()[0] << endl;
}
}
// 更新定義の出力
for (int u = 0; u < f.updates().size(); u++)
{
os << f.name() << "(";
Halide::Internal::Definition u_def = f.update(u);
for (int i = 0; i < u_def.args().size(); i++)
{
os << u_def.args()[i];
if (i < u_def.args().size() - 1) os << ", ";
}
os << ")=";
if (u_def.values().size() > 1)
{
os << "Tuple(";
for (int i = 0; i < u_def.values().size(); i++)
{
os << u_def.values()[i];
if (i < u_def.values().size() - 1) os << ", ";
}
os << ")" << endl;
}
else
{
os << u_def.values()[0] << endl;
}
}
return os;
}
// Halide::Funcの出力
ostream& operator<<(ostream& os, const Halide::Func f)
{
os << f.function();
return os;
}
上のコードでは,Halide::Internal::Function
のostreamへの<<演算子をオーバロードして,出力の本体部分を定義します.
Func
からは.function()
で内部のFunction
にアクセスできるので,そのままos << f.function()
でok.
実験
Var x("x"), y("y");
Func clamped = BoundaryConditions::repeat_edge(src);
RDom r(-rad, 2 * rad + 1, "r");
Func blur_x("blur_x"), blur_y("blur_y"), total("total"), kernel("kernel");
Expr d = -1.f / (2.f * sigma * sigma);
total() = sum(fast_exp((r * r) * d),"total_sum");
kernel(x) = fast_exp((x * x) * d);
blur_x(x, y) = sum(kernel(r) * clamped(x + r, y), "blurx_sum") / total();
blur_y(x, y) = sum(kernel(r) * blur_x(x, y + r), "blury_sum") / total();
cout << "blur_yは..."
cout << blur_y; // 出力
blur_yは...blur_y(x, y)=((float32)blury_sum(x, y)/(float32)total())
ちゃんと出力できました.
入れ子関数も見たい
先ほどの実験ですと,一番外側の_blur_y_だけ出力されて,入れ子になっている_blury_sum_や_total_までは見れません.
そこで,Halide::Internal::populate_environment
を使います.
Halide::Internal::populate_environment
はFunction
に対して,そのFunction
で入れ子になっているFunction
たちのstd::map<std::string, Halide::Internal::Function>
を生成してくれます.
下のコードはFunction
やFunc
について,入れ子になっているFunction
を出力する関数です.
void print_all_ref_functions(Halide::Internal::Function f)
{
map<string, Halide::Internal::Function> env;
Halide::Internal::populate_environment(f, env);
for (auto it = env.begin(); it != env.end(); it++)
{
cout << it->first << "...\n";
cout << it->second;
cout << endl;
}
}
void print_all_ref_funcs(Halide::Func f)
{
print_all_ref_functions(f.function());
}
実験
先ほどの_blur_y_について,入れ子関数まで出力させてみます.
Var x("x"), y("y");
Func clamped = BoundaryConditions::repeat_edge(src);
RDom r(-rad, 2 * rad + 1, "r");
Func blur_x("blur_x"), blur_y("blur_y"), total("total"), kernel("kernel");
Expr d = -1.f / (2.f * sigma * sigma);
total() = sum(fast_exp((r * r) * d),"total_sum");
kernel(x) = fast_exp((x * x) * d);
blur_x(x, y) = sum(kernel(r) * clamped(x + r, y), "blurx_sum") / total();
blur_y(x, y) = sum(kernel(r) * blur_x(x, y + r), "blury_sum") / total();
print_all_ref_funcs(blur_y); // 出力
blur_x...
blur_x(x, y)=((float32)blurx_sum(x, y)/(float32)total())
blur_y...
blur_y(x, y)=((float32)blury_sum(x, y)/(float32)total())
blurx_sum...
blurx_sum(x, y)=0.000000f
blurx_sum(x, y)=((float32)blurx_sum(x, y) + ((float32)kernel(r$x)*(float32)repeat_edge(x + r$x, y)))
blury_sum...
blury_sum(x, y)=0.000000f
blury_sum(x, y)=((float32)blury_sum(x, y) + ((float32)kernel(r$x)*(float32)blur_x(x, y + r$x)))
kernel...
kernel(x)=(let t16 = float32((x*x)) in (let t17 = (float32)floor_f32((t16*-0.055556f)/0.693147f) in (let t18 = ((t16*-0.055556f) - (t17*0.693147f)) in (let t19 = (t18*t18) in (((((((0.013144f*t19) + 0.168739f)*t19) + 1.000000f)*t18) + ((((0.036690f*t19) + 0.499705f)*t19) + 1.000000f))*(float32)reinterpret(shift_left(max(min(int32(t17) + 127, 255), 0), (uint32)23)))))))
lambda_0...
lambda_0(_0, _1)=(float32)b0(_0, _1)
repeat_edge...
repeat_edge(_0, _1)=(float32)lambda_0(max(min(likely(_0), (0 + 512) - 1), 0), max(min(likely(_1), (0 + 512) - 1), 0))
total...
total()=(float32)total_sum()
total_sum...
total_sum()=0.000000f
total_sum()=(let t8 = float32((r$x*r$x)) in (let t9 = (float32)floor_f32((t8*-0.055556f)/0.693147f) in (let t10 = ((t8*-0.055556f) - (t9*0.693147f)) in (let t11 = (t10*t10) in ((float32)total_sum() + (((((((0.013144f*t11) + 0.168739f)*t11) + 1.000000f)*t10) + ((((0.036690f*t11) + 0.499705f)*t11) + 1.000000f))*(float32)reinterpret(shift_left(max(min(int32(t9) + 127, 255), 0), (uint32)23))))))))
こんな感じで,_blur_y_を構成するFunc
を全て吐き出してくれます.
普段sumを使うと,総和の計算をするFunc
はHalide側で自動的に生成され,その参照のExpr
が返されるので,sumはほぼブラックボックス状態ですが,このように引っ張り出してくることも可能です(応用すればsumのスケジューリングも変更可能!).
他にもBoundaryConditions::repeat_edge
やfast_exp
の実装が覗けます(ソースファイル見てもわかるけど...).
最後に
HalideのFunc
の中身を覗く方法を紹介しました.
この方法を用いればブラックボックス状態のものの中身が分かるため,デバッグが捗るかと思います.
追記
久しぶりにHalideのバージョン(13.0.4)上げたら一部削除された関数がありました.
Function
が参照する入れ子関数を全て吐き出すpopulate_environment
ですが,find_transitive_calls
に機能移植されたっぽいです.
使い方は,
env = Halide::Internal::find_transitive_calls(f);
でOK.