Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

Metal Performance Shadersで行列乗算

Metal Performance Shadersで行列乗算

Metal Performance Shadersというやつを使うと行列演算等をApple製品のGPUで行えるらしいです。というわけでObjective-Cから行列乗算を呼び出してみました。

行列はrow majorのみ対応のようです(column majorのみ対応なcuBLASとは対照的です)。浮動小数点数の精度は単精度と半精度のみで、倍精度には対応していないようです。

// clang -fobjc-arc mps-test.m -lobjc -framework Metal -framework MetalPerformanceShaders -framework CoreGraphics
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <stdio.h>

int main()
{
    float A[3][3] = {
        {1.0, 2.0, 3.0},
        {2.0, 3.0, 4.0},
        {4.0, 5.0, 6.0},
    };
    float B[3][4] = {
        {1.0, 2.0, 3.0, 2.0},
        {2.0, -2.0, 4.0, 0.0},
        {1.0, 5.0, 6.0, -1.0},
    };
    float C[3][4];
    printf("A: {");
    for (size_t i = 0; i < 3 * 3; ++i) {
        printf("%g, ", (&A[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t j = 0; j < 3; ++j) {
            printf("%g, ", A[i][j]); // 3 * i + j
        }
        puts("");
    }
    puts("---");
    printf("sizeof(B[0]) = %zu\n", sizeof(B[0]));
    printf("B: {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&B[0][0])[i]);
    }
    puts("}");
    for (size_t j = 0; j < 3; ++j) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", B[j][k]); // 4 * j + k
        }
        puts("");
    }
    puts("---");
    @autoreleasepool {
        id <MTLDevice> device = MTLCreateSystemDefaultDevice(); // -framework CoreGraphics をしないと nil が返ってくるので注意
        if (device == nil) {
            puts("MTLCreateSystemDefaultDevice failed");
            return 1;
        }
        printf("device name: %s\n", [[device name] UTF8String]);
        if (!MPSSupportsMTLDevice(device)) {
            puts("Metal Performance Shaders does not support this metal device.");
            return 1;
        }
        id <MTLCommandQueue> commandQueue = [device newCommandQueue];

        // バッファーの用意
        id <MTLBuffer> bufA = [device newBufferWithBytes:A length:sizeof(float) * 3 * 3 options:MTLResourceStorageModeManaged]; // CPUからGPUに渡す用
        id <MTLBuffer> bufB = [device newBufferWithBytes:B length:sizeof(float) * 3 * 4 options:MTLResourceStorageModeManaged]; // CPUからGPUに渡す用
        id <MTLBuffer> bufC = [device newBufferWithLength:sizeof(float) * 3 * 4 options:MTLResourceStorageModeShared]; // 計算完了後にGPUからCPUに渡す用

        // 行列の用意
        MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:3 columns:3 rowBytes:sizeof(float) * 3 dataType:MPSDataTypeFloat32]];
        MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:3 columns:4 rowBytes:sizeof(float) * 4 dataType:MPSDataTypeFloat32]];
        MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:[MPSMatrixDescriptor matrixDescriptorWithRows:3 columns:4 rowBytes:sizeof(float) * 4 dataType:MPSDataTypeFloat32]];

        id <MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
        // MPSMatrixMultiplication *mmul = [[MPSMatrixMultiplication alloc] initWithDevice:device transposeLeft:NO transposeRight:NO resultRows:3 resultColumns:4 interiorColumns:3 alpha:1.0 beta:0.0]; // この辺BLASっぽい
        MPSMatrixMultiplication *mmul = [[MPSMatrixMultiplication alloc] initWithDevice:device resultRows:3 resultColumns:4 interiorColumns:3];
        [mmul encodeToCommandBuffer:commandBuffer leftMatrix:matA rightMatrix:matB resultMatrix:matC];
        [commandBuffer commit];
        [commandBuffer waitUntilCompleted];
        // この時点でCPUへの転送が済んでいて、演算結果を [bufC contents] で読めるはず
        memcpy(C, [bufC contents], sizeof(C));
    }
    // 結果の表示
    puts("---");
    printf("C: {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&C[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", C[i][k]); // 4 * i + k
        }
        puts("");
    }
}

実行結果(Intel Mac):

A: {1, 2, 3, 2, 3, 4, 4, 5, 6, }
1, 2, 3, 
2, 3, 4, 
4, 5, 6, 
---
sizeof(B[0]) = 16
B: {1, 2, 3, 2, 2, -2, 4, 0, 1, 5, 6, -1, }
1, 2, 3, 2, 
2, -2, 4, 0, 
1, 5, 6, -1, 
---
device name: Intel(R) Iris(TM) Plus Graphics
---
C: {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 

実行結果(Apple M1):

A: {1, 2, 3, 2, 3, 4, 4, 5, 6, }
1, 2, 3, 
2, 3, 4, 
4, 5, 6, 
---
sizeof(B[0]) = 16
B: {1, 2, 3, 2, 2, -2, 4, 0, 1, 5, 6, -1, }
1, 2, 3, 2, 
2, -2, 4, 0, 
1, 5, 6, -1, 
---
device name: Apple M1
---
C: {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 

おまけ:BLASでの例

この記事及び HaskellからBLASを叩く の元ネタになったCBLASを使うコードも供養しておきます。

#include <cblas_openblas.h>
#include <stdio.h>
#include <math.h>

int main(void)
{
    double A[3][3] = {
        {1.0, 2.0, 3.0},
        {2.0, 3.0, 4.0},
        {4.0, 5.0, 6.0},
    };
    double B[3][4] = {
        {1.0, 2.0, 3.0, 2.0},
        {2.0, -2.0, 4.0, 0.0},
        {1.0, 5.0, 6.0, -1.0},
    };
    double C[3][4];
    printf("A: {");
    for (size_t i = 0; i < 3 * 3; ++i) {
        printf("%g, ", (&A[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t j = 0; j < 3; ++j) {
            printf("%g, ", A[i][j]); // 3 * i + j
        }
        puts("");
    }
    puts("---");
    printf("sizeof(B[0]) = %zu\n", sizeof(B[0]));
    printf("B: {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&B[0][0])[i]);
    }
    puts("}");
    for (size_t j = 0; j < 3; ++j) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", B[j][k]); // 4 * j + k
        }
        puts("");
    }
    puts("---");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            double x = 0.0;
            for (size_t j = 0; j < 3; ++j) {
                x += A[i][j] * B[j][k];
            }
            C[i][k] = x;
        }
    }
    printf("C (naive): {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&C[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", C[i][k]); // 4 * i + k
        }
        puts("");
    }
    puts("---");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            C[i][k] = NAN;
        }
    }
    cblas_dgemm(CblasRowMajor /* or CblasColMajor */, CblasNoTrans, CblasNoTrans, /* m */ 3, /* n */ 4, /* k */ 3, /* alpha */ 1.0, &A[0][0], /* lda */ 3, &B[0][0], /* ldb */ 4, /* beta */ 0.0, &C[0][0], /* ldc */ 4);
    printf("C (blas): {");
    for (size_t i = 0; i < 3 * 4; ++i) {
        printf("%g, ", (&C[0][0])[i]);
    }
    puts("}");
    for (size_t i = 0; i < 3; ++i) {
        for (size_t k = 0; k < 4; ++k) {
            printf("%g, ", C[i][k]); // 4 * i + k
        }
        puts("");
    }
}

実行結果:

A: {1, 2, 3, 2, 3, 4, 4, 5, 6, }
1, 2, 3, 
2, 3, 4, 
4, 5, 6, 
---
sizeof(B[0]) = 32
B: {1, 2, 3, 2, 2, -2, 4, 0, 1, 5, 6, -1, }
1, 2, 3, 2, 
2, -2, 4, 0, 
1, 5, 6, -1, 
---
C (naive): {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 
---
C (blas): {8, 13, 29, -1, 12, 18, 42, 0, 20, 28, 68, 2, }
8, 13, 29, -1, 
12, 18, 42, 0, 
20, 28, 68, 2, 
mod_poppo
最近は浮動小数点数オタクをやっています。
https://blog.miz-ar.info/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away