LoginSignup
3

More than 3 years have passed since last update.

Metal Performance Shadersで行列乗算

Posted at

Metal Performance Shadersで行列乗算

Metal Performance Shadersというやつを使うと行列演算等をApple製品のGPUで行えるらしいです。というわけでObjective-Cから行列乗算を呼び出してみました。

行列はrow majorのみ対応のようです(column majorのみ対応なcuBLASとは対照的です)。浮動小数点数の精度は単精度と半精度のみで、倍精度には対応していないようです。

// clang -fobjc-arc mps-test.m -lobjc -framework Metal -framework MetalPerformanceShaders -framework CoreGraphics
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <stdio.h>

int main()
{
    float A[3][3] = {
        {1.0, 2.0, 3.0},
        {2.0, 3.0, 4.0},
        {4.0, 5.0, 6.0},
    };
    float B[3][4] = {
        {1.0, 2.0, 3.0, 2.0},
        {2.0, -2.0, 4.0, 0.0},
        {1.0, 5.0, 6.0, -1.0},
    };
    float C[3][4];
    printf("A: {");
    for (size_t i = 0; i < 3 * 3; ++i) {
        printf("%g, ", (&A[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t j = 0; j < 3; ++j) {
            printf("%g, ", A[i][j]); // 3 * i + j
        }
        puts("");
    }
    puts("---");
    printf("sizeof(B[0]) = %zu\n", sizeof(B[0]));
    printf("B: {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&B[0][0])[i]);
    }
    puts("}");
    for (size_t j = 0; j < 3; ++j) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", B[j][k]); // 4 * j + k
        }
        puts("");
    }
    puts("---");
    @autoreleasepool {
        id <MTLDevice> device = MTLCreateSystemDefaultDevice(); // -framework CoreGraphics をしないと nil が返ってくるので注意
        if (device == nil) {
            puts("MTLCreateSystemDefaultDevice failed");
            return 1;
        }
        printf("device name: %s\n", [[device name] UTF8String]);
        if (!MPSSupportsMTLDevice(device)) {
            puts("Metal Performance Shaders does not support this metal device.");
            return 1;
        }
        id <MTLCommandQueue> commandQueue = [device newCommandQueue];

        // バッファーの用意
        id <MTLBuffer> bufA = [device newBufferWithBytes:A length:sizeof(float) * 3 * 3 options:MTLResourceStorageModeManaged]; // CPUからGPUに渡す用
        id <MTLBuffer> bufB = [device newBufferWithBytes:B length:sizeof(float) * 3 * 4 options:MTLResourceStorageModeManaged]; // CPUからGPUに渡す用
        id <MTLBuffer> bufC = [device newBufferWithLength:sizeof(float) * 3 * 4 options:MTLResourceStorageModeShared]; // 計算完了後にGPUからCPUに渡す用

        // 行列の用意
        MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:3 columns:3 rowBytes:sizeof(float) * 3 dataType:MPSDataTypeFloat32]];
        MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:3 columns:4 rowBytes:sizeof(float) * 4 dataType:MPSDataTypeFloat32]];
        MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:3 columns:4 rowBytes:sizeof(float) * 4 dataType:MPSDataTypeFloat32]];

        id <MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
        // MPSMatrixMultiplication *mmul = [[MPSMatrixMultiplication alloc] initWithDevice:device transposeLeft:NO transposeRight:NO resultRows:3 resultColumns:4 interiorColumns:3 alpha:1.0 beta:0.0]; // この辺BLASっぽい
        MPSMatrixMultiplication *mmul = [[MPSMatrixMultiplication alloc] initWithDevice:device resultRows:3 resultColumns:4 interiorColumns:3];
        [mmul encodeToCommandBuffer:commandBuffer leftMatrix:matA rightMatrix:matB resultMatrix:matC];
        [commandBuffer commit];
        [commandBuffer waitUntilCompleted];
        // この時点でCPUへの転送が済んでいて、演算結果を [bufC contents] で読めるはず
        memcpy(C, [bufC contents], sizeof(C));
    }
    // 結果の表示
    puts("---");
    printf("C: {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&C[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", C[i][k]); // 4 * i + k
        }
        puts("");
    }
}

実行結果(Intel Mac):

A: {1, 2, 3, 2, 3, 4, 4, 5, 6, }
1, 2, 3, 
2, 3, 4, 
4, 5, 6, 
---
sizeof(B[0]) = 16
B: {1, 2, 3, 2, 2, -2, 4, 0, 1, 5, 6, -1, }
1, 2, 3, 2, 
2, -2, 4, 0, 
1, 5, 6, -1, 
---
device name: Intel(R) Iris(TM) Plus Graphics
---
C: {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 

実行結果(Apple M1):

A: {1, 2, 3, 2, 3, 4, 4, 5, 6, }
1, 2, 3, 
2, 3, 4, 
4, 5, 6, 
---
sizeof(B[0]) = 16
B: {1, 2, 3, 2, 2, -2, 4, 0, 1, 5, 6, -1, }
1, 2, 3, 2, 
2, -2, 4, 0, 
1, 5, 6, -1, 
---
device name: Apple M1
---
C: {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 

おまけ:BLASでの例

この記事及び HaskellからBLASを叩く の元ネタになったCBLASを使うコードも供養しておきます。

#include <cblas_openblas.h>
#include <stdio.h>
#include <math.h>

int main(void)
{
    double A[3][3] = {
        {1.0, 2.0, 3.0},
        {2.0, 3.0, 4.0},
        {4.0, 5.0, 6.0},
    };
    double B[3][4] = {
        {1.0, 2.0, 3.0, 2.0},
        {2.0, -2.0, 4.0, 0.0},
        {1.0, 5.0, 6.0, -1.0},
    };
    double C[3][4];
    printf("A: {");
    for (size_t i = 0; i < 3 * 3; ++i) {
        printf("%g, ", (&A[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t j = 0; j < 3; ++j) {
            printf("%g, ", A[i][j]); // 3 * i + j
        }
        puts("");
    }
    puts("---");
    printf("sizeof(B[0]) = %zu\n", sizeof(B[0]));
    printf("B: {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&B[0][0])[i]);
    }
    puts("}");
    for (size_t j = 0; j < 3; ++j) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", B[j][k]); // 4 * j + k
        }
        puts("");
    }
    puts("---");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            double x = 0.0;
            for (size_t j = 0; j < 3; ++j) {
                x += A[i][j] * B[j][k];
            }
            C[i][k] = x;
        }
    }
    printf("C (naive): {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&C[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", C[i][k]); // 4 * i + k
        }
        puts("");
    }
    puts("---");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            C[i][k] = NAN;
        }
    }
    cblas_dgemm(CblasRowMajor /* or CblasColMajor */, CblasNoTrans, CblasNoTrans, /* m */ 3, /* n */ 4, /* k */ 3, /* alpha */ 1.0, &A[0][0], /* lda */ 3, &B[0][0], /* ldb */ 4, /* beta */ 0.0, &C[0][0], /* ldc */ 4);
    printf("C (blas): {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&C[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", C[i][k]); // 4 * i + k
        }
        puts("");
    }
}

実行結果:

A: {1, 2, 3, 2, 3, 4, 4, 5, 6, }
1, 2, 3, 
2, 3, 4, 
4, 5, 6, 
---
sizeof(B[0]) = 32
B: {1, 2, 3, 2, 2, -2, 4, 0, 1, 5, 6, -1, }
1, 2, 3, 2, 
2, -2, 4, 0, 
1, 5, 6, -1, 
---
C (naive): {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 
---
C (blas): {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3