cryoDRGNについて
学習曲線の可視化
ログファイルに記録されている損失関数の値を見る。
ログファイル(run.log)
(省略)
(1): ReLU()
(2): ResidLinear(
(linear): Linear(in_features=256, out_features=256, bias=True)
)
(3): ReLU()
(4): ResidLinear(
(linear): Linear(in_features=256, out_features=256, bias=True)
)
(5): ReLU()
(6): ResidLinear(
(linear): Linear(in_features=256, out_features=256, bias=True)
)
(7): ReLU()
(8): Linear(in_features=256, out_features=2, bias=True)
)
)
)
)
2021-01-18 13:48:13 3784964 parameters in model
2021-01-18 13:54:34 # =====> Epoch: 1 Average gen loss = 1.02231, KLD = 0.924802, total loss = 1.022387; Finished in 0:06:20.245799
2021-01-18 14:01:29 # =====> Epoch: 2 Average gen loss = 1.01518, KLD = 1.224245, total loss = 1.015275; Finished in 0:06:22.065888
2021-01-18 14:08:19 # =====> Epoch: 3 Average gen loss = 1.01342, KLD = 1.943195, total loss = 1.013566; Finished in 0:06:20.048240
2021-01-18 14:15:07 # =====> Epoch: 4 Average gen loss = 1.01169, KLD = 2.337141, total loss = 1.011874; Finished in 0:06:18.268111
2021-01-18 14:22:03 # =====> Epoch: 5 Average gen loss = 1.01076, KLD = 2.467319, total loss = 1.010954; Finished in 0:06:23.279119
2021-01-18 14:28:54 # =====> Epoch: 6 Average gen loss = 1.01009, KLD = 2.548503, total loss = 1.010287; Finished in 0:06:20.104529
2021-01-18 14:35:46 # =====> Epoch: 7 Average gen loss = 1.00955, KLD = 2.615071, total loss = 1.009756; Finished in 0:06:21.803054
2021-01-18 14:42:35 # =====> Epoch: 8 Average gen loss = 1.00914, KLD = 2.653529, total loss = 1.009347; Finished in 0:06:19.304278
2021-01-18 14:49:26 # =====> Epoch: 9 Average gen loss = 1.00878, KLD = 2.707029, total loss = 1.008995; Finished in 0:06:20.044801
2021-01-18 14:56:19 # =====> Epoch: 10 Average gen loss = 1.00849, KLD = 2.747613, total loss = 1.008708; Finished in 0:06:22.251633
2021-01-18 15:03:08 # =====> Epoch: 11 Average gen loss = 1.00822, KLD = 2.802414, total loss = 1.008435; Finished in 0:06:19.600215
2021-01-18 15:09:58 # =====> Epoch: 12 Average gen loss = 1.00798, KLD = 2.830367, total loss = 1.008198; Finished in 0:06:19.171115
2021-01-18 15:16:50 # =====> Epoch: 13 Average gen loss = 1.00776, KLD = 2.878851, total loss = 1.007987; Finished in 0:06:20.732890
2021-01-18 15:23:39 # =====> Epoch: 14 Average gen loss = 1.00757, KLD = 2.906105, total loss = 1.007792; Finished in 0:06:19.150181
2021-01-18 15:30:32 # =====> Epoch: 15 Average gen loss = 1.00739, KLD = 2.945066, total loss = 1.007615; Finished in 0:06:19.663154
2021-01-18 15:37:27 # =====> Epoch: 16 Average gen loss = 1.00721, KLD = 2.989139, total loss = 1.007440; Finished in 0:06:20.898634
2021-01-18 15:44:17 # =====> Epoch: 17 Average gen loss = 1.00704, KLD = 3.023471, total loss = 1.007278; Finished in 0:06:19.637517
2021-01-18 15:51:09 # =====> Epoch: 18 Average gen loss = 1.0069, KLD = 3.056445, total loss = 1.007137; Finished in 0:06:21.252752
2021-01-18 15:57:58 # =====> Epoch: 19 Average gen loss = 1.00676, KLD = 3.087742, total loss = 1.007005; Finished in 0:06:19.174269
2021-01-18 16:04:50 # =====> Epoch: 20 Average gen loss = 1.00662, KLD = 3.121757, total loss = 1.006860; Finished in 0:06:22.003833
2021-01-18 16:11:40 # =====> Epoch: 21 Average gen loss = 1.0065, KLD = 3.140710, total loss = 1.006740; Finished in 0:06:19.149865
2021-01-18 16:18:42 # =====> Epoch: 22 Average gen loss = 1.00637, KLD = 3.176621, total loss = 1.006622; Finished in 0:06:31.558981
2021-01-18 16:25:47 # =====> Epoch: 23 Average gen loss = 1.00625, KLD = 3.207150, total loss = 1.006504; Finished in 0:06:29.674917
2021-01-18 16:32:36 # =====> Epoch: 24 Average gen loss = 1.00614, KLD = 3.234640, total loss = 1.006395; Finished in 0:06:18.424028
2021-01-18 16:39:25 # =====> Epoch: 25 Average gen loss = 1.00602, KLD = 3.270075, total loss = 1.006276; Finished in 0:06:18.568070
2021-01-18 16:46:15 # =====> Epoch: 26 Average gen loss = 1.00592, KLD = 3.293772, total loss = 1.006178; Finished in 0:06:19.931208
2021-01-18 16:53:06 # =====> Epoch: 27 Average gen loss = 1.00582, KLD = 3.317351, total loss = 1.006082; Finished in 0:06:20.290187
2021-01-18 16:59:59 # =====> Epoch: 28 Average gen loss = 1.00571, KLD = 3.340852, total loss = 1.005974; Finished in 0:06:20.850692
2021-01-18 17:06:51 # =====> Epoch: 29 Average gen loss = 1.00563, KLD = 3.358073, total loss = 1.005888; Finished in 0:06:19.800037
2021-01-18 17:13:41 # =====> Epoch: 30 Average gen loss = 1.00553, KLD = 3.394385, total loss = 1.005797; Finished in 0:06:19.344668
2021-01-18 17:20:30 # =====> Epoch: 31 Average gen loss = 1.00544, KLD = 3.412608, total loss = 1.005709; Finished in 0:06:19.537618
2021-01-18 17:27:21 # =====> Epoch: 32 Average gen loss = 1.00535, KLD = 3.426723, total loss = 1.005612; Finished in 0:06:20.501204
2021-01-18 17:34:13 # =====> Epoch: 33 Average gen loss = 1.00527, KLD = 3.448969, total loss = 1.005542; Finished in 0:06:19.104303
2021-01-18 17:41:04 # =====> Epoch: 34 Average gen loss = 1.00518, KLD = 3.468607, total loss = 1.005448; Finished in 0:06:20.574825
2021-01-18 17:47:53 # =====> Epoch: 35 Average gen loss = 1.0051, KLD = 3.487816, total loss = 1.005368; Finished in 0:06:19.266487
2021-01-18 17:54:41 # =====> Epoch: 36 Average gen loss = 1.00503, KLD = 3.502420, total loss = 1.005298; Finished in 0:06:16.710547
2021-01-18 18:01:28 # =====> Epoch: 37 Average gen loss = 1.00495, KLD = 3.525365, total loss = 1.005221; Finished in 0:06:17.309040
2021-01-18 18:08:21 # =====> Epoch: 38 Average gen loss = 1.00487, KLD = 3.541543, total loss = 1.005146; Finished in 0:06:23.100732
2021-01-18 18:15:10 # =====> Epoch: 39 Average gen loss = 1.0048, KLD = 3.555950, total loss = 1.005080; Finished in 0:06:18.102520
2021-01-18 18:21:55 # =====> Epoch: 40 Average gen loss = 1.00474, KLD = 3.566339, total loss = 1.005020; Finished in 0:06:16.129188
2021-01-18 18:28:42 # =====> Epoch: 41 Average gen loss = 1.00469, KLD = 3.590437, total loss = 1.004970; Finished in 0:06:17.461872
2021-01-18 18:35:29 # =====> Epoch: 42 Average gen loss = 1.00463, KLD = 3.600814, total loss = 1.004912; Finished in 0:06:18.362757
2021-01-18 18:42:16 # =====> Epoch: 43 Average gen loss = 1.00458, KLD = 3.621860, total loss = 1.004857; Finished in 0:06:17.235549
2021-01-18 18:49:04 # =====> Epoch: 44 Average gen loss = 1.0045, KLD = 3.633174, total loss = 1.004787; Finished in 0:06:17.204172
2021-01-18 18:55:49 # =====> Epoch: 45 Average gen loss = 1.00443, KLD = 3.666399, total loss = 1.004714; Finished in 0:06:14.774338
2021-01-18 19:02:36 # =====> Epoch: 46 Average gen loss = 1.0044, KLD = 3.671483, total loss = 1.004687; Finished in 0:06:18.165214
2021-01-18 19:09:21 # =====> Epoch: 47 Average gen loss = 1.00436, KLD = 3.678585, total loss = 1.004645; Finished in 0:06:15.194644
2021-01-18 19:16:09 # =====> Epoch: 48 Average gen loss = 1.00429, KLD = 3.694472, total loss = 1.004573; Finished in 0:06:19.251820
2021-01-18 19:22:56 # =====> Epoch: 49 Average gen loss = 1.00424, KLD = 3.703967, total loss = 1.004533; Finished in 0:06:16.262893
2021-01-18 19:29:44 # =====> Epoch: 50 Average gen loss = 1.00423, KLD = 3.705236, total loss = 1.004515; Finished in 0:06:17.885488
2021-01-18 19:30:43 Finsihed in 5:43:56.411010 (0:06:52.728220 per epoch)
ログファイルから損失関数の値をとってくる
ログファイルから損失関数の値をとってくるための関数は以下。Pandas使用。
関数定義(Python)
import re
import pandas as pd
def parse_cryodrgn_log(log_file):
log = []
for line in open(log_file):
m = re.match(r'.+Epoch:\s*(\d*)\s*Average gen loss =\s*(\d*\.\d*), KLD =\s*(\d*\.\d*),\s*total loss =\s*(\d*\.\d*).*', line)
if m is not None:
log.append(dict(
epoch=int(m.group(1)),
average_gen_loss=float(m.group(2)),
kld=float(m.group(3)),
total_loss=float(m.group(4))
))
df_log = pd.DataFrame(log)
return df_log
epoch | average_gen_loss | kld | total_loss | |
---|---|---|---|---|
0 | 1 | 1.02231 | 0.924802 | 1.02239 |
1 | 2 | 1.01518 | 1.22425 | 1.01527 |
2 | 3 | 1.01342 | 1.9432 | 1.01357 |
3 | 4 | 1.01169 | 2.33714 | 1.01187 |
4 | 5 | 1.01076 | 2.46732 | 1.01095 |
5 | 6 | 1.01009 | 2.5485 | 1.01029 |
6 | 7 | 1.00955 | 2.61507 | 1.00976 |
(以下省略)
単一のログファイルの可視化
単一のログファイルについて、各曲線をプロット。Plotlyライブラリ使用。
関数定義(Python)
from plotly.subplots import make_subplots
import plotly.graph_objects as go
def plot_cryodrgn_learning_curve(log_file):
df = parse_cryodrgn_log(log_file)
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(
go.Scatter(x=df['epoch'], y=df['average_gen_loss'], name='Generator loss'),
secondary_y=False
)
fig.add_trace(
go.Scatter(x=df['epoch'], y=df['total_loss'], name='Total loss'),
secondary_y=False
)
fig.add_trace(
go.Scatter(x=df['epoch'], y=df['kld'], name='KL Divergence'),
secondary_y=True
)
fig.update_layout(
height=800,
xaxis_title='Epoch'
)
fig.update_yaxes(
title_text='Loss',
secondary_y=False
)
fig.update_yaxes(
title_text='KL Divergence',
secondary_y=True
)
return fig
コード例(Python)
log_file = '00_vae128_zdim1_seed1/run.log'
fig = plot_cryodrgn_learning_curve(log_file)
fig.show()
複数のログファイルの可視化
コード(Python)
fig = go.Figure()
for log_file in log_files:
df = parse_cryodrgn_log(log_file)
fig.add_trace(
go.Scatter(x=df['epoch'], y=df['total_loss'], name=os.path.dirname(log_file))
)
fig.update_layout(
height=800,
xaxis_title='Epoch',
yaxis_title='Total loss'
)
fig.show()