LoginSignup
7
7

More than 3 years have passed since last update.

Juliaの多重ディスパッチで遊んでみる:2次元1次元配列

Last updated at Posted at 2020-04-17

Juliaでは多重ディスパッチを利用することができます。
これを使って、少し面白いものができましたので、ここに書いておきます。
結論としては、1次元配列に引数二つの場合の多重ディスパッチをすると周期境界条件こみの2次元系が扱える、ということです。

バージョン

Julia 1.4

不思議?なコード

まず、以下のコードを見てみてください。

using .MatVec2d

function test()
    Lx = 10
    Ly = 10
    N = Lx*Ly
    ψ = Vector2d{Float64}(Lx,Ly)
    ψ .= 0
    for i=1:N
        ψ[i] = rand()
    end
    H = Matrix2d{Float64}(Lx,Ly)
    H .= 0
    μ = -1.5
    for ix=1:Lx
        for iy=1:Ly
            jx = ix
            jy = iy

            H[ix,iy,jx,jy] = -μ

            jx = ix + 1
            jy = iy
            H[ix,iy,jx,jy] = -1

            jx = ix - 1
            jy = iy
            H[ix,iy,jx,jy] = -1

            jx = ix 
            jy = iy + 1
            H[ix,iy,jx,jy] = -1

            jx = ix 
            jy = iy -1
            H[ix,iy,jx,jy] = -1
        end
    end
    b = H \ ψ
    b  =Vector2d(b,Lx,Ly,ψ.periodic)
    println(typeof(b))
    println(b)

end

test()

これをみると、足が四つの配列Hを作っているように見えると思います。
差分化された微分方程式の行列のようなものです。Lx x Lyの二次元格子を定義し、そこでの行列の行列要素を計算しているように見えると思います。
しかも、ixは1からLxまで動いているのに、jxはix+1だったりするので、Lxを超えているように見えます。配列外に行っているのでしょうか?

また、なぜか、その四次元配列に対して連立方程式Hψ = b:

b = H \ ψ

を解いているように見えます。しかし、足が四つあるのになぜ行列みたいなことができるのでしょうか?

正解

ここでVector2dMatrix2dなる型があることに気がつかれた方もいると思います。
これは、

module MatVec2d
    export Vector2d,Matrix2d

    mutable struct Vector2d{T} <: AbstractVector{T}
        values::Array{T,1}
        Lx::Int64
        Ly::Int64
        N::Int64
        periodic::Bool

        function Vector2d{T}(Lx,Ly) where T<:Number
            periodic = true
            return Vector2d{T}(Lx,Ly,periodic)
        end

        function Vector2d{T}(Lx,Ly,periodic) where T<:Number
            N = Lx*Ly
            values = Array{T}(undef,N)
            return new(values,Lx,Ly,N,periodic)
        end

        function Vector2d(a::Vector,Lx,Ly)
            periodic = true
            return Vector2d(a,Lx,Ly,periodic)
        end

        function Vector2d(a::Vector,Lx,Ly,periodic)
            a2d = Vector2d{eltype(a)}(Lx,Ly,periodic)
            for ix=1:Lx
                for iy=1:Ly
                    i = xy2i(ix,iy,Lx,Ly,periodic)
                    a2d[ix,iy] = a[i] 
                end
            end
            return a2d
        end
    end



    function xy2i(ix,iy,Lx,Ly,periodic)
        if periodic

            ixx = (ix+Lx-1)%Lx+1
            iyy = (iy+Ly-1)%Ly+1
            return (iyy-1)*Lx + ixx
        else
            return (iy-1)*Lx + ix
        end
    end

    Base.getindex(v::Vector2d,i::Int) = v.values[i]

    function Base.getindex(v::Vector2d,ix::Int,iy::Int) 
        i = xy2i(ix,iy,v.Lx,v.Ly,v.periodic)
        return v.values[i]
    end
    Base.size(vector::Vector2d) = (vector.N,)

    function Base.setindex!(vector::Vector2d, v, i::Int) 
        vector.values[i] = v
    end


    function Base.setindex!(vector::Vector2d, v, ix::Int,iy::Int) 
        i = xy2i(ix,iy,vector.Lx,vector.Ly,vector.periodic)
        vector.values[i] = v
    end

    mutable struct Matrix2d{T} <: AbstractMatrix{T}
        values::Array{T,2}
        Lx::Int64
        Ly::Int64
        N::Int64
        periodic::Bool

        function Matrix2d{T}(Lx,Ly) where T<:Number
            periodic = true
            return Matrix2d{T}(Lx,Ly,periodic)
        end

        function Matrix2d{T}(Lx,Ly,periodic) where T<:Number
            N = Lx*Ly
            values = Array{T}(undef,N,N)
            return new(values,Lx,Ly,N,periodic)
        end
    end

    Base.getindex(v::Matrix2d,i::Int,j::Int) = v.values[i,j]

    function Base.getindex(v::Matrix2d,ix::Int,iy::Int,jx::Int,jy::Int) 
        i = xy2i(ix,iy,v.Lx,v.Ly,v.periodic)
        j = xy2i(jx,jy,v.Lx,v.Ly,v.periodic)
        return v.values[i,j]
    end
    Base.size(A::Matrix2d) = (A.N,A.N)

    function Base.setindex!(A::Matrix2d, v, i::Int,j::Int) 
        A.values[i,j] = v
    end

    function Base.setindex!(A::Matrix2d, v, ix::Int,iy::Int,jx::Int,jy::Int) 
        i = xy2i(ix,iy,A.Lx,A.Ly,A.periodic)
        j = xy2i(jx,jy,A.Lx,A.Ly,A.periodic)
        A.values[i,j] = v
    end

end

と定義しています。

Vector2d{T} <: AbstractVector{T}

の部分は、このTypeがAbstractVectorを親に持つことを意味しています。これにより、AbstractVector型に対して定義されているfunctionは好きなだけ適用できるようになっています。

また、

Base.getindex(v::Vector2d,i::Int) = v.values[i]

ですが、これは

v[i]

とした時の挙動を多重ディスパッチで定義しているものです。これはAbstractVector、つまり1次元配列ですので、引数をひとつでアクセスした時に値が出てくるようにしたいですよね。

また、

function Base.setindex!(vector::Vector2d, v, i::Int) 
   vector.values[i] = v
end

は、配列として代入をしたい場合、つまり、

v[i] = 3

みたいな演算を定義しているものです。この二つを定義しておけば、独自型であるVector2dがまるで1次元配列のように扱えるわけです。

これを応用してみます。1次元配列として考えていますが、実は引数が二つ入ってしまった場合の定義を書いてしまっても構いません。つまり、

    function Base.getindex(v::Vector2d,ix::Int,iy::Int) 
        i = xy2i(ix,iy,v.Lx,v.Ly,v.periodic)
        return v.values[i]
    end

のように、ixとiyを与えた時に、値が返るような関数を作っても構わないわけです。
ここで、xy2i(ix,iy,v.Lx,v.Ly,v.periodic)というものを呼んでいますが、これは周期境界条件を考慮したものです。つまり、ixとかiyがLxやLyを超えてしまっても、ちゃんと配列に代入できるってことですね。上のコードは実は周期境界条件ではない場合には配列外参照してしまいます。直さなければなりません。

というわけで、同様に2次元配列も定義し、さらに、引数が四つの時、ix,iy,jx,jyの時の挙動を定義しておけば、2次元系の問題を簡単に定義できることになります。

7
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
7