LoginSignup
0
0

はじめに

  • numpy.dotでm×n行列と1×m行列の積を計算しようとしたら失敗したので,共有できればと思い記事を書きました

numpyを用いた行列の積の計算方法

  • 以下のようにすれば行列を定義可能

    import numpy as np
    A = np.array([[1,2,3],[3,3,5]])
    B = np.array([[1,2],[2,2],[1,1]])
    print(A)
    print(B)
    

  • 実行画面

    [[1 2 3]
     [3 3 5]]
    [[1 2]
     [2 2]
     [1 1]]
    

  • 積の計算をnp.dotで行う

    print(np.dot(A, B))
    

  • 実行結果

    • 正しく計算できている
    [[ 8  9]
     [14 17]]
    

m×nの行列と1×mの行列の積

  • 以下のコードを実行してみる

    import numpy as np
    
    A = np.array([1,2,3])
    B = np.array([[1],[2],[3]])
    print(A)
    print(B)
    print(np.dot(B,A))
    

  • 実行結果

    [1 2 3]
    [[1]
     [2]
     [3]]
    Traceback (most recent call last):
      File "xx\yy.py", line 9, in <module>
        print(np.dot(B,A))
    ValueError: shapes (3,1) and (3,) not aligned: 1 (dim 1) != 3 (dim 0)
    

  • np.dot(B, A)がうまくいっていない

    • 3×3の行列になるはず
  • 第2項が1行のとき,積の計算はpythonでは不可能...?

解決方法

  • 1行の行列は以下のように定義すれば解決

    import numpy as np
    A = np.array([[1,2]])
    print(A)
    

  • 出力

    [[1 2]]
    

1行であることを明示的に定義する

  • 参考に以下のコードを実行してみる

    import numpy as np
    
    A = np.array([[1,2,3]])
    B = np.array([[1],[2],[3]])
    
    print(np.dot(B,A))
    

  • 実行結果

    [[1 2 3]
     [2 4 6]
     [3 6 9]]
    

直感的な定義でも動くような行列積の自作関数

  • 上記のように定義すれば計算可能

    • 直感的ではないのが難点
  • 直感的な定義でも動くような設計にする

    import numpy as np
    
    def original_dot(l1, l2):
      if len(list(l1.shape)) == 1: # 適切な形への変形処理
        l1 = np.array([l1])
    
      if len(list(l2.shape)) == 1: # 適切な形への変形処理
        l2 = np.array([l2])
    
      m1 = l1.shape[0]
      n1 = l1.shape[1]
      m2 = l2.shape[0]
      n2 = l2.shape[1]
    
      if n1 != m2: #計算不可能な場合
        print("ERROR n1 = "+str(n1)+" and m2 = "+str(m2))
        return
    
      else:
        ans = np.zeros((m1,n2))
        for i in range(m1):
          for j in range(n2):
            ans[i][j] = sum([l1[i][k]*l2[k][j] for k in range(n1)])
        return ans
    
    A = np.array([1,2,3]) #(1 2 3)
    B = np.array([[1],[2],[3]]) # (1 2 3)^T
    
    print(original_dot(B, A))
    

  • 実行結果

    • 正しい出力が得られた
    [[1. 2. 3.]
     [2. 4. 6.]
     [3. 6. 9.]]
    

1行m列の行列Aが適切でない定義だと,A.shape = (m, )となることを利用

0
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0