LoginSignup
5
4

More than 5 years have passed since last update.

Common Lisp で行列計算するブログを追試してみた (その2)

Last updated at Posted at 2016-02-29

http://qiita.com/guicho271828/items/60236e32871b9eb610c7 の続き。

そもそも論で、たしか内積型の配列アクセスは投機的メモリ転送云々で遅いんだった。(p.46、 「ikj, jkiループによる実現」。) 列方向アクセスにしないといけないような気がするのでやってみた (これってはたしてlispでも正しいのか?)

一応、元の

(defun simple-gemm (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0)))
  (declare (type matrix ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (col cols)
        (dotimes (k cols)
          (incf (aref mc row col)
                (* (aref ma row k) (aref mb k col))))))
    mc))
(benchmark (10)
  (simple-gemm *ma* *mb* *mc*)
  ;; Evaluation took:
  ;;   7.020 seconds of real time
  ;;   7.020000 seconds of total run time (7.020000 user, 0.000000 system)
  ;;   100.00% CPU
  ;;   21,061,255,674 processor cycles
  ;;   65,536 bytes consed
  )

に対して、ループを入れ替えただけの以下は

(defun simple-gemm-k (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0)))
  (declare (type matrix ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (k cols)
        (dotimes (col cols)
          (incf (aref mc row col)
                (* (aref ma row k) (aref mb k col))))))
                                           ; \__ k は変化しないほうがいい
    mc))
(benchmark (10)
  ;; Evaluation took:
  ;;   5.920 seconds of real time
  ;;   5.920000 seconds of total run time (5.920000 user, 0.000000 system)
  ;;   100.00% CPU
  ;;   17,760,421,320 processor cycles
  ;;   35,712 bytes consed
  (simple-gemm-k *ma* *mb* *mc*))

だった。どうやらこっちベースにしないとだめそうだ。そういうわけでいろいろ書き直す。

cache-gemm-k

(defun cache-gemm-k (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type matrix ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (k cols)
        (let ((cell (aref ma row k)))
          (dotimes (col cols)
            (incf (aref mc row col)
                  (* cell (aref mb k col)))))))
    mc))

(benchmark (10)
  ;; Evaluation took:
  ;;   5.084 seconds of real time
  ;;   5.084000 seconds of total run time (5.080000 user, 0.004000 system)
  ;;   100.00% CPU
  ;;   15,252,688,222 processor cycles
  ;;   30,224 bytes consed
  (cache-gemm-k *ma* *mb* *mc*))

ok.

row-major-gemm-k

(defun rm-gemm-k (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type matrix ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (k cols)
        (let ((cell (aref ma row k))
              (mb-index (array-row-major-index mb k 0))
              (mc-index (array-row-major-index mc row 0)))
          (dotimes (col cols)
            (incf (row-major-aref mc mc-index)
                  (* cell (row-major-aref mb mb-index)))
            (incf mb-index)
            (incf mc-index)))))
    mc))

(benchmark (10)
  ;; Evaluation took:
  ;;   2.701 seconds of real time
  ;;   2.708000 seconds of total run time (2.708000 user, 0.000000 system)
  ;;   100.26% CPU
  ;;   8,102,781,155 processor cycles
  ;;   487,504 bytes consed
  (rm-gemm-k *ma* *mb* *mc*))

ok. ちなみに もとの row-major-gemm は

(benchmark (10)
  ;; Evaluation took:
  ;;   3.115 seconds of real time
  ;;   3.116000 seconds of total run time (3.116000 user, 0.000000 system)
  ;;   100.03% CPU
  ;;   9,347,123,164 processor cycles
  ;;   55,072 bytes consed
  (row-major-gemm *ma* *mb* *mc*))

static size で宣言

ここらへんから雲行きが怪しくなってきた。

(defun rm-gemm+static-size-k (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type (matrix 500 500) ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (k cols)
        (let ((cell (aref ma row k))
              (mb-index (array-row-major-index mb k 0))
              (mc-index (array-row-major-index mc row 0)))
          (dotimes (col cols)
            (incf (row-major-aref mc mc-index)
                  (* cell (row-major-aref mb mb-index)))
            (incf mb-index)
            (incf mc-index)))))
    mc))

(benchmark (10)
  ;; Evaluation took:
  ;;   2.696 seconds of real time
  ;;   2.696000 seconds of total run time (2.696000 user, 0.000000 system)
  ;;   100.00% CPU
  ;;   8,088,876,210 processor cycles
  ;;   33,056 bytes consed
  (rm-gemm+static-size-k *ma* *mb* *mc*))

これはあんまり変わらないな。最初のarray-dimensionで結果(500)が取られているからほとんど変わらないのかな。この時点のコードは

; disassembly for RM-GEMM+STATIC-SIZE-K
; Size: 170 bytes. Origin: #x10030631A8
; 1A8:       31FF             XOR EDI, EDI                    ; no-arg-parsing entry point
; 1AA:       E98D000000       JMP L5
; 1AF:       90               NOP
; 1B0: L0:   31C9             XOR ECX, ECX
; 1B2:       E978000000       JMP L4
; 1B7:       660F1F840000000000 NOP
; 1C0: L1:   4869C7F4010000   IMUL RAX, RDI, 500
; 1C7:       4801C8           ADD RAX, RCX
; 1CA:       498B5211         MOV RDX, [R10+17]
; 1CE:       F30F105C4201     MOVSS XMM3, [RDX+RAX*2+1]
; 1D4:       4869F1F4010000   IMUL RSI, RCX, 500
; 1DB:       4869C7F4010000   IMUL RAX, RDI, 500
; 1E2:       31DB             XOR EBX, EBX
; 1E4:       EB3C             JMP L3
; 1E6:       660F1F840000000000 NOP
; 1EF:       90               NOP
; 1F0: L2:   498B5111         MOV RDX, [R9+17]
; 1F4:       F30F104C7201     MOVSS XMM1, [RDX+RSI*2+1]
; 1FA:       F30F59CB         MULSS XMM1, XMM3
; 1FE:       498B5011         MOV RDX, [R8+17]
; 202:       F30F10544201     MOVSS XMM2, [RDX+RAX*2+1]
; 208:       F30F58D1         ADDSS XMM2, XMM1
; 20C:       498B5011         MOV RDX, [R8+17]
; 210:       F30F11544201     MOVSS [RDX+RAX*2+1], XMM2
; 216:       4883C602         ADD RSI, 2
; 21A:       4883C002         ADD RAX, 2
; 21E:       4883C302         ADD RBX, 2
; 222: L3:   4881FBE8030000   CMP RBX, 1000
; 229:       7CC5             JL L2
; 22B:       4883C102         ADD RCX, 2
; 22F: L4:   4881F9E8030000   CMP RCX, 1000
; 236:       7C88             JL L1
; 238:       4883C702         ADD RDI, 2
; 23C: L5:   4881FFE8030000   CMP RDI, 1000
; 243:       0F8C67FFFFFF     JL L0
; 249:       498BD0           MOV RDX, R8
; 24C:       488BE5           MOV RSP, RBP
; 24F:       F8               CLC
; 250:       5D               POP RBP
; 251:       C3               RET

loop unrolling

手でアンローリングをチューニングするのはもう馬鹿らしいので最初から探索した。
すると まったく改善されなかった。 なぜだ。手で書いてもだめだ。

(defun rm-gemm+static-size+unroll-k (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type (matrix 500 500) ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (k cols)
        (let ((cell (aref ma row k))
              (mb-index (array-row-major-index mb k 0))
              (mc-index (array-row-major-index mc row 0)))
          (dotimes-unroll (col cols 2)
            (incf (row-major-aref mc mc-index)
                  (* cell (row-major-aref mb mb-index)))
            (incf mb-index)
            (incf mc-index)))))
    mc))

(benchmark (10)
  ;; Evaluation took:
  ;;   2.699 seconds of real time
  ;;   2.704000 seconds of total run time (2.704000 user, 0.000000 system)
  ;;   100.19% CPU
  ;;   8,098,862,660 processor cycles
  ;;   0 bytes consed
  (rm-gemm+static-size+unroll-k *ma* *mb* *mc*))

dotimes-unroll の定義がだめに違いない。

dotimes-unroll2

dotimes-unroll は incf をいくつかコピーしてアンローリングを行っていた。それはなんとなく嫌な感じがする。

そこでsymbol-macrolet を用いて本当に定数伝播させたunrollを作ってみた。結構骨が折れた。
まずインタフェースを変えて、インクリメントする必要のある全ての変数をキャプチャする必要があった。
これらの変数は全てsymbol-macroletで隠蔽され、gensymされた実際の変数名に置き換えられる。

(ql:quickload :alexandria)
(use-package :alexandria)
(ql:quickload :priority-queue)
(use-package :priority-queue)
(ql:quickload :iterate)
(use-package :iterate)
(ql:quickload :trivia)
(use-package :trivia)

(defmacro dotimes-unroll2 ((var n unroll) parameters &body body)
  (check-type var symbol)                   ; \___ new!
  (assert (and (constantp unroll) (numberp unroll)))
  (with-gensyms (limit1 i) ; i is the true counter
    (let (varlist bindings)
      (dolist (p parameters)
        (ematch (ensure-list p)
          ((list x) ; x : user interface, replaced with y by symbol-macrolet
           (with-gensyms (y)
             (push (list x y) bindings) ; y is the true counter
             (push (list y 0 `(+ ,y ,unroll)) varlist)))
          ((list x start)
           (with-gensyms (y)
             (push (list x y) bindings)
             (push (list y start `(+ ,y ,unroll)) varlist)))))
      (push (list var i) bindings)
      (push (list i 0 `(+ ,i ,unroll)) varlist)
      (once-only (n)
        `(locally (declare (fixnum ,n))
           (let ((,limit1 (the fixnum (- ,n ,unroll))))
             (declare (fixnum ,limit1))
             (do (,@varlist)
                 ((< ,limit1 ,i)
                  (do (,@(mapcar (lambda-match
                                   ((list y _ _) `(,y ,y (the fixnum (1+ ,y)))))
                                 varlist))
                      ((<= ,n ,i))
                    (symbol-macrolet (,@bindings)
                      ,@body)))
               (declare (fixnum ,@(mapcar #'first varlist)))
               ,@(iter (for j below unroll)
                       (collect
                           `(symbol-macrolet (,@(mapcar (lambda-match
                                                          ((list x y) `(,x (the fixnum (+ ,y ,j)))))
                                                  bindings))
                              ,@body))))))))))

性能計測: だめだった。変わらない。なんでだろう? と書くとだれかがやってくれるはず。

(defun rm-gemm+static-size+unroll-k (ma mb mc)
  (declare (optimize (speed 3) (debug 0) (safety 0) (space 0)))
  (declare (type (matrix 500 500) ma mb mc))
  (let ((rows (array-dimension ma 0))
        (cols (array-dimension mb 1)))
    (declare (type fixnum rows cols))
    (dotimes (row rows)
      (dotimes (k cols)
        (let ((cell (aref ma row k)))
          (dotimes-unroll2 (col cols 2) ((mb-index (array-row-major-index mb k 0))
                                         (mc-index (array-row-major-index mc row 0)))
            (incf (row-major-aref mc mc-index)
                  (* cell (row-major-aref mb mb-index)))))))
    mc))

(benchmark (10)
  ;; Evaluation took:
  ;;   2.753 seconds of real time
  ;;   2.752000 seconds of total run time (2.752000 user, 0.000000 system)
  ;;   99.96% CPU
  ;;   8,259,540,639 processor cycles
  ;;   33,248 bytes consed
  (rm-gemm+static-size+unroll-k *ma* *mb* *mc*))

disassemble:

; disassembly for RM-GEMM+STATIC-SIZE+UNROLL-K
; Size: 287 bytes. Origin: #x101075E738
; 738:       4D31C9           XOR R9, R9                      ; no-arg-parsing entry point
; 73B:       E901010000       JMP L7
; 740: L0:   31F6             XOR ESI, ESI
; 742:       E9E9000000       JMP L6
; 747:       660F1F840000000000 NOP
; 750: L1:   4969C9F4010000   IMUL RCX, R9, 500
; 757:       4801F1           ADD RCX, RSI
; 75A:       498B5611         MOV RDX, [R14+17]
; 75E:       F30F104C4A01     MOVSS XMM1, [RDX+RCX*2+1]
; 764:       4969C9F4010000   IMUL RCX, R9, 500
; 76B:       4869DEF4010000   IMUL RBX, RSI, 500
; 772:       31FF             XOR EDI, EDI
; 774:       EB6D             JMP L3
; 776:       660F1F840000000000 NOP
; 77F:       90               NOP
; 780: L2:   498B5011         MOV RDX, [R8+17]
; 784:       F30F10545A01     MOVSS XMM2, [RDX+RBX*2+1]
; 78A:       F30F59D1         MULSS XMM2, XMM1
; 78E:       488B5011         MOV RDX, [RAX+17]
; 792:       F30F105C4A01     MOVSS XMM3, [RDX+RCX*2+1]
; 798:       F30F58DA         ADDSS XMM3, XMM2
; 79C:       488B5011         MOV RDX, [RAX+17]
; 7A0:       F30F115C4A01     MOVSS [RDX+RCX*2+1], XMM3
; 7A6:       488D5102         LEA RDX, [RCX+2]
; 7AA:       4C8D5302         LEA R10, [RBX+2]
; 7AE:       4D8B6811         MOV R13, [R8+17]
; 7B2:       F3430F10545501   MOVSS XMM2, [R13+R10*2+1]
; 7B9:       F30F59D1         MULSS XMM2, XMM1
; 7BD:       4C8B5011         MOV R10, [RAX+17]
; 7C1:       F3410F105C5201   MOVSS XMM3, [R10+RDX*2+1]
; 7C8:       F30F58DA         ADDSS XMM3, XMM2
; 7CC:       4C8B5011         MOV R10, [RAX+17]
; 7D0:       F3410F115C5201   MOVSS [R10+RDX*2+1], XMM3
; 7D7:       4883C704         ADD RDI, 4
; 7DB:       4883C104         ADD RCX, 4
; 7DF:       4883C304         ADD RBX, 4
; 7E3: L3:   4881FFE4030000   CMP RDI, 996
; 7EA:       7E94             JLE L2
; 7EC:       EB35             JMP L5
; 7EE:       6690             NOP
; 7F0: L4:   498B5011         MOV RDX, [R8+17]
; 7F4:       F30F10545A01     MOVSS XMM2, [RDX+RBX*2+1]
; 7FA:       F30F59D1         MULSS XMM2, XMM1
; 7FE:       488B5011         MOV RDX, [RAX+17]
; 802:       F30F105C4A01     MOVSS XMM3, [RDX+RCX*2+1]
; 808:       F30F58DA         ADDSS XMM3, XMM2
; 80C:       488B5011         MOV RDX, [RAX+17]
; 810:       F30F115C4A01     MOVSS [RDX+RCX*2+1], XMM3
; 816:       4883C102         ADD RCX, 2
; 81A:       4883C302         ADD RBX, 2
; 81E:       BFE8030000       MOV EDI, 1000
; 823: L5:   4881FFE8030000   CMP RDI, 1000
; 82A:       7CC4             JL L4
; 82C:       4883C602         ADD RSI, 2
; 830: L6:   4881FEE8030000   CMP RSI, 1000
; 837:       0F8C13FFFFFF     JL L1
; 83D:       4983C102         ADD R9, 2
; 841: L7:   4981F9E8030000   CMP R9, 1000
; 848:       0F8CF2FEFFFF     JL L0
; 84E:       488BD0           MOV RDX, RAX
; 851:       488BE5           MOV RSP, RBP
; 854:       F8               CLC
; 855:       5D               POP RBP
; 856:       C3               RET

macroexpansion:

(dotimes-unroll2 (col cols 2) ((mb-index (array-row-major-index mb k 0))
                               (mc-index (array-row-major-index mc row 0)))
   (incf (row-major-aref mc mc-index)
         (* cell (row-major-aref mb mb-index))))

(LET ((#:N849 COLS))
  (LOCALLY
   (DECLARE (FIXNUM #:N849))
   (LET ((#:LIMIT1845 (THE FIXNUM (- #:N849 2))))
     (DECLARE (FIXNUM #:LIMIT1845))
     (DO ((#:I846 0 (+ #:I846 2))
          (#:Y848 (ARRAY-ROW-MAJOR-INDEX MC ROW 0) (THE FIXNUM (+ #:Y848 2)))
          (#:Y847 (ARRAY-ROW-MAJOR-INDEX MB K 0) (THE FIXNUM (+ #:Y847 2))))
         ((< #:LIMIT1845 #:I846)
          (DO ((#:I846 #:I846 (THE FIXNUM (1+ #:I846)))
               (#:Y848 #:Y848 (THE FIXNUM (1+ #:Y848)))
               (#:Y847 #:Y847 (THE FIXNUM (1+ #:Y847))))
              ((<= #:N849 #:I846))
            (DECLARE (FIXNUM #:I846 #:Y848 #:Y847))
            (SYMBOL-MACROLET ((COL #:I846) (MC-INDEX #:Y848) (MB-INDEX #:Y847))
              (INCF (ROW-MAJOR-AREF MC MC-INDEX)
                    (* CELL (ROW-MAJOR-AREF MB MB-INDEX))))))
       (DECLARE (FIXNUM #:I846 #:Y848 #:Y847))
       (SYMBOL-MACROLET ((COL (THE FIXNUM (+ #:I846 0)))
                         (MC-INDEX (THE FIXNUM (+ #:Y848 0)))
                         (MB-INDEX (THE FIXNUM (+ #:Y847 0))))
         (INCF (ROW-MAJOR-AREF MC MC-INDEX)
               (* CELL (ROW-MAJOR-AREF MB MB-INDEX))))
       (SYMBOL-MACROLET ((COL (THE FIXNUM (+ #:I846 1)))
                         (MC-INDEX (THE FIXNUM (+ #:Y848 1)))
                         (MB-INDEX (THE FIXNUM (+ #:Y847 1))))
         (INCF (ROW-MAJOR-AREF MC MC-INDEX)
               (* CELL (ROW-MAJOR-AREF MB MB-INDEX))))))))

; equivalent to

(LET ((#:N849 COLS))
  (LOCALLY
   (DECLARE (FIXNUM #:N849))
   (LET ((#:LIMIT1845 (THE FIXNUM (- #:N849 2))))
     (DECLARE (FIXNUM #:LIMIT1845))
     (DO ((#:I846 0 (+ #:I846 2))
          (#:Y848 (ARRAY-ROW-MAJOR-INDEX MC ROW 0) (THE FIXNUM (+ #:Y848 2)))
          (#:Y847 (ARRAY-ROW-MAJOR-INDEX MB K 0) (THE FIXNUM (+ #:Y847 2))))
         ((< #:LIMIT1845 #:I846)
          (DO ((#:I846 #:I846 (THE FIXNUM (1+ #:I846)))
               (#:Y848 #:Y848 (THE FIXNUM (1+ #:Y848)))
               (#:Y847 #:Y847 (THE FIXNUM (1+ #:Y847))))
              ((<= #:N849 #:I846))
            (DECLARE (FIXNUM #:I846 #:Y848 #:Y847))
            (INCF (ROW-MAJOR-AREF MC #:Y848)
                  (* CELL (ROW-MAJOR-AREF MB #:Y847)))))
       (DECLARE (FIXNUM #:I846 #:Y848 #:Y847))
       (INCF (ROW-MAJOR-AREF MC (THE FIXNUM (+ #:Y848 0)))
             (* CELL (ROW-MAJOR-AREF MB (THE FIXNUM (+ #:Y847 0)))))
       (INCF (ROW-MAJOR-AREF MC (THE FIXNUM (+ #:Y848 1)))
             (* CELL (ROW-MAJOR-AREF MB (THE FIXNUM (+ #:Y847 1)))))))))

メモ: VOP を使おうという事になったが、 data-vector-set-with-offset が VOP に変換されなくてfull-call になってしまい、かつ always-transferrable なので エラーになる。どうするか。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4