2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DNN の量子化したモデルってオーバーフローしないの?

Posted at

DNN の計算効率を向上させるための技術として、量子化というものがあります。
量子化によって、モデルのウェイトなどを浮動小数点から整数に変換することで、メモリ使用量や計算速度が改善されます。
浮動小数点数から整数への変換方法は、https://tech.retrieva.jp/entry/20220128 などが参考になります。

これで量子化したモデルを作ることは出来そうですが、どうやって利用するのでしょうか?ところで低ビットの整数計算はオーバーフローしませんか?

TL;DR

CUDA には8bit整数を受け取り32bit整数を返す関数があり、オーバーフローは発生しない。

量子化済みモデルの計算

量子化とは浮動小数点数のウェイトなどを低ビットの整数で近似させます。
例えば 1

\begin{pmatrix}
0.050332898 &	-0.371847316 \\
0.03545045 & 0.851365483
\end{pmatrix}
\begin{pmatrix}
0.807183607 \\
-0.131675099
\end{pmatrix}

の代わりに

\frac{0.851365483}{127.0}
\frac{0.807183607}{127.0}
\bigg[
\begin{pmatrix}
8 &	-55 \\
5 & 127
\end{pmatrix}
\begin{pmatrix}
127 \\
-21
\end{pmatrix}
\bigg]

を計算しても近似した結果が得られるというわけです、オーバーフローが起きなければ

仮説

  • 計算するたびに浮動小数点数に戻してるよ
    • キャスト遅くない?
  • 高ビットの整数に昇格させてから計算してるよ
    • やっぱり遅くない?
  • 実はオーバーフロー発生してるけどモデル自身のロバスト性で何とかなってるよ
    • モデルにとって重要な、大きな数値ほど発生しやすいから無理では?

調査

巷に存在するランタイムはどのように計算しているのかを調べてみます。
今回はONNXランタイムを調べていました。
ONNX では行列積 (=全結合) は GEMM という名前なので検索してみたところ、どうやらそれっぽい関数がありました。

Status GemmInt8(int m, int n, int k,
                int32_t alpha, int32_t beta,
                const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc,
                const CudaKernel* cuda_kernel, onnxruntime::Stream* ort_stream) {
  ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
  ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null");
  ORT_ENFORCE(ort_stream != nullptr, "Cuda kernel must have the stream instance");

  cudaStream_t stream = static_cast<cudaStream_t>(ort_stream->GetHandle());

  // pad A and B to make their leading dimension be multiples of 32
  // because cublasGemmEx requires:
  // 1. leading dimension is multiples of 4
  // 2. A, B is 32-bit aligned

  constexpr int mask = 0x1F;
  int lda_aligned = lda;
  IAllocatorUniquePtr<int8_t> a_padded;
  if ((mask & lda_aligned) != 0) {
    lda_aligned = roundoff(lda, 32);
    a_padded = cuda_kernel->GetScratchBuffer<int8_t>(SafeInt<size_t>(m) * lda_aligned, ort_stream);
    cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream);
  }

  int ldb_aligned = ldb;
  IAllocatorUniquePtr<int8_t> b_padded;
  if ((mask & ldb_aligned) != 0) {
    ldb_aligned = roundoff(ldb, 32);
    b_padded = cuda_kernel->GetScratchBuffer<int8_t>(SafeInt<size_t>(k) * ldb_aligned, ort_stream);
    cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream);
  }

  auto* ort_cuda_stream = dynamic_cast<CudaStream*>(ort_stream);
  auto cublas = ort_cuda_stream->cublas_handle_;

  CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
      cublas,
      CUBLAS_OP_N, CUBLAS_OP_N,
      n, m, k,
      &alpha,
      ldb_aligned == ldb ? b : b_padded.get(), CUDA_R_8I, ldb_aligned,
      lda_aligned == lda ? a : a_padded.get(), CUDA_R_8I, lda_aligned,
      &beta,
      c, CUDA_R_32I, ldc, CUDA_R_32I,
      CUBLAS_GEMM_DFALT));
  return Status::OK();
}

おや、入力 (a, b) は int8_t* (8ビット整数の配列) ですが、入出力 (c) は int32_t* (32ビット整数の配列) になっていますね。
そして cublasGemmEx が最終的に呼び出す CUDA の関数のようです。

cublasStatus_t cublasGemmEx(cublasHandle_t handle,
                           cublasOperation_t transa,
                           cublasOperation_t transb,
                           int m,
                           int n,
                           int k,
                           const void     *alpha,
                           const void     *A,
                           cudaDataType   Atype,
                           int lda,
                           const void     *B,
                           cudaDataType   Btype,
                           int ldb,
                           const void     *beta,
                           void           *C,
                           cudaDataType   Ctype,
                           int ldc,
                           cudaDataType   computeType,
                           cublasGemmAlgo_t algo)

どうやら cublasGemmEx は、入力8bit、出力32bitということが出来るようです。

今回調べることが出来たのはここまででした。
GPU内部に8bit整数を受け取り32bit整数を返す命令があれば、省メモリと速度を両立させることが出来るでしょう。

32bit なら大丈夫なの?

行列積の要素の値は、入力ベクトルの次元が $ N $ のとき $ [N \times 127 \times (-128), N \times (-128) \times (-128)] $ の範囲に収まります。$ N = 132,104 $ までは最悪ケースであっても大丈夫です。

  1. 量子化の手法にはいろいろあり、そのうちの一例です

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?