LoginSignup
7
8

More than 1 year has passed since last update.

[C#] Vector<T>で配列の最大値を高速取得

Last updated at Posted at 2022-02-20

(2023/02/09 追記)

記事を書いた 2022/02/20 時点は.NET6による計測で、既に古い内容となっております。.NET7以降はLINQの一部にもSIMDが適用されるなど状況が変わってきています。
SimdLinq のように、より広範囲にSIMD適用するライブラリも登場し、ますますLINQの強みが増したと言えるでしょう。

Vector<T> について

Vector<T>を使うと、Releaseビルドで実行した場合、JITコンパイル時にCPUにマッチしたSIMD命令を吐き出してくれます。アセンブリ言語やSSE、AVX等のSIMD命令を覚えなくても、その恩恵を受ける事が出来てしまう訳ですね。素晴らしい!

SIMDは、雑に説明するとCPUが1命令で複数データをまとめて処理できるような仕組みです。例えば、int型は32bitですが、256bitのSIMD命令なら、int型のデータ8個を1命令で同時に演算出来ます。データ量が多い場合は、かなりの高速化が見込めそうです。

※注意※
ハードウェアアクセラレーションのサポートされるデータ型と演算子の組み合わせは限定されているので、詳細はMicrosoft Docsを確認してください。また、CPUによってSIMD命令の有無があるので、適切な命令が存在しなかった場合、逆に遅くなる事も有り得ます。

計測準備

Vector<T>での最大値取得用に、下記のような拡張メソッドを作成しました。
※decimalのように、Vector<T>に対応していない型を使用すると例外が発生します

using System;
using System.Numerics;
using System.Runtime.InteropServices;

public static class SpanExtentions
{
    public static T VectorMax<T>(this Span<T> span) where T : struct, IComparable<T>
    {
        var spanVec = new Span<Vector<T>> { };
        T result = span[0];

        // Vector<T>で最大値算出を行う。
        // Vector<T>.Countは、TをCPUのレジスタで同時処理可能な要素数になる。
        if (span.Length >= Vector<T>.Count)
        {
            spanVec = MemoryMarshal.Cast<T, Vector<T>>(span);
            var vecMax = spanVec[0];
            foreach (var vec in spanVec)
            {
                vecMax = Vector.Max(vecMax, vec);
            }

            for (int i = 0; i < Vector<T>.Count; i++)
            {
                if (result.CompareTo(vecMax[i]) < 0)
                {
                    result = vecMax[i];
                }
            }
        }

        // Vector<T>で処理出来なかった余りの要素との最大値を取る。
        // MemoryMarshal.Cast<T, Vector<T>>で取得したSpanは、
        // Vector<T>.Countのサイズで処理できない要素が含まれないので、
        // このような処理が必要になる。
        span = span.Slice(MemoryMarshal.Cast<Vector<T>, T>(spanVec).Length);
        foreach (var value in span)
        {
            if (result.CompareTo(value) < 0)
            {
                result = value;
            }
        }

        return result;
    }
}

次に、BenchMarkDotNetで、ベンチマーク用のクラスを作成します。
ランダムなデータの配列を作成し、
・ForEachMax1(foreach+IComparable<T>判定)
・ForEachMax2(foreach+Comparer<T>判定)
・VectorMax(Vector<T>使用)
・LinqMax(LINQ使用)
をint,long,double型について、それぞれ計測を行いました。

using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Running;
using System;
using System.Linq;
using System.Collections.Generic;

[ShortRunJob]
[MemoryDiagnoser(false)]
public class MaxBenchmark<T> where T : struct, IComparable<T>
{
    public const int N = 1000015;
    private T[] _Array = Array.Empty<T>();

    public void MakeData()
    {
        var rand = new Random();
        _Array = new T[N];
        for (int i = 0; i < _Array.Length; i++)
        {
            _Array[i] = (T)(dynamic)((double)rand.Next() * rand.NextDouble());
        }
    }

    [GlobalSetup]
    public void Setup()
    {
        MakeData();
    }

    [Benchmark]
    public T ForEachMax1()
    {
        T result = _Array[0];
        foreach (var i in _Array)
        {
            if (result.CompareTo(i) < 0)
            {
                result = i;
            }
        }
        return result;
    }

    [Benchmark]
    public T ForEachMax2()
    {
        T result = _Array[0];
        var comparer = Comparer<T>.Default;
        foreach (var i in _Array)
        {
            if (comparer.Compare(result, i) < 0)
            {
                result = i;
            }
        }
        return result;
    }

    [Benchmark]
    public T VectorMax() => _Array.AsSpan().VectorMax();

    [Benchmark]
    public T LinqMax() => _Array.Max();

    public static void Test()
    {
        var b = new MaxBenchmark<T>();
        b.Setup();
        Console.WriteLine(b.ForEachMax1());
        Console.WriteLine(b.ForEachMax2());
        Console.WriteLine(b.VectorMax());
        Console.WriteLine(b.LinqMax());
    }

    public static void Run() => BenchmarkRunner.Run<MaxBenchmark<T>>();
}

public class Program
{
    public static void Main()
    { 
        MaxBenchmark<int>.Run();
        MaxBenchmark<long>.Run();
        MaxBenchmark<double>.Run();
    }
}

計測結果

【計測環境】
BenchmarkDotNet=v0.13.1, OS=Windows 10.0.19043.1526 (21H1/May2021Update)
Intel Core i5-6600K CPU 3.50GHz (Skylake), 1 CPU, 4 logical and 4 physical cores
.NET SDK=6.0.101

【int】LINQ比 50.5倍高速化

Method Mean Error StdDev Allocated
ForEachMax1 1,026.31 μs 39.935 μs 2.189 μs 1 B
ForEachMax2 1,813.47 μs 107.801 μs 5.909 μs 1 B
VectorMax 97.94 μs 8.114 μs 0.445 μs -
LinqMax 4,949.57 μs 2,082.960 μs 114.174 μs 36 B

【long】LINQ比 9.8倍高速化

Method Mean Error StdDev Allocated
ForEachMax1 1,098.0 μs 93.00 μs 5.10 μs 1 B
ForEachMax2 2,073.5 μs 399.85 μs 21.92 μs 2 B
VectorMax 517.2 μs 318.72 μs 17.47 μs -
LinqMax 5,068.5 μs 1,307.75 μs 71.68 μs 36 B

【double】LINQ比 16.9倍高速化

Method Mean Error StdDev Allocated
ForEachMax1 2,658.5 μs 703.3 μs 38.55 μs 2 B
ForEachMax2 2,386.6 μs 634.7 μs 34.79 μs 3 B
VectorMax 383.8 μs 152.3 μs 8.35 μs -
LinqMax 6,499.0 μs 2,896.7 μs 158.78 μs 36 B

型によって結構差が出ています。longだと効果はいまいちでした。
ジェネリック型の比較は、Comparer<T>よりIComparable<T>を使用した方が概ね速いようですが、何故かdoubleだと逆転していますね。
LINQの遅さは、IEnumerableの走査の遅さによるものではないかなと思われます。

LINQは遅いからやめたほうがいい?

このようなサンプルを投稿しておいてなんですが、通常は汎用性とメンテナンス性の面から、記述が簡潔で済むLINQを使用した方が良いでしょう。通常のプログラムでは、このようなちょっとした計算より、ネットワークやファイル・DBアクセス等のIO周りがボトルネックになる事の方が遥かに多い筈です。(そもそもDBから集計済みの結果持ってくる事の方が多いし…)
ハッシュ計算・画像処理のように、データ量に対して単純計算が多い分野では、Vector<T>を使うと抽象化と高速化を両立出来るパターンがありそうなので、速度を重視したライブラリを作る際に役に立つかもしれません。
本当に最適化が必要な場面では、まずボトルネックの分析が重要になります。きちんと処理を切り分けして速度計測を行い、目標の性能を出すために何をするべきか、最適化する対象を見誤ってはいけません。

アセンブリコードを覗く

SharpLabで、実際のアセンブリコードを覗いてみました。

intのアセンブリコード
SpanExtentions.VectorMax[[System.Int32, System.Private.CoreLib]](System.Span`1<Int32>)
    L0000: sub rsp, 0x48
    L0004: vzeroupper
    L0007: xor eax, eax
    L0009: mov [rsp+0x28], rax
    L000e: vxorps xmm4, xmm4, xmm4
    L0012: vmovdqa [rsp+0x30], xmm4
    L0018: mov [rsp+0x40], rax
    L001d: mov rax, [rcx]
    L0020: mov edx, [rcx+8]
    L0023: xor ecx, ecx
    L0025: test edx, edx
    L0027: je L0100
    L002d: mov r8d, [rax]
    L0030: cmp edx, 8
    L0033: jl short L00a2
    L0035: mov ecx, edx
    L0037: shl rcx, 2
    L003b: shr rcx, 5
    L003f: cmp rcx, 0x7fffffff
    L0046: ja L00f5
    L004c: test ecx, ecx
    L004e: je L0100
    L0054: vmovupd ymm0, [rax]
    L0058: xor r9d, r9d
    L005b: test ecx, ecx
    L005d: jle short L0079
    L005f: movsxd r10, r9d
    L0062: shl r10, 5
    L0066: vmovupd ymm1, [rax+r10]
    L006c: vpmaxsd ymm0, ymm0, ymm1
    L0071: inc r9d
    L0074: cmp r9d, ecx
    L0077: jl short L005f
    L0079: xor r9d, r9d
    L007c: vmovupd [rsp+0x28], ymm0
    L0082: mov r10d, [rsp+r9*4+0x28]
    L0087: mov r11d, r10d
    L008a: cmp r8d, r11d
    L008d: jl short L0096
    L008f: cmp r8d, r11d
    L0092: jg short L0099
    L0094: jmp short L0099
    L0096: mov r8d, r10d
    L0099: inc r9d
    L009c: cmp r9d, 8
    L00a0: jl short L007c
    L00a2: mov ecx, ecx
    L00a4: shl rcx, 5
    L00a8: shr rcx, 2
    L00ac: cmp rcx, 0x7fffffff
    L00b3: ja short L00f5
    L00b5: cmp ecx, edx
    L00b7: ja short L00fa
    L00b9: sub edx, ecx
    L00bb: mov ecx, ecx
    L00bd: lea rax, [rax+rcx*4]
    L00c1: xor r9d, r9d
    L00c4: test edx, edx
    L00c6: jle short L00ea
    L00c8: cmp r9d, edx
    L00cb: jae short L0100
    L00cd: movsxd rcx, r9d
    L00d0: mov ecx, [rax+rcx*4]
    L00d3: cmp r8d, ecx
    L00d6: jl short L00df
    L00d8: cmp r8d, ecx
    L00db: jg short L00e2
    L00dd: jmp short L00e2
    L00df: mov r8d, ecx
    L00e2: inc r9d
    L00e5: cmp r9d, edx
    L00e8: jl short L00c8
    L00ea: mov eax, r8d
    L00ed: vzeroupper
    L00f0: add rsp, 0x48
    L00f4: ret
    L00f5: call 0x00007ff93bb8e7b0
    L00fa: call 0x00007ff8dc135778
    L00ff: int3
    L0100: call 0x00007ff93bb8ec10
    L0105: int3

なるほど判らん!重要そうな部分だけ摘んで見ていきます。
vpmaxsdとか、それっぽい名前の命令が出てきましたね。きちんとSIMD命令が出力されているようです。
PMAXSD - Packed MAXimum Signed Dword

longのアセンブリコード
SpanExtentions.VectorMax[[System.Int64, System.Private.CoreLib]](System.Span`1<Int64>)
    L0000: sub rsp, 0x28
    L0004: vzeroupper
    L0007: mov rax, [rcx]
    L000a: mov edx, [rcx+8]
    L000d: xor ecx, ecx
    L000f: test edx, edx
    L0011: je L0144
    L0017: mov r8, [rax]
    L001a: cmp edx, 4
    L001d: jl L00e5
    L0023: mov ecx, edx
    L0025: shl rcx, 3
    L0029: shr rcx, 5
    L002d: cmp rcx, 0x7fffffff
    L0034: ja L0139
    L003a: test ecx, ecx
    L003c: je L0144
    L0042: vmovupd ymm0, [rax]
    L0046: xor r9d, r9d
    L0049: test ecx, ecx
    L004b: jle short L0073
    L004d: movsxd r10, r9d
    L0050: shl r10, 5
    L0054: vmovupd ymm1, [rax+r10]
    L005a: vpcmpgtq ymm2, ymm0, ymm1
    L005f: vpand ymm0, ymm0, ymm2
    L0063: vpandn ymm1, ymm2, ymm1
    L0067: vpor ymm0, ymm0, ymm1
    L006b: inc r9d
    L006e: cmp r9d, ecx
    L0071: jl short L004d
    L0073: vmovaps xmm1, xmm0
    L0077: vmovq r9, xmm1
    L007c: mov r10, r9
    L007f: cmp r8, r10
    L0082: jl short L008b
    L0084: cmp r8, r10
    L0087: jle short L008e
    L0089: jmp short L008e
    L008b: mov r8, r9
    L008e: vmovaps xmm1, xmm0
    L0092: vpextrq r9, xmm1, 1
    L0098: mov r10, r9
    L009b: cmp r8, r10
    L009e: jl short L00a7
    L00a0: cmp r8, r10
    L00a3: jle short L00aa
    L00a5: jmp short L00aa
    L00a7: mov r8, r9
    L00aa: vextractf128 xmm1, ymm0, 1
    L00b0: vmovq r9, xmm1
    L00b5: mov r10, r9
    L00b8: cmp r8, r10
    L00bb: jl short L00c4
    L00bd: cmp r8, r10
    L00c0: jle short L00c7
    L00c2: jmp short L00c7
    L00c4: mov r8, r9
    L00c7: vextractf128 xmm0, ymm0, 1
    L00cd: vpextrq r9, xmm0, 1
    L00d3: mov r10, r9
    L00d6: cmp r8, r10
    L00d9: jl short L00e2
    L00db: cmp r8, r10
    L00de: jle short L00e5
    L00e0: jmp short L00e5
    L00e2: mov r8, r9
    L00e5: mov ecx, ecx
    L00e7: shl rcx, 5
    L00eb: shr rcx, 3
    L00ef: cmp rcx, 0x7fffffff
    L00f6: ja short L0139
    L00f8: cmp ecx, edx
    L00fa: ja short L013e
    L00fc: sub edx, ecx
    L00fe: mov ecx, ecx
    L0100: lea rax, [rax+rcx*8]
    L0104: xor r9d, r9d
    L0107: test edx, edx
    L0109: jle short L012e
    L010b: cmp r9d, edx
    L010e: jae short L0144
    L0110: movsxd rcx, r9d
    L0113: mov rcx, [rax+rcx*8]
    L0117: cmp r8, rcx
    L011a: jl short L0123
    L011c: cmp r8, rcx
    L011f: jle short L0126
    L0121: jmp short L0126
    L0123: mov r8, rcx
    L0126: inc r9d
    L0129: cmp r9d, edx
    L012c: jl short L010b
    L012e: mov rax, r8
    L0131: vzeroupper
    L0134: add rsp, 0x28
    L0138: ret
    L0139: call 0x00007ff93bb8e7b0
    L013e: call 0x00007ff8dc135778
    L0143: int3
    L0144: call 0x00007ff93bb8ec10
    L0149: int3

longだとvpmaxsdが消滅し、vpcmpgtqで比較して、いくつか論理演算をしてますね。intよりやや複雑なコードが出力されているのが伺えます。
PCMPGTQ - Packed CoMPare Greater Than Qword

7
8
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
8