LoginSignup
10
9

More than 5 years have passed since last update.

Common Lisp で行列計算するブログを追試してみた

Last updated at Posted at 2016-02-29

ついでにアンローリングの自動チューニングもしてみた。
追記: unroll = 1 のときに不要なオーバーヘッドがあって、おかげで手で書いた(1 1 8)よりも遅くなっていた。また、探索において生成された親だけではなくすべての親を考慮することにした。再実験したら別の結果が得られた。

http://d.hatena.ne.jp/masatoi/20160204/1454519281
http://keens.github.io/blog/2016/02/04/common_lispdekousokugyouretsuenzan/
http://gos-k.hatenablog.com/entry/2016/02/27/135958

追試

ベンチマーク用の道具

(ql:quickload :alexandria)
(use-package :alexandria)

(defun make-matrix (rows cols)
  (make-array (list rows cols) :element-type 'single-float))

(deftype matrix (&optional a b)
  `(simple-array single-float (,a ,b)))

(defmacro benchmark ((times &optional (times2 times)) &body body)
  (once-only (times times2)
    (with-gensyms (i start end)
      `(progn
         (dotimes (,i ,times)
           ,@body)
         (let ((,start (get-internal-run-time)))
           (time
            (dotimes (,i ,times2)
              ,@body))
           (let ((,end (get-internal-run-time)))
             (print (float (/ (- ,end ,start)
                              internal-time-units-per-second)))))))))

データ

(defparameter *ma* (make-matrix 500 500))
(defparameter *mb* (make-matrix 500 500))
(defparameter *mc* (make-matrix 500 500))

simple-gemm 。配列確保の時間を入れたくなかったので,結果は引数mcに破壊的に代入するようにした。

(defun simple-gemm (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0)))
  (declare (type matrix ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (col cols)
        (dotimes (k cols)
          (incf (aref mc row col)
                (* (aref ma row k) (aref mb k col))))))
    mc))

(benchmark (10)
  (simple-gemm *ma* *mb* *mc*)
  ;; Evaluation took:
  ;;   7.020 seconds of real time
  ;;   7.020000 seconds of total run time (7.020000 user, 0.000000 system)
  ;;   100.00% CPU
  ;;   21,061,255,674 processor cycles
  ;;   65,536 bytes consed
  )

gos-k さんの row-major-gemm, disassemble 見てチューニングされてる。
レジスタでキャッシュも適用済み。

(defun row-major-gemm (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type matrix ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (col cols)
        (let ((cell (aref mc row col))
              (ma-index (array-row-major-index ma row 0))
              (mb-index (array-row-major-index mb 0 col)))
          (declare (type (single-float) cell))
          (declare (type fixnum ma-index mb-index))
          (dotimes (k cols)
            (incf cell (* (row-major-aref ma ma-index)
                          (row-major-aref mb mb-index)))
            (incf ma-index)
            (incf mb-index cols))
          (setf (aref mc row col) cell))))
    mc))

(benchmark (10)
  ;; Evaluation took:
  ;;   3.115 seconds of real time
  ;;   3.116000 seconds of total run time (3.116000 user, 0.000000 system)
  ;;   100.03% CPU
  ;;   9,347,123,164 processor cycles
  ;;   55,072 bytes consed
  (row-major-gemm *ma* *mb* *mc*))

配列サイズは固定なので定数として入れてみた。もうちょっと早くなる

(defun rm-gemm+static-size (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type (matrix 500 500) ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (col cols)
        (let ((cell (aref mc row col))
              (ma-index (array-row-major-index ma row 0))
              (mb-index (array-row-major-index mb 0 col)))
          (declare (type (single-float) cell))
          (declare (type fixnum ma-index mb-index))
          (dotimes (k cols)
            (incf cell (* (row-major-aref ma ma-index)
                          (row-major-aref mb mb-index)))
            (incf ma-index)
            (incf mb-index cols))
          (setf (aref mc row col) cell))))
    mc))

(benchmark (10)
  (rm-gemm+static-size *ma* *mb* *mc*)
  ;; Evaluation took:
  ;;   2.773 seconds of real time
  ;;   2.772000 seconds of total run time (2.772000 user, 0.000000 system)
  ;;   99.96% CPU
  ;;   8,318,453,469 processor cycles
  ;;   0 bytes consed
  )

ループアンローリング。

(defmacro dotimes-unroll ((i n unroll) &body body)
  (check-type i symbol)
  (assert (and (constantp unroll) (numberp unroll)))
  (once-only (n)
    `(locally
       (declare (fixnum ,n))
       (do ((,i 0))
           ((< ,n (the fixnum (+ ,unroll ,i)))
            (do ((,i ,i (the fixnum (1+ ,i))))
                ((< ,n (the fixnum (1+ ,i))))
              ,@body))
         (declare (fixnum ,i))
         ,@(loop :repeat unroll :append (append body `((incf ,i))))))))

(dotimes-unroll (i 10 3)
  (format t "~%~a" i))

row-major-gemm にアンロールをかけて実験. 8 がいちばんよかった。

(defun rm-gemm+static-size+unroll (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type (matrix 500 500) ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (col cols)
        (let ((cell (aref mc row col))
              (ma-index (array-row-major-index ma row 0))
              (mb-index (array-row-major-index mb 0 col)))
          (declare (type (single-float) cell))
          (declare (type fixnum ma-index mb-index))
          (dotimes-unroll (k cols 8)
            ;; best so far, on AMD Phenom(tm) II X6 1075T Processor
            (incf cell (* (row-major-aref ma ma-index)
                          (row-major-aref mb mb-index)))
            (incf ma-index)
            (incf mb-index cols))
          (setf (aref mc row col) cell))))
    mc))

(benchmark (10)
  (rm-gemm+static-size+unroll *ma* *mb* *mc*)
  ;; Evaluation took:
  ;;   2.203 seconds of real time
  ;;   2.200000 seconds of total run time (2.200000 user, 0.000000 system)
  ;;   99.86% CPU
  ;;   6,609,926,012 processor cycles
  ;;   0 bytes consed
  )

自動チューニング

さて、ここで自分の出番だ。多段アンローリングをかけたコードを生成し、コンパイル。関数オブジェクトが返ってくる。アンローリングのコードは間違ってるかもしれない。

(ql:quickload :priority-queue)
(use-package :priority-queue)
(ql:quickload :iterate)
(use-package :iterate)

(defun make-unroll-gemm (x y z)
  (check-type x integer)
  (check-type y integer)
  (check-type z integer)
  (compile nil
           `(lambda (ma mb mc)
              (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
              (declare (type (matrix 500 500) ma mb mc))
              (let ((rows (array-dimension ma 0))
                    (cols (array-dimension mb 1)))
                (declare (type fixnum rows cols))
                (dotimes-unroll (row rows ,x)
                  (dotimes-unroll (col cols ,y)
                    (let ((cell (aref mc row col))
                          (ma-index (array-row-major-index ma row 0))
                          (mb-index (array-row-major-index mb 0 col)))
                      (declare (type (single-float) cell))
                      (declare (type fixnum ma-index mb-index))
                      (dotimes-unroll (k cols ,z)
                        (incf cell (* (row-major-aref ma ma-index)
                                      (row-major-aref mb mb-index)))
                        (incf ma-index)
                        (incf mb-index cols))
                      (setf (aref mc row col) cell))))
                mc))))

性能評価。秒数がfloatで返ってくる。

(defun evaluate-unrolling (x y z)
  (let ((f (make-unroll-gemm x y z)))
    (benchmark (20) (funcall f *ma* *mb* *mc*))))

パラメータ空間を幅優先で探索する。priority queue からパラメータを取り出して、
1,2,3 段目の何れかの段数を倍に。結果が改善されていればqueueに突っ込む。

(defun search-best-unrolling ()
  (let ((q (make-pqueue #'< :key-type 'float :value-type 'list))
        (close nil))
    (format t "~&testing ~a ..." '(1 1 1))
    (let* ((f (make-unroll-gemm 1 1 1))
           (basetime (benchmark (10) (funcall f *ma* *mb* *mc*))))
      (format t " ~a (sec). " basetime)
      (pqueue-push '(1 1 1) basetime q))
    (iter (until (pqueue-empty-p q))
          (for (values parameters time) = (pqueue-pop q))
          (finding (cons parameters time) minimizing time)
          (for (x y z) = parameters)
          (iter (for new-parameters in (list (list x y (* 2 z))
                                             (list x (* 2 y) z)
                                             (list (* 2 x) y z)))
                (when (member parameters close :test #'equal)
                  ;; duplicate detection
                  (next-iteration))
                (format t "~&testing ~a ..." new-parameters)
                (for newtime = (apply #'evaluate-unrolling new-parameters))
                (format t " ~a (sec). " newtime)
                (if (< newtime time)
                    (progn
                      (format t "Improved from the results of ~a: ~a." parameters time)
                      (pqueue-push new-parameters newtime q))
                    (push new-parameters close))))))

結果 (不当)

まあ順当か。

CL-USER> (search-best-unrolling)

testing (1 1 1) ... 2.792 (sec). 
testing (1 1 2) ... 2.632 (sec). Improved from the results of (1 1 1): 2.792.
testing (1 2 1) ... 2.772 (sec). Improved from the results of (1 1 1): 2.792.
testing (2 1 1) ... 2.768 (sec). Improved from the results of (1 1 1): 2.792.
testing (1 1 4) ... 2.548 (sec). Improved from the results of (1 1 2): 2.632.
testing (1 2 2) ... 2.632 (sec). 
testing (2 1 2) ... 2.636 (sec). 
testing (1 1 8) ... 2.496 (sec). Improved from the results of (1 1 4): 2.548.
testing (1 2 4) ... 2.544 (sec). Improved from the results of (1 1 4): 2.548.
testing (2 1 4) ... 2.548 (sec). 
testing (1 1 16) ... 2.5 (sec). 
testing (1 2 8) ... 2.492 (sec). Improved from the results of (1 1 8): 2.496.
testing (2 1 8) ... 2.496 (sec). 
testing (1 2 16) ... 2.496 (sec). 
testing (1 4 8) ... 2.496 (sec). 
testing (2 2 8) ... 2.496 (sec). 
testing (1 2 8) ... 2.496 (sec). Improved from the results of (1 2 4): 2.544.
testing (1 4 4) ... 2.548 (sec). 
testing (2 2 4) ... 2.548 (sec). 
testing (1 2 16) ... 2.504 (sec). 
testing (1 4 8) ... 2.492 (sec). Improved from the results of (1 2 8): 2.496.
testing (2 2 8) ... 2.496 (sec). 
testing (2 1 2) ... 2.632 (sec). Improved from the results of (2 1 1): 2.768.
testing (2 2 1) ... 2.792 (sec). 
testing (4 1 1) ... 2.784 (sec). 
testing (1 2 2) ... 2.636 (sec). Improved from the results of (1 2 1): 2.772.
testing (1 4 1) ... 2.784 (sec). 
testing (2 2 1) ... 2.796 (sec). 
((1 2 8) . 2.492)

間違い発見

アンローリングのコードで、 unroll = 1 のときに不要なオーバーヘッドがあった。
手で書いた (1 1 8) は 2.203 sec だったのに、上の実験での (1 1 8) は 2.496 sec もかかっている。そこで、1のときには普通のdotimeに戻すようにした。

(defmacro dotimes-unroll ((i n unroll) &body body)
  (check-type i symbol)
  (assert (and (constantp unroll) (numberp unroll)))
  (if (= 1 unroll)
      `(dotimes (,i ,n)
         ,@body)
      (once-only (n)
        `(locally
             (declare (fixnum ,n))
           (do ((,i 0))
               ((< ,n (the fixnum (+ ,unroll ,i)))
                (do ((,i ,i (the fixnum (1+ ,i))))
                    ((< ,n (the fixnum (1+ ,i))))
                  ,@body))
             (declare (fixnum ,i))
             ,@(loop :repeat unroll :append (append body `((incf ,i)))))))))

また、直接生成された親だけではなく、ほかの親も考慮することによって枝刈りを強くした。例えば, (1 2 4) から (2 2 4) を生成した時、 (1 2 4) のスコアだけではなく (2 1 4), (2 2 2) らのスコアよりも改善されないといけないようにした。

(defun children (parameters)
  (destructuring-bind (x y z) parameters
    (list (list x y (* z 2))
          (list x (* y 2) z)
          (list (* x 2) y z))))

(defun parents (parameters)
  (destructuring-bind (x y z) parameters
    (remove-if-not (lambda (parameters)
                     (every #'integerp parameters))
                   (list (list x y (/ z 2))
                         (list x (/ y 2) z)
                         (list (/ x 2) y z)))))

(defun search-best-unrolling ()
  (let ((q (make-pqueue #'< :key-type 'float :value-type 'list))
        (close nil))
    (format t "~&testing ~a ..." '(1 1 1))
    (let* ((f (make-unroll-gemm 1 1 1))
           (basetime (benchmark (10) (funcall f *ma* *mb* *mc*))))
      (format t " ~a (sec). " basetime)
      (push (cons '(1 1 1) basetime) close)
      (pqueue-push '(1 1 1) basetime q))
    (iter (until (pqueue-empty-p q))
          (for (values parameters time) = (pqueue-pop q))
          (finding (cons parameters time) minimizing time)
          (iter (for new-parameters in (children parameters)) ; parameters : 直接の親
                (when (member new-parameters close :key #'car :test #'equal)
                  ;; duplicate detection
                  (next-iteration))
                (format t "~&testing ~a ..." new-parameters)
                (for newtime = (apply #'evaluate-unrolling new-parameters))
                (format t " ~a (sec). " newtime)
                (push (cons new-parameters newtime) close)
                (for (time . best-parent) = ; 一番良い親についてチェックすることにした。
                     (iter (for parent in (parents new-parameters))
                           (for time = (cdr (assoc parent close :test #'equal)))
                           (when time
                             (finding (cons time parent) minimizing time))))
                (when (< newtime time)
                  (format t "Improved from the best result by parent ~a: ~a." best-parent time)
                  (pqueue-push new-parameters newtime q))))))

結果

早くなった。いちばんいいアンロールパラメータは同じだった。

CL-USER> (search-best-unrolling)
testing (1 1 1) ... 2.528 (sec). 
testing (1 1 2) ... 2.348 (sec). Improved from the best result by parent (1 1 1): 2.528.
testing (1 2 1) ... 2.524 (sec). Improved from the best result by parent (1 1 1): 2.528.
testing (2 1 1) ... 2.532 (sec). 
testing (1 1 4) ... 2.248 (sec). Improved from the best result by parent (1 1 2): 2.348.
testing (1 2 2) ... 2.352 (sec). 
testing (2 1 2) ... 2.352 (sec). 
testing (1 1 8) ... 2.212 (sec). Improved from the best result by parent (1 1 4): 2.248.
testing (1 2 4) ... 2.248 (sec). 
testing (2 1 4) ... 2.248 (sec). 
testing (1 1 16) ... 2.228 (sec). 
testing (1 2 8) ... 2.204 (sec). Improved from the best result by parent (1 1 8): 2.212.
testing (2 1 8) ... 2.208 (sec). Improved from the best result by parent (1 1 8): 2.212.
testing (1 2 16) ... 2.228 (sec). 
testing (1 4 8) ... 2.224 (sec). 
testing (2 2 8) ... 2.28 (sec). 
testing (2 1 16) ... 2.328 (sec). 
testing (4 1 8) ... 2.24 (sec). 
testing (1 4 1) ... 2.576 (sec). 
testing (2 2 1) ... 2.58 (sec). 
((1 2 8) . 2.204)
10
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
9