LoginSignup
9
8

More than 5 years have passed since last update.

C++でTensorFlow: Opを実装する( 2 )

Last updated at Posted at 2016-03-09

Opの実装2回目

1回目は最も単純な実装をやってみました。
今回は、InputOutput以外に定義できるAttrというものをMyOpに追加してみたいと思います。

前回長々と書いたOpの追加の流れを3行でまとめると、
1. REGISTER_OPマクロでOpを定義する
2. OpKernelクラスを継承、Computeメソッドをオーバーライドしてカーネル部分を実装する
3. REGISTER_KERNEL_BUILDERマクロでカーネルを定義する

これだけです。
正直この3行と、今回のコードを見てもらえれば、前回の内容は見なくても大体わかると思います…

Attrについて

OpはAttr(attributeかな?)というものを持つことができて、用途は二つあり、
1. Input以外の入力
2. Opのポリモーフィズム
です。

InputとAttrの区別

Attr は学習ステップ間を通して変わらない値や、後述するポリモーフィズムのために使用することが主な用途になります。それ以外の値に関しては極力Inputを使用することが推奨されています。
この理由は、Attrはコンストラクタで設定される、つまり、Computation Graphにノードを追加する際に値が設定されるものなので、Inputとして定義することが可能ならば、ステップ毎の変化などに対応できて良いということです。

Attrの定義

AttrREGISTER_OPでOpの設定を定義する際に、Inputなどと一緒にAttr("<name>:<type>")の形式で定義します。
例示した方が早いと思うので、前回のMyOpAttrを追加したものを載せておきます。

REGISTER_OP("My")
    .Attr("op: int32")     // int32の"op"というAttr
    .Input("msg: string")  // stringの入力Tensor
    .Input("a: int32")     // int32の入力Tensor
    .Output("b: int32")    // int32の出力Tensor

カーネル内でのAttrへのアクセス

まずは、MyOpクラスのコンストラクタ内で、メンバ変数にAttrの値を設定します。

// MyOpクラスのコンストラクタ内
// op_はMyOpクラスののメンバ変数
OP_REQUIRES_OK(context,
               context->GetAttr("op", &op_));

こうすることで、Computeメソッド内からAttrの値を、メンバ変数として使用可能

制限付きAttrの定義

あるAttrの取り得る値や、タイプを制限することも可能

  1. 値による制限
.Attr("e: {'apple', 'orange'}")

これでeというAttrの取り得る値をappleorangeに制限することができます。
2. タイプによる制限
tensor typesに記載されているタイプによってAttrの取り得るタイプを制限することができます。

.Attr("t: {int32, float, bool}")

この場合、tint32floatbool

(numbertyperealnumbertypeなど便利なタイプもあるらしい…)

  1. 比較演算子による制限 数値の大小やリストの長さによる制限も可能
.Attr("a: int >= 2") // aはintで2以上でなければならないという制限

// リストbはint32のリストまたはfloatのリストかつ長さは3以上という制限
.Attr("b: list({int32, float}) >= 3") 

Attrのデフォルト値の設定

Attr はデフォルト値を設定することが推奨されていて、
以下はすべてのタイプのデフォルト値の設定法のまとめ
気になったら試してみてください

.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");

グラフ生成時の引数順序

main.cc
Node* My(NodeOut msg, NodeOut a, int64 op, const GraphDefBuilder::Options&

ということなので、Computation Graphを生成する際の引数順序は、(REGISTER_OPで先に定義しても)InputAttrの順でした。


ポリモーフィズム

次は、上で説明したAttrを利用したポリモーフィズムについて

REGISTER_OP("My")
    .Attr("T: {float, int32}") // Tはfloatまたはint32
    .Input("in: T")            // 入力Tensorはfloatまたはint32
    .Output("out: T")          // 出力Tensorはfloatまたはint32

上のようにTというAttrInputOutputの定義の中で使うことで、InputOutputfloatまたはint32であるというような定義が可能
(ちなみにここでも、Tのデフォルトタイプを指定することが推奨される)

このように、複数のタイプを指定した場合、それに応じて、カーネルも複数タイプ定義する必要があります。具体的には、REGISTER_KERNEL_BUILDERでカーネルを定義する際、TypeConstraint<type>("T")の制限を付加して、対応するカーネルを定義します。

REGISTER_KERNEL_BUILDER(Name("My")
                        .Device(DEVICE_CPU)
                        .TypeConstraint<int32>("T"),
                        MyOpInt); //int32に対応する方をインスタンス化する
REGISTER_KERNEL_BUILDER(Name("My")
                        .Device(DEVICE_CPU)
                        .TypeConstraint<float>("T"),
                        MyOpFloat); //floatに対応する方をインスタンス化する

この例では、int32floatの場合で明示的にクラスを分けてますが、もちろんテンプレートクラスで定義することも可能(下の例)

ここで、Nameがどちらの場合も同じなので、Myについてポリモーフィズムが実現可能


Attrを使ったコード例

以下にAttrを追加したコード例を載せておきます。

Opの定義部分

my_op.cc
REGISTER_OP("My")                // Opの名前
    .Attr("T: {int32, float}")   // Tはint32またはfloat
    .Attr("op: int32 = 0")       // opはint32でデフォルト値は0
    .Input("msg: string")        // 一つ目の入力Tensorはstring
    .Input("a: T")               // 二つ目の入力Tensorはint32またはfloat
    .Output("b: T")              // 出力Tensorはint32またはfloat
    .Doc(R"doc(                  // ヘッダファイル生成時のコメント                                                                                                           
MyOp                                                                                                                
)doc");

カーネルの実装部分

my_op.cc
template <typename T>
// OpKernelを継承
class MyOp : public OpKernel {
#define SUM 0
#define MUL 1

public:
    explicit MyOp(OpKernelConstruction* context) : OpKernel(context) {
        // Attrをメンバ変数に書き込む
        OP_REQUIRES_OK(context,
                       context->GetAttr("op", &op_)); 
    } 
    // Computeをオーバーライド
    void Compute(OpKernelContext* context) override { 
        // 一つ目の入力msgを取り出す
        const Tensor& msg_tensor = context->input(0);
        auto msg = msg_tensor.flat<string>();
        std::cout << msg(0) << std::endl;

        // 二つ目の入力aを取り出す
        const Tensor& input_tensor = context->input(1);
        auto input = input_tensor.flat<T>();

        // 出力bを生成する
        Tensor* output_tensor = NULL;
        OP_REQUIRES_OK(context,
                       context->allocate_output(0, input_tensor.shape(), &output_tensor));
        auto output = output_tensor->template flat<T>();

        // Attr opで指定された値が0(SUM)なら入力に2を加算、1(MUL)なら2を乗算
        if (op_ == SUM)
            for (int i = 0; i < input.size(); ++i)
                output(i) = input(i) + 2;
        else if (op_ == MUL)
            for (int i = 0; i < input.size(); ++i)
                output(i) = input(i)*2;
    }
private:
    int op_; // Attr用のメンバ変数
};

カーネルの定義部分

my_op.cc
REGISTER_KERNEL_BUILDER(Name("My") // カーネルの名前
                    .Device(DEVICE_CPU) // 動作デバイスはCPU
                    .TypeConstraint<int32>("T"), // このカーネルをint32に制限
                    MyOp<int32>); // int32の方をインスタンス化
REGISTER_KERNEL_BUILDER(Name("My") // カーネルの名前
                    .Device(DEVICE_CPU) // 動作デバイスはCPU
                    .TypeConstraint<float>("T"), // このカーネルをfloatに制限
                    MyOp<float>); // floatの方をインスタンス化

Computation Graph生成部分

main.cc
 // 一つ目の入力msg用のConstノード
 Node* msg = Const((string)"my op!", b.opts());
 // 二つ目の入力a用のConstノード
 Node* a = Const({3, 2}, b.opts());

 // MyOpノードを追加
 // Attr opは1に設定 (Constノードとかでなくそのまま渡す)
 // このノードの名前はmy_ops_out
 My(msg, a, 1, b.opts().WithName("my_ops_out"));

Session実行部分

main.cc
// 出力の欲しいノードにmy_ops_outを指定
session->Run({},{"my_ops_out"},{}, &outputs);

結果

// Attr op を 0に
my op!
5,4,

// Attr op を 1に
my op!
6,4,

まとめ

前回と今回を通して、TensorFlowに自分のOpノードを追加して、C++でそれを使用することができた!

それより先に、 C++で既存のOpを使って機械学習を実装しろっていう話なので、次回以降がんばりたいと思います

9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8