1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

C#で小さい複素数の積を大量に計算したかったのでSIMD実装とか試してみた

Last updated at Posted at 2021-10-26

まえがき

SDR(ソフトウェア無線)のデータをハンドリングするに当たり、ある程度高速な複素数の配列の積が必要になった。最初は動作確認用としてSystem.Numerics.Complx(doubleベース)を使用していたが、どの程度早くなるのかを確認するためにいくつか試してみた。

System.Numerics.Complex(doubleベース)とMathNet.Numerics.Complex32(floatベース)、それから整数演算(後述)の3種類を、それぞれ単純なforループとSIMD実装の2種類、計6種類を試した。

速度

同じ環境でも実行するたびに結果が変わる。相対的な尺度として参考程度に。

  • doubleベース
    • for 80million毎秒
    • SIMD 145million毎秒(約1.8倍)
  • floatベース
    • for 79.8million毎秒
    • SIMD 280million毎秒(約3.5倍、double for比約3.5倍)
  • shortベース
    • for 110million毎秒
    • SIMD 565million毎秒(約5.1倍、double for比約7.0倍)

doubleベースとfloatベースのforは概ね同じ速度、doubleベースのSIMD化は2倍弱、floatベースのSIMD化はdoubleベースSIMDの2倍程度、といった感じ。なんとなく納得感のある結果。

shortベースのforはdoubleベースのforに比べて3割増し程度で早い(メモリ帯域幅で有利?)。
shortベースのSIMDはdoubleベースのSIMDに比べて4倍弱早い。shortは16bit、doubleは64bitなので、単純に4倍のデータを同時に計算できるから、妥当な気がする。

トータルでは、doubleベースのforに比べてshortベースのSIMDは7倍程度早くなっている。もう少し劇的な高速化を期待していただけに、少し残念。
古いPCなので、最近のCPU/RAMならもっと早く処理できるはず。実際の処理ではマルチスレッドで処理できる場合もあるので、それだけでも単純に数倍程度早くなる。

ソース

整数演算

byte(オフセットあり)とsbyte(オフセットなし)を受け取り、shortを返すタイプ(具体的な動作はfor実装を参照)。
今回はleftをバイナリファイルのフォーマットに合わせ、rightはLUT参照で基準信号を作りやすいように、という理由でこのような実装になった。

static void Multiply_for(byte[] left, ushort[] right, (short Re, short Im)[] result)
{
    if (left.Length != right.Length * 2 ||
        result.Length != right.Length)
    { throw new Exception(); }

    for (int i = 0; i < right.Length; i++)
    {
        var a = left[i * 2 + 0] - 128;
        var b = left[i * 2 + 1] - 128;
        var c = (sbyte)(right[i]);
        var d = (sbyte)(right[i] >> 8);
        result[i].Re = (short)(a * c - b * d);
        result[i].Im = (short)(a * d + b * c);
    }
}

static void Multiply_SIMD(byte[] left, ushort[] right, (short Re, short Im)[] result)
{
    if (left.Length != right.Length * 2 ||
        result.Length != right.Length)
    { throw new Exception(); }

    unsafe
    {
        fixed (byte* ptr1 = left)
        fixed (ushort* ptr2 = right)
        fixed ((short Re, short Im)* ptr3 = result)
        {
            byte* p1 = ptr1;
            sbyte* p2 = (sbyte*)ptr2;
            short* p3 = (short*)ptr3;

            var sub = Vector256.Create((short)128);

            int i = 0;
            for (; i + 32 <= result.Length * 2; i += 32)
            {
                var a = Avx2.Subtract(Avx2.ConvertToVector256Int16(p1 + i), sub);
                var b = Avx2.Subtract(Avx2.ConvertToVector256Int16(p1 + i + 16), sub);
                var c = Avx2.ConvertToVector256Int16(p2 + i);
                var d = Avx2.ConvertToVector256Int16(p2 + i + 16);

                var foo = Avx2.HorizontalSubtract(
                    Avx2.MultiplyLow(a, c),
                    Avx2.MultiplyLow(b, d));
                var bar = Avx2.HorizontalAdd(
                    Avx2.MultiplyLow(a, Avx2.ShuffleHigh(Avx2.ShuffleLow(c, 0xB1), 0xB1)),
                    Avx2.MultiplyLow(b, Avx2.ShuffleHigh(Avx2.ShuffleLow(d, 0xB1), 0xB1)));
                // control 0xB1: _MM_SHUFFLE(2, 3, 0, 1)

                Avx.Store(p3 + i, Avx2.UnpackLow(foo, bar));
                Avx.Store(p3 + i + 16, Avx2.UnpackHigh(foo, bar));
            }

            for (; i < result.Length * 2; i += 2)
            {
                var a = p1[i + 0] - 128;
                var b = p1[i + 1] - 128;
                var c = p2[i + 0];
                var d = p2[i + 1];
                p3[i + 0] = (short)(a * c - b * d);
                p3[i + 1] = (short)(a * d + b * c);
            }
        }
    }
}

float(MathNet.Numerics.Complex32)

static void Multiply_for(Complex32[] left, Complex32[] right, Complex32[] result)
{
    if (left.Length != right.Length ||
        left.Length != result.Length)
    { throw new Exception(); }

    for (int i = 0; i < result.Length; i++)
    {
        result[i] = left[i] * right[i];
    }
}

static void Multiply_SIMD(Complex32[] left, Complex32[] right, Complex32[] result)
{
    if (left.Length != right.Length ||
        left.Length != result.Length)
    { throw new Exception(); }

    unsafe
    {
        fixed (Complex32* ptr1 = left)
        fixed (Complex32* ptr2 = right)
        fixed (Complex32* ptr3 = result)
        {
            float* p1 = (float*)ptr1;
            float* p2 = (float*)ptr2;
            float* p3 = (float*)ptr3;

            int i = 0;
            for (; i + 16 <= result.Length * 2; i += 16)
            {
                var a = Avx.LoadVector256(p1 + i);
                var b = Avx.LoadVector256(p1 + i + 8);
                var c = Avx.LoadVector256(p2 + i);
                var d = Avx.LoadVector256(p2 + i + 8);

                var foo = Avx.HorizontalSubtract(
                        Avx.Multiply(a, c),
                        Avx.Multiply(b, d));
                var bar = Avx.HorizontalAdd(
                        Avx.Multiply(a, Avx.Shuffle(c, c, 0xB1)),
                        Avx.Multiply(b, Avx.Shuffle(d, d, 0xB1)));
                // control 0xB1: _MM_SHUFFLE(2, 3, 0, 1)

                Avx.Store(p3 + i, Avx.UnpackLow(foo, bar));
                Avx.Store(p3 + i + 8, Avx.UnpackHigh(foo, bar));
            }

            for (; i < result.Length * 2; i += 2)
            {
                var a = p1[i + 0];
                var b = p1[i + 1];
                var c = p2[i + 0];
                var d = p2[i + 1];
                p3[i + 0] = a * c - b * d;
                p3[i + 1] = a * d + b * c;
            }
        }
    }
}

double(System.Numerics.Complex)

static void Multiply_for(Complex[] left, Complex[] right, Complex[] result)
{
    if (left.Length != right.Length ||
        left.Length != result.Length)
    { throw new Exception(); }

    for (int i = 0; i < result.Length; i++)
    {
        result[i] = left[i] * right[i];
    }
}

static void Multiply_SIMD(Complex[] left, Complex[] right, Complex[] result)
{
    if (left.Length != right.Length ||
        left.Length != result.Length)
    { throw new Exception(); }

    unsafe
    {
        fixed (Complex* ptr1 = left)
        fixed (Complex* ptr2 = right)
        fixed (Complex* ptr3 = result)
        {
            var p1 = (double*)ptr1;
            var p2 = (double*)ptr2;
            var p3 = (double*)ptr3;
            int i = 0;

            for (; i + 8 <= result.Length * 2; i += 8)
            {
                var a = Avx.LoadVector256(p1 + i);
                var b = Avx.LoadVector256(p1 + i + 4);
                var c = Avx.LoadVector256(p2 + i);
                var d = Avx.LoadVector256(p2 + i + 4);

                var foo = Avx.HorizontalSubtract(
                    Avx.Multiply(a, c),
                    Avx.Multiply(b, d));
                var bar = Avx.HorizontalAdd(
                    Avx.Multiply(a, Avx.Shuffle(c, c, 5)),
                    Avx.Multiply(b, Avx.Shuffle(d, d, 5)));

                Avx.Store(p3 + i, Avx.UnpackLow(foo, bar));
                Avx.Store(p3 + i + 4, Avx.UnpackHigh(foo, bar));
            }

            for (; i < result.Length * 2; i += 2)
            {
                var a = p1[i + 0];
                var b = p1[i + 1];
                var c = p2[i + 0];
                var d = p2[i + 1];
                p3[i + 0] = a * c - b * d;
                p3[i + 1] = a * d + b * c;
            }
        }
    }
}

おまけ:byte配列→複素数配列

  • 32bit
    • for 98.1million毎秒
    • SIMD 554million毎秒
  • 64bit
    • for 93.9million毎秒
    • SIMD 303million毎秒

forはなぜか複素数の積と同じ程度の速度しか出ない(データ量的にも命令数的にももっと早くて良いはず)。SIMDは複素数の積に比べてもかなり早い。

static void Convert_for(byte[] src, Complex32[] dst)
{
    if (src.Length != dst.Length * 2)
    { throw new Exception(); }

    for (int i = 0; i < dst.Length; i++)
    {
        dst[i] = new Complex32(src[i * 2] - 128, src[i * 2 + 1] - 128);
    }
}

static void Convert_SIMD(byte[] src, Complex32[] dst)
{
    if (src.Length != dst.Length * 2)
    { throw new Exception(); }

    unsafe
    {
        fixed (byte* p1 = src)
        fixed (Complex32* p2 = dst)
        {
            var p3 = (float*)p2;
            Vector256<int> sub = Vector256.Create(128);
            int i = 0;

            for (; i + 8 <= dst.Length * 2; i += 8)
            {
                Avx.Store(p3 + i,
                    Avx.ConvertToVector256Single(
                        Avx2.Subtract(
                            Avx2.ConvertToVector256Int32(p1 + i), sub)));
            }

            for (; i < dst.Length * 2; i += 2)
            {
                p3[i] = p1[i] - 128;
                p3[i + 1] = p1[i + 1] - 128;
            }
        }
    }
}

static void Convert_for(byte[] src, Complex[] dst)
{
    if (src.Length != dst.Length * 2)
    { throw new Exception(); }

    for (int i = 0; i < dst.Length; i++)
    {
        dst[i] = new Complex(src[i * 2] - 128, src[i * 2 + 1] - 128);
    }
}

static void Convert_SIMD(byte[] src, Complex[] dst)
{
    if (src.Length != dst.Length * 2)
    { throw new Exception(); }

    unsafe
    {
        fixed (byte* p1 = src)
        fixed (Complex* p2 = dst)
        {
            var p3 = (double*)p2;
            Vector128<int> sub = Vector128.Create(128);
            int i = 0;

            for (; i + 4 <= dst.Length * 2; i += 4)
            {
                Avx.Store(p3 + i,
                    Avx.ConvertToVector256Double(
                            Sse2.Subtract(
                                Sse41.ConvertToVector128Int32(p1 + i), sub)));
            }

            for (; i < dst.Length * 2; i += 2)
            {
                p3[i] = p1[i] - 128;
                p3[i + 1] = p1[i + 1] - 128;
            }
        }
    }
}
1
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?