0
1

More than 1 year has passed since last update.

pythonでプロット用のデータを事前計算してバイナリでファイル出力しておく&読み込んでプロットする

Last updated at Posted at 2023-01-28

アイデア

プロット用のデータを毎回計算すると時間がかかりすぎるため、事前計算しておきたい。
そこで、バイナリファイルに書き出すようにしてみた。

  • 複素数関数の、引数の実部と虚部をそれぞれ等差的に変化させたときの格子点に対しての、関数の値の実部と虚部をdouble型で格納。
  • 複数の関数を組み合わせて追加計算したい場合を想定して、複数の関数を同じファイルに格納。
データ構造のメモ書き
ascii4 "PLOT"
uint2  format version info
uint2  number of functions
uint4  stepcount of x(real part) of numpy.linspace(,,*)
uint4  stepcount of y(imag part)
ascii128 description of func[0]
ascii128 description of func[1]
:
ascii128 description of func[n-1]
double start x of numpy.linspace(*,,)
double end x of numpy.linspace(,*,)
double start y
double end y
double func[0].Re[xindex=0,yindex=0]
double func[0].Im[xindex=0,yindex=0]
double func[1].Re[xindex=0,yindex=0]
double func[1].Im[xindex=0,yindex=0]
:
double func[n-1].Re[xindex=0,yindex=0]
double func[n-1].Im[xindex=0,yindex=0]
double func[0].Re[xindex=1,yindex=0]
double func[0].Im[xindex=1,yindex=0]
double func[1].Re[xindex=1,yindex=0]
double func[1].Im[xindex=1,yindex=0]
:
double func[0].Re[xindex=0,yindex=1]
double func[0].Im[xindex=0,yindex=1]
double func[1].Re[xindex=0,yindex=1]
double func[1].Im[xindex=0,yindex=1]
:

ソースコード

プログラム上の要点としては、struct.pack を使ってdouble等の値をバイナリに変換している。

今回使っている複素関数とプロットについては下記参照。

なお以下のコードは、グローバル変数使いまくりだったり、なぐり書きのソースなので作法的にはマネしない方がよいかも。

計算してファイルに出力

おおよそ5秒経過毎に進捗率を表示するようにしている。
for k in range(len(x[0])):の部分が長くなると5秒ごとに出力できなくなる可能性があるが、処理速度優先で外側のforのほうに入れた。

calc.py
import time
import struct
import numpy as np
import mpmath
mpmath.mp.dps = 15
# mpmath.mp.pretty = True

# -----------------------

def dumpFuncValues(f, x, y):
    global previous_time

    # z = np.zeros((len(x), len(x[0])))
    for i in range(len(x)):
        current_time = time.time()
        if previous_time+5 < current_time :
            previous_time = current_time
            print(str(round(100*i/len(x),3))+'%')
        
        for k in range(len(x[0])):
            t = x[i][k] + y[i][k]*1j
            ze    = mpmath.zeta(t)
            ze_d1 = mpmath.zeta(t,derivative=1)
            ga    = mpmath.gamma(t*0.5)
            diga  = mpmath.digamma(t*0.5)
            p = mpmath.power(mpmath.pi,-0.5*t)
            ect = ga*p*ze

            re = struct.pack('@d', ze.real)
            im = struct.pack('@d', ze.imag)
            f.write(re)
            f.write(im)
            
            re = struct.pack('@d', ze_d1.real)
            im = struct.pack('@d', ze_d1.imag)
            f.write(re)
            f.write(im)
            
            re = struct.pack('@d', ga.real)
            im = struct.pack('@d', ga.imag)
            f.write(re)
            f.write(im)
            
            re = struct.pack('@d', diga.real)
            im = struct.pack('@d', diga.imag)
            f.write(re)
            f.write(im)

            re = struct.pack('@d', p.real)
            im = struct.pack('@d', p.imag)
            f.write(re)
            f.write(im)

            re = struct.pack('@d', ect.real)
            im = struct.pack('@d', ect.imag)
            f.write(re)
            f.write(im)
    return
# -----------------------
def dump():
    x = np.linspace(X_START, X_END, X_COUNT)
    y = np.linspace(Y_START, Y_END, Y_COUNT)

    X, Y = np.meshgrid(x, y)

    func_comments = [
        'RiemannZeta(z)',
        'RiemannZeta(z) derivative',
        'Gamma(z/2)',
        'Digamma(z/2)',
        'pow(pi,-z/2)',
        'prod of (#1,#3,#5)'
    ]

    f = open('myfile.dat', 'wb')
    f.write('PLOT'.encode('utf-8'))
    f.write(struct.pack('@H', 0)) # version info
    f.write(struct.pack('@H', len(func_comments))) # number of functions
    f.write(struct.pack('@L', X_COUNT))
    f.write(struct.pack('@L', Y_COUNT))


    for fc in func_comments:
        f.write(fc.ljust(FUNC_COMMENT_LEN,' ').encode('utf-8'))
    f.write(struct.pack('@d', X_START))
    f.write(struct.pack('@d', X_END))
    f.write(struct.pack('@d', Y_START))
    f.write(struct.pack('@d', Y_END))
    dumpFuncValues(f, X, Y)
    f.close()
    return
# -----------------------

FUNC_COMMENT_LEN = 128

X_STEP = 1.0/32.0
Y_STEP = 1.0/2.0

X_START = 3/8
X_END   = 5/8

Y_START = 10
Y_END   = 200

X_COUNT = round(((X_END-X_START)/X_STEP)+1)
Y_COUNT = round(((Y_END-Y_START)/Y_STEP)+1)

previous_time = time.time()

if X_COUNT * Y_COUNT < 1000000:
    dump()
else:
    f = open('mesh_is_toobig.txt', 'wb')
    f.write('X:')
    f.write(str(X_COUNT))
    f.write(',Y:')
    f.write(str(Y_COUNT))
    f.close()

読み出しとプロット

load_and_plot.py
import sys
import os
import struct
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt

def load_header_info(f):
    global x_count
    global y_count
    global x_start
    global x_end
    global y_start
    global y_end
    global seekpos
    global count_of_funcs
    f.read(4) # skip 'PLOT'
    f.read(2) # skip version info
    count_of_funcs = struct.unpack('@H', f.read(2))[0]
    print(count_of_funcs)
    x_count = struct.unpack('@L', f.read(4))[0]
    y_count = struct.unpack('@L', f.read(4))[0]
    # func_comments = []
    for i in range(count_of_funcs):
        fc = f.read(128).decode()
        # func_comments.append(fc)
    x_start = struct.unpack('@d', f.read(8))[0]
    x_end   = struct.unpack('@d', f.read(8))[0]
    y_start = struct.unpack('@d', f.read(8))[0]
    y_end   = struct.unpack('@d', f.read(8))[0]
    seekpos = f.tell() # get current seek pos

    tmpx = np.linspace(x_start, x_end, x_count)
    tmpy = np.linspace(y_start, y_end, y_count)
    return tmpx,tmpy

def load_nth_funcs_real(f,index):
    global x
    global y
    global count_of_funcs
    z = np.zeros((len(y), len(x)))
    for i in range(len(y)):
        for k in range(len(x)):
            for funci in range(count_of_funcs):
                if funci == index:
                    tmp_re = struct.unpack('@d', f.read(8))[0]
                    f.read(8)
                    z[i][k] = tmp_re
                else:
                    f.read(8)
                    f.read(8)
    return z

def load_nth_funcs_imag(f,index):
    global x
    global y
    global count_of_funcs
    z = np.zeros((len(y), len(x)))
    for i in range(len(y)):
        for k in range(len(x)):
            for funci in range(count_of_funcs):
                if funci == index:
                    f.read(8)
                    tmp_im = struct.unpack('@d', f.read(8))[0]
                    z[i][k] = tmp_im
                else:
                    f.read(8)
                    f.read(8)
    return z

x_count=0
y_count=0
x_start=0
x_end=0
y_start=0
y_end=0
seekpos=0
count_of_funcs=0


args = sys.argv
path = None
if len(args)>1:
    tmp_path = args[1]
    dummy,ext = os.path.splitext(tmp_path)
    if ext == ".dat":
        path = tmp_path
else:
    path = 'myfile.dat'

f = open(path, 'rb')
x,y = load_header_info(f)
#Z = load_nth_funcs_real(f,0)
# f.seek(seekpos)
Z = load_nth_funcs_real(f,0)
f.close()
X, Y = np.meshgrid(x, y)

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.view_init(elev=20, azim=225)
# ax.set_zlim3d(top=Z_LIMIT)
# ax.set_ylim3d(top=50,bottom=10)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='viridis', edgecolor='none')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

plt.show()
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1