Opの実装2回目
1回目は最も単純な実装をやってみました。
今回は、Input
とOutput
以外に定義できるAttr
というものをMyOp
に追加してみたいと思います。
前回長々と書いたOpの追加の流れを3行でまとめると、
1. REGISTER_OP
マクロでOpを定義する
2. OpKernel
クラスを継承、Compute
メソッドをオーバーライドしてカーネル部分を実装する
3. REGISTER_KERNEL_BUILDER
マクロでカーネルを定義する
これだけです。
正直この3行と、今回のコードを見てもらえれば、前回の内容は見なくても大体わかると思います…
Attrについて
OpはAttr
(attributeかな?)というものを持つことができて、用途は二つあり、
1. Input
以外の入力
2. Opのポリモーフィズム
です。
InputとAttrの区別
Attr
は学習ステップ間を通して変わらない値や、後述するポリモーフィズムのために使用することが主な用途になります。それ以外の値に関しては極力Input
を使用することが推奨されています。
この理由は、Attr
はコンストラクタで設定される、つまり、Computation Graphにノードを追加する際に値が設定されるものなので、Input
として定義することが可能ならば、ステップ毎の変化などに対応できて良いということです。
Attrの定義
Attr
は REGISTER_OP
でOpの設定を定義する際に、Input
などと一緒にAttr("<name>:<type>")
の形式で定義します。
例示した方が早いと思うので、前回のMyOp
にAttr
を追加したものを載せておきます。
REGISTER_OP("My")
.Attr("op: int32") // int32の"op"というAttr
.Input("msg: string") // stringの入力Tensor
.Input("a: int32") // int32の入力Tensor
.Output("b: int32") // int32の出力Tensor
カーネル内でのAttr
へのアクセス
まずは、MyOp
クラスのコンストラクタ内で、メンバ変数にAttr
の値を設定します。
// MyOpクラスのコンストラクタ内
// op_はMyOpクラスののメンバ変数
OP_REQUIRES_OK(context,
context->GetAttr("op", &op_));
こうすることで、Compute
メソッド内からAttr
の値を、メンバ変数として使用可能
制限付きAttrの定義
あるAttr
の取り得る値や、タイプを制限することも可能
- 値による制限
.Attr("e: {'apple', 'orange'}")
これでe
というAttr
の取り得る値をapple
とorange
に制限することができます。
2. タイプによる制限
tensor typesに記載されているタイプによってAttr
の取り得るタイプを制限することができます。
.Attr("t: {int32, float, bool}")
この場合、t
はint32
かfloat
かbool
(numbertype
やrealnumbertype
など便利なタイプもあるらしい…)
- 比較演算子による制限 数値の大小やリストの長さによる制限も可能
.Attr("a: int >= 2") // aはintで2以上でなければならないという制限
// リストbはint32のリストまたはfloatのリストかつ長さは3以上という制限
.Attr("b: list({int32, float}) >= 3")
Attrのデフォルト値の設定
Attr
はデフォルト値を設定することが推奨されていて、
以下はすべてのタイプのデフォルト値の設定法のまとめ
気になったら試してみてください
.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");
グラフ生成時の引数順序
Node* My(NodeOut msg, NodeOut a, int64 op, const GraphDefBuilder::Options&
ということなので、Computation Graphを生成する際の引数順序は、(REGISTER_OP
で先に定義しても)Input
、Attr
の順でした。
ポリモーフィズム
次は、上で説明したAttr
を利用したポリモーフィズムについて
REGISTER_OP("My")
.Attr("T: {float, int32}") // Tはfloatまたはint32
.Input("in: T") // 入力Tensorはfloatまたはint32
.Output("out: T") // 出力Tensorはfloatまたはint32
上のようにT
というAttr
をInput
とOutput
の定義の中で使うことで、Input
とOutput
がfloat
またはint32
であるというような定義が可能
(ちなみにここでも、T
のデフォルトタイプを指定することが推奨される)
このように、複数のタイプを指定した場合、それに応じて、カーネルも複数タイプ定義する必要があります。具体的には、REGISTER_KERNEL_BUILDER
でカーネルを定義する際、TypeConstraint<type>("T")
の制限を付加して、対応するカーネルを定義します。
REGISTER_KERNEL_BUILDER(Name("My")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
MyOpInt); //int32に対応する方をインスタンス化する
REGISTER_KERNEL_BUILDER(Name("My")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
MyOpFloat); //floatに対応する方をインスタンス化する
この例では、int32
とfloat
の場合で明示的にクラスを分けてますが、もちろんテンプレートクラスで定義することも可能(下の例)
ここで、Name
がどちらの場合も同じなので、My
についてポリモーフィズムが実現可能
Attrを使ったコード例
以下にAttr
を追加したコード例を載せておきます。
Opの定義部分
REGISTER_OP("My") // Opの名前
.Attr("T: {int32, float}") // Tはint32またはfloat
.Attr("op: int32 = 0") // opはint32でデフォルト値は0
.Input("msg: string") // 一つ目の入力Tensorはstring
.Input("a: T") // 二つ目の入力Tensorはint32またはfloat
.Output("b: T") // 出力Tensorはint32またはfloat
.Doc(R"doc( // ヘッダファイル生成時のコメント
MyOp
)doc");
カーネルの実装部分
template <typename T>
// OpKernelを継承
class MyOp : public OpKernel {
#define SUM 0
#define MUL 1
public:
explicit MyOp(OpKernelConstruction* context) : OpKernel(context) {
// Attrをメンバ変数に書き込む
OP_REQUIRES_OK(context,
context->GetAttr("op", &op_));
}
// Computeをオーバーライド
void Compute(OpKernelContext* context) override {
// 一つ目の入力msgを取り出す
const Tensor& msg_tensor = context->input(0);
auto msg = msg_tensor.flat<string>();
std::cout << msg(0) << std::endl;
// 二つ目の入力aを取り出す
const Tensor& input_tensor = context->input(1);
auto input = input_tensor.flat<T>();
// 出力bを生成する
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output_tensor));
auto output = output_tensor->template flat<T>();
// Attr opで指定された値が0(SUM)なら入力に2を加算、1(MUL)なら2を乗算
if (op_ == SUM)
for (int i = 0; i < input.size(); ++i)
output(i) = input(i) + 2;
else if (op_ == MUL)
for (int i = 0; i < input.size(); ++i)
output(i) = input(i)*2;
}
private:
int op_; // Attr用のメンバ変数
};
カーネルの定義部分
REGISTER_KERNEL_BUILDER(Name("My") // カーネルの名前
.Device(DEVICE_CPU) // 動作デバイスはCPU
.TypeConstraint<int32>("T"), // このカーネルをint32に制限
MyOp<int32>); // int32の方をインスタンス化
REGISTER_KERNEL_BUILDER(Name("My") // カーネルの名前
.Device(DEVICE_CPU) // 動作デバイスはCPU
.TypeConstraint<float>("T"), // このカーネルをfloatに制限
MyOp<float>); // floatの方をインスタンス化
Computation Graph生成部分
// 一つ目の入力msg用のConstノード
Node* msg = Const((string)"my op!", b.opts());
// 二つ目の入力a用のConstノード
Node* a = Const({3, 2}, b.opts());
// MyOpノードを追加
// Attr opは1に設定 (Constノードとかでなくそのまま渡す)
// このノードの名前はmy_ops_out
My(msg, a, 1, b.opts().WithName("my_ops_out"));
Session実行部分
// 出力の欲しいノードにmy_ops_outを指定
session->Run({},{"my_ops_out"},{}, &outputs);
結果
// Attr op を 0に
my op!
5,4,
// Attr op を 1に
my op!
6,4,
まとめ
前回と今回を通して、TensorFlowに自分のOpノードを追加して、C++でそれを使用することができた!
それより先に、 C++で既存のOpを使って機械学習を実装しろっていう話なので、次回以降がんばりたいと思います