LoginSignup
6
7

More than 5 years have passed since last update.

Genericなコードで四則演算(演算子)ができないのをExpressionTreeでILコードを自動生成してできるようにする方法

Last updated at Posted at 2016-11-15

Genericは大変便利ですが欠点もあります。Genericな変数に対して四則演算などができないという欠点があります。

public static int Sum<T>(IEnamerable<T> list)
{
    var sum = 0;
    foreach (var item in list)
    {
        sum = sum + item;
    }
}

このコードはコンパイルエラーになります。Tにどんな型がくるかわからないので当然といえば当然です。しかしTが+演算子に対応している場合(Int,Decimal,独自のVectorクラスとか)は値を足してその結果を返して欲しいです。対応していない場合(例えば独自のPersonクラスとか)は例外を投げてほしいです。

ExpressionTreeでAdd演算子を動的に作成する

ExpressionTreeを使うと動的にAdd演算を行うことが可能です。任意の型2つをAddするFuncを作成するには以下のように記述します。

public static Func<TArg1, TArg2, TResult> CreateAddExpression<TArg1, TArg2, TResult>()
{
    ParameterExpression leftParameter = Expression.Parameter(typeof(TArg1), "leftParameter");
    ParameterExpression rightParameter = Expression.Parameter(typeof(TArg2), "rightParameter");
    try
    {
        //Func<...> f1 = (leftParameter,rightParameter) => return leftParameter + rightParameter; というメソッドを作成し戻り値として返します。
        return Expression.Lambda<Func<TArg1, TArg2, TResult>>(Expression.Add(leftParameter, rightParameter), leftParameter, rightParameter).Compile();
    }
    catch (InvalidOperationException ex)
    {
        //Func<...> f1 = (_, __) => throw new InvalidOperationException(""); ただ例外を投げるメソッドを作成し戻り値として返します。
        return delegate { throw new InvalidOperationException(ex.Message); };
    }
}

これでAdd演算をGenericに対して実行できるようになりました。TがAdd演算子に対応していない場合は例外が発生します。

作成したメソッドを使うと

public static int Sum<T>(IEnamerable<T> list)
{
    var f = CreateAddExpression<int, int, int>();
    var sum = 0;
    foreach (var item in list)
    {
        sum = f(sum, item);
    }
}

という形でAddができるようになりました。

Subtract演算用のも作ってみます。Expression.AddがExpression.Subtractになるだけなので簡単です。

public static Func<TArg1, TArg2, TResult> CreateSubtractExpression<TArg1, TArg2, TResult>()
{
    ParameterExpression leftParameter = Expression.Parameter(typeof(TArg1), "leftParameter");
    ParameterExpression rightParameter = Expression.Parameter(typeof(TArg2), "rightParameter");
    try
    {
        return Expression.Lambda<Func<TArg1, TArg2, TResult>>(Expression.Subtract(leftParameter, rightParameter), leftParameter, rightParameter).Compile();
    }
    catch (InvalidOperationException ex)
    {
        return delegate { throw new InvalidOperationException(ex.Message); };
    }
}

コードをよく見ると共通化できそうです。Expressionを引数で受け取るようにしてメソッドを共通化します。

public static Func<TArg1, TArg2, TResult> CreateExpression<TArg1, TArg2, TResult>(Func<Expression, Expression, BinaryExpression> body)
{
    ParameterExpression leftParameter = Expression.Parameter(typeof(TArg1), "leftParameter");
    ParameterExpression rightParameter = Expression.Parameter(typeof(TArg2), "rightParameter");
    try
    {
        return Expression.Lambda<Func<TArg1, TArg2, TResult>>(body(leftParameter, rightParameter), leftParameter, rightParameter).Compile();
    }
    catch (InvalidOperationException ex)
    {
        return delegate { throw new InvalidOperationException(ex.Message); };
    }
}

これによりExpressionを渡せばやりたい演算が可能になります。

public static int Sum<T>(IEnamerable<T> list)
{
    var f = CreateExpression<int, int, int>(Expression.Add);//+演算を行う
    var sum = 0;
    foreach (var item in list)
    {
        sum = f(sum, item);
    }
}

次のセクションでは二つの型が違う場合について解説します。

2つの型が違う場合のExpression

2つの型が違う場合のAdd演算としてDateTime型とTimeSpan型のAdd演算があります。

var d1 = new DateTime(2016, 11, 15, 12, 22, 0);
var ts = TimeSpan.FromHours(2);
var d2 = d1 + ts;
//d2 は2016/11/15 14:22:00になっているはず

Expressionを使うと

var d1 = new DateTime(2016, 11, 15, 12, 22, 0);
var ts = TimeSpan.FromHours(2);
var f = CreateExpression<DateTime, TimeSpan, DateTime>(Expression.Add);//+演算を行う
var d2 = f(d1, ts);

という形になります。

今度はInt16とInt32でやってみます。

Int32 x1 = 2;
Int16 x2 = 1;

var f = CreateExpression<Int32, Int16, Int32>(Expression.Add);//+演算を行う
Int32 x3 = f(x1, x2);//例外が発生!

例外が発生します。これはInt32とInt16を足し合わせる演算子が定義されていないためです。しかしInt16からInt32への変換は安全であり、暗黙的なキャストも定義されているのでこれを使用してうまいことやってほしいところです。そこで以下のようにCreateExpressionメソッドを改良します。

public static Func<TArg1, TArg2, TResult> CreateExpression<TArg1, TArg2, TResult>(Func<Expression, Expression, BinaryExpression> body, Boolean canCast)
{
    ParameterExpression leftParameter = Expression.Parameter(typeof(TArg1), "leftParameter");
    ParameterExpression rightParameter = Expression.Parameter(typeof(TArg2), "rightParameter");
    try
    {
        return Expression.Lambda<Func<TArg1, TArg2, TResult>>(body(leftParameter, rightParameter), leftParameter, rightParameter).Compile();
    }
    catch (InvalidOperationException ex)
    {
        if (canCast == false)
        {
            return delegate { throw new InvalidOperationException(ex.Message); };
        }
    }
    try
    {
        if (typeof(TArg1) != typeof(TResult) || typeof(TArg2) != typeof(TResult))
        {
            Expression castLeftParameter = null;
            Expression castRightParameter = null;

            if (typeof(TArg1) == typeof(TResult))
            {
                castLeftParameter = (Expression)leftParameter;
            }
            else
            {
                castLeftParameter = (Expression)Expression.Convert(leftParameter, typeof(TResult));
            }
            if (typeof(TArg2) == typeof(TResult))
            {
                castRightParameter = (Expression)rightParameter;
            }
            else
            {
                castRightParameter = (Expression)Expression.Convert(rightParameter, typeof(TResult));
            }
            return Expression.Lambda<Func<TArg1, TArg2, TResult>>(body(castLeftParameter, castRightParameter)
                , leftParameter, rightParameter).Compile();
        }
    }
    catch (Exception ex)
    {
        return delegate { throw new InvalidOperationException(ex.Message); };
    }
    return delegate { throw new InvalidOperationException(); };
}

これでTResultに変換可能な場合はうまいことやってくれるようになりました。

Expressionの生成処理をキャッシュする

Expression.Lambda<>.CompileによるILコードの動的生成でのFuncの作成はかなり重い処理になり、何回も呼ぶとパフォーマンスが低下します。今回の場合、TArg1, TArg2, TResultごとでキャッシュしておけばよいのでキャッシュ用のクラスを作成します。Dictionaryを使用する方法が思い浮かびますがもっと高速なGenericクラスのstaticプロパティでのキャッシュを行います。

Operator.cs
namespace HigLabo.Core
{
    public class Operator
    {
        public static Boolean HasValue<T>(T value)
        {
            return value != null;
        }
        public static T Negate<T>(T value)
        {
            return Operator<T>.Negate(value);
        }
        public static T Not<T>(T value)
        {
            return Operator<T>.Not(value);
        }
        public static T Or<T>(T value1, T value2)
        {
            return Operator<T>.Or(value1, value2);
        }
        public static T And<T>(T value1, T value2)
        {
            return Operator<T>.And(value1, value2);
        }
        public static T Xor<T>(T value1, T value2)
        {
            return Operator<T>.Xor(value1, value2);
        }
        public static TResult Convert<TFrom, TResult>(TFrom value)
        {
            return Operator<TFrom, TResult>.Convert(value);
        }
        public static T Add<T>(T value1, T value2)
        {
            return Operator<T>.Add(value1, value2);
        }
        public static TArg1 Add<TArg1, TArg2>(TArg1 value1, TArg2 value2)
        {
            return Operator<TArg2, TArg1>.Add(value1, value2);
        }
        public static T Subtract<T>(T value1, T value2)
        {
            return Operator<T>.Subtract(value1, value2);
        }
        public static TArg1 Subtract<TArg1, TArg2>(TArg1 value1, TArg2 value2)
        {
            return Operator<TArg2, TArg1>.Subtract(value1, value2);
        }
        public static T Multiply<T>(T value1, T value2)
        {
            return Operator<T>.Multiply(value1, value2);
        }
        public static TArg1 Multiply<TArg1, TArg2>(TArg1 value1, TArg2 value2)
        {
            return Operator<TArg2, TArg1>.Multiply(value1, value2);
        }
        public static T Divide<T>(T value1, T value2)
        {
            return Operator<T>.Divide(value1, value2);
        }
        public static TArg1 Divide<TArg1, TArg2>(TArg1 value1, TArg2 value2)
        {
            return Operator<TArg2, TArg1>.Divide(value1, value2);
        }
        public static Boolean Equal<T>(T value1, T value2)
        {
            return Operator<T>.Equal(value1, value2);
        }
        public static Boolean NotEqual<T>(T value1, T value2)
        {
            return Operator<T>.NotEqual(value1, value2);
        }
        public static Boolean GreaterThan<T>(T value1, T value2)
        {
            return Operator<T>.GreaterThan(value1, value2);
        }
        public static Boolean LessThan<T>(T value1, T value2)
        {
            return Operator<T>.LessThan(value1, value2);
        }
        public static Boolean GreaterThanOrEqual<T>(T value1, T value2)
        {
            return Operator<T>.GreaterThanOrEqual(value1, value2);
        }
        public static Boolean LessThanOrEqual<T>(T value1, T value2)
        {
            return Operator<T>.LessThanOrEqual(value1, value2);
        }
        public static T Divide<T>(T value, Int32 divisor)
        {
            return Operator<Int32, T>.Divide(value, divisor);
        }

        public static Func<TArg1, TResult> CreateExpression<TArg1, TResult>(Func<Expression, UnaryExpression> body)
        {
            ParameterExpression p = Expression.Parameter(typeof(TArg1), "value");
            try
            {
                return Expression.Lambda<Func<TArg1, TResult>>(body(p), p).Compile();
            }
            catch (Exception ex)
            {
                return delegate { throw new InvalidOperationException(ex.Message); };
            }
        }
        public static Func<TArg1, TArg2, TResult> CreateExpression<TArg1, TArg2, TResult>(Func<Expression, Expression, BinaryExpression> body)
        {
            return CreateExpression<TArg1, TArg2, TResult>(body, false);
        }
        public static Func<TArg1, TArg2, TResult> CreateExpression<TArg1, TArg2, TResult>(Func<Expression, Expression, BinaryExpression> body, Boolean canCast)
        {
            ParameterExpression leftParameter = Expression.Parameter(typeof(TArg1), "leftParameter");
            ParameterExpression rightParameter = Expression.Parameter(typeof(TArg2), "rightParameter");

            try
            {
                return Expression.Lambda<Func<TArg1, TArg2, TResult>>(body(leftParameter, rightParameter), leftParameter, rightParameter).Compile();
            }
            catch (InvalidOperationException ex)
            {
                if (canCast == false)
                {
                    return delegate { throw new InvalidOperationException(ex.Message); };
                }
            }
            try
            {
                if (typeof(TArg1) != typeof(TResult) || typeof(TArg2) != typeof(TResult))
                {
                    Expression castLeftParameter = null;
                    Expression castRightParameter = null;

                    if (typeof(TArg1) == typeof(TResult))
                    {
                        castLeftParameter = (Expression)leftParameter;
                    }
                    else
                    {
                        castLeftParameter = (Expression)Expression.Convert(leftParameter, typeof(TResult));
                    }
                    if (typeof(TArg2) == typeof(TResult))
                    {
                        castRightParameter = (Expression)rightParameter;
                    }
                    else
                    {
                        castRightParameter = (Expression)Expression.Convert(rightParameter, typeof(TResult));
                    }
                    return Expression.Lambda<Func<TArg1, TArg2, TResult>>(body(castLeftParameter, castRightParameter)
                        , leftParameter, rightParameter).Compile();
                }
            }
            catch (Exception ex)
            {
                return delegate { throw new InvalidOperationException(ex.Message); };
            }
            return delegate { throw new InvalidOperationException(); };
        }
    }
    internal static class Operator<T>
    {
        private static readonly Func<T, T> _Negate = null;
        private static readonly Func<T, T> _Not = null;
        private static readonly Func<T, T, T> _Or = null;
        private static readonly Func<T, T, T> _And = null;
        private static readonly Func<T, T, T> _Xor = null;
        private static readonly Func<T, T, T> _Add = null;
        private static readonly Func<T, T, T> _Subtract = null;
        private static readonly Func<T, T, T> _Multiply = null;
        private static readonly Func<T, T, T> _Divide = null;
        private static readonly Func<T, T, Boolean> _Equal = null;
        private static readonly Func<T, T, Boolean> _NotEqual = null;
        private static readonly Func<T, T, Boolean> _GreaterThan = null;
        private static readonly Func<T, T, Boolean> _GreaterThanOrEqual = null;
        private static readonly Func<T, T, Boolean> _LessThan = null;
        private static readonly Func<T, T, Boolean> _LessThanOrEqual = null;

        public static Func<T, T> Negate
        {
            get { return _Negate; }
        }
        public static Func<T, T> Not
        {
            get { return _Not; }
        }
        public static Func<T, T, T> Or
        {
            get { return _Or; }
        }
        public static Func<T, T, T> And
        {
            get { return _And; }
        }
        public static Func<T, T, T> Xor
        {
            get { return _Xor; }
        }

        public static Func<T, T, T> Add
        {
            get { return _Add; }
        }
        public static Func<T, T, T> Subtract
        {
            get { return _Subtract; }
        }
        public static Func<T, T, T> Multiply
        {
            get { return _Multiply; }
        }
        public static Func<T, T, T> Divide
        {
            get { return _Divide; }
        }

        public static Func<T, T, Boolean> Equal
        {
            get { return _Equal; }
        }
        public static Func<T, T, Boolean> NotEqual
        {
            get { return _NotEqual; }
        }
        public static Func<T, T, Boolean> GreaterThan
        {
            get { return _GreaterThan; }
        }
        public static Func<T, T, Boolean> LessThan
        {
            get { return _LessThan; }
        }
        public static Func<T, T, Boolean> GreaterThanOrEqual
        {
            get { return _GreaterThanOrEqual; }
        }
        public static Func<T, T, Boolean> LessThanOrEqual
        {
            get { return _LessThanOrEqual; }
        }

        static Operator()
        {
            _Add = Operator.CreateExpression<T, T, T>(Expression.Add);
            _Subtract = Operator.CreateExpression<T, T, T>(Expression.Subtract);
            _Divide = Operator.CreateExpression<T, T, T>(Expression.Divide);
            _Multiply = Operator.CreateExpression<T, T, T>(Expression.Multiply);

            _GreaterThan = Operator.CreateExpression<T, T, bool>(Expression.GreaterThan);
            _GreaterThanOrEqual = Operator.CreateExpression<T, T, bool>(Expression.GreaterThanOrEqual);
            _LessThan = Operator.CreateExpression<T, T, bool>(Expression.LessThan);
            _LessThanOrEqual = Operator.CreateExpression<T, T, bool>(Expression.LessThanOrEqual);
            _Equal = Operator.CreateExpression<T, T, bool>(Expression.Equal);
            _NotEqual = Operator.CreateExpression<T, T, bool>(Expression.NotEqual);

            _Negate = Operator.CreateExpression<T, T>(Expression.Negate);
            _And = Operator.CreateExpression<T, T, T>(Expression.And);
            _Or = Operator.CreateExpression<T, T, T>(Expression.Or);
            _Not = Operator.CreateExpression<T, T>(Expression.Not);
            _Xor = Operator.CreateExpression<T, T, T>(Expression.ExclusiveOr);
        }
    }
    internal static class Operator<TValue, TResult>
    {
        private static readonly Func<TValue, TResult> _Convert = null;
        private static readonly Func<TResult, TValue, TResult> _Add = null;
        private static readonly Func<TResult, TValue, TResult> _Subtract = null;
        private static readonly Func<TResult, TValue, TResult> _Multiply = null;
        private static readonly Func<TResult, TValue, TResult> _Divide = null;

        static Operator()
        {
            _Convert = Operator.CreateExpression<TValue, TResult>(body => Expression.Convert(body, typeof(TResult)));
            _Add = Operator.CreateExpression<TResult, TValue, TResult>(Expression.Add, true);
            _Subtract = Operator.CreateExpression<TResult, TValue, TResult>(Expression.Subtract, true);
            _Multiply = Operator.CreateExpression<TResult, TValue, TResult>(Expression.Multiply, true);
            _Divide = Operator.CreateExpression<TResult, TValue, TResult>(Expression.Divide, true);
        }

        public static Func<TValue, TResult> Convert
        {
            get { return _Convert; }
        }
        public static Func<TResult, TValue, TResult> Add
        {
            get { return _Add; }
        }
        public static Func<TResult, TValue, TResult> Subtract
        {
            get { return _Subtract; }
        }
        public static Func<TResult, TValue, TResult> Multiply
        {
            get { return _Multiply; }
        }
        public static Func<TResult, TValue, TResult> Divide
        {
            get { return _Divide; }
        }
    }
}

Dictionaryだと実行時にキーの走査が必要ですが、Operator<...>クラスを定義することでコンパイル時に使用する型の分だけ動的にクラスが生成され、そのクラスを直接呼び出すILコードが生成されます。ですのでこっちのほうが高速です。

staticコンストラクタにコンパイル処理を書いています。このクラスへの初回アクセス時に動的にILコードを生成しキャッシュします。使用するにはこんな感じです。

public static int Sum<T>(IEnamerable<T> list)
{
    var sum = 0;
    foreach (var item in list)
    {
        sum = Operator.Add(sum, item);
    }
}
public static Main()
{
    var n1 = new int[] { 1, 2, 3, 4 };
    //1回目にILコードが生成されキャッシュされる。ILコードの生成の分、遅いはず。
    var result1 = Sum<int>(n1);

    var n2 = new int[] { 1, 2, 3, 4, 5 };
    //キャッシュされたFuncが使用される。直接ハードコーディングした場合とほぼ同じ速度が出るはず。
    var result2 = Sum<int>(n2);

    var n3 = new Int64[] { 1, 2, 3, 4 };
    //Int64なのでInt64用のILコードが生成されキャッシュされる。
    var result3 = Sum<int>(n3);
}

これで四則演算するGenericなメソッドを作って活用できるようになりました。

using staticでより使いやすくする

using staticを利用すればより簡潔に記述できます。

using static HigLabo.Core.Operator;

public static int Sum<T>(IEnamerable<T> list)
{
    var sum = 0;
    foreach (var item in list)
    {
        sum = Add(sum, item);
    }
}

Operator.csをコピーすれば動きます。ご利用は自由にどうぞ。
Operator.cs以外は脳内デバッグしかしてないのでうまく動かないよとか認識違いなどありましたらまたコメントください。修正します。

以上です。

6
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
7