## AFML Snipet 3.1
# 日次ボラティリティの計算
def get_daily_vol(close, span0=100):
df0 = close.index.searchsorted(close.index - pd.Timedelta(days=1))
df0 = df0[df0 > 0]
df0 = pd.Series(close.index[df0 - 1],
index=close.index[close.shape[0] - df0.shape[0]:])
#それぞれの時刻における1日前のリターンを算出
df0 = close.loc[df0.index] / close.loc[df0.values].values - 1
#ローリング指数加重標準偏差を計算
df0 = df0.ewm(span=span0).std()
return df0.rename('daily_vol')
#リターンと各時点における日次ボラティリティを計算
df0 = get_daily_vol(time_bar.cl)
pfm = pd.concat([returns, df0], axis=1, names=['r', 'daily_vol']).dropna()
#ラベリング
labels_dynamic = pfm.apply(lambda x: x['r'] / abs(x['r'])
if abs(x['r']) > x['daily_vol'] else 0,
axis=1)
plt.hist(labels_dynamic, bins=[-1.5, -0.5, 0.5, 1.5], align='mid', rwidth=0.5)
plt.xlabel('Label')
plt.ylabel('Frequency')
plt.xticks([-1, 0, 1])
plt.show()
#各静的閾値に対してグラフを用いて描画
thre_r = [0.5, 0.6, 0.7, 0.8, 0.9, 1]
fig = plt.figure(figsize=(20, 10))
ax1 = fig.add_subplot(2, 3, 1)
ax2 = fig.add_subplot(2, 3, 2)
ax3 = fig.add_subplot(2, 3, 3)
ax4 = fig.add_subplot(2, 3, 4)
ax5 = fig.add_subplot(2, 3, 5)
ax6 = fig.add_subplot(2, 3, 6)
axs = [ax1, ax2, ax3, ax4, ax5, ax6]
for r, ax in zip(thre_r, axs):
labels_static = returns.apply(lambda x: x / abs(x)
if abs(x) > pfm.daily_vol.mean() * r else 0)
ax.hist(np.array([labels_static, labels_dynamic], dtype='object'),
label=[f'static_{r}', 'dynamic'], bins=[-1.5, -0.5, 0.5, 1.5], align='mid', rwidth=0.5)
ax.legend(loc='upper left')
ax.set_xlabel('Label')
ax.set_ylabel('Frequency')
ax.set_xticks([-1, 0, 1])
plt.show()
# 静的閾値τを変更して再度ラベリング
labels_static = returns.apply(lambda x: x / abs(x)
if abs(x) > pfm.daily_vol.mean() * 0.7 else 0)
# グラフを描画
fig = plt.figure(figsize=(17, 6))
ax = fig.add_subplot(3, 1, 1)
ax.set_ylabel('l_static')
ax.plot(abs(labels_static))
ax = fig.add_subplot(3, 1, 2)
ax.set_ylabel('l_dynamic')
ax.plot(abs(labels_dynamic))
ax = fig.add_subplot(3, 1, 3)
ax.set_ylabel('daily_volatility')
ax.plot(pfm.daily_vol)
plt.show()
トリプルバリア
## AFML Snipet 3.2
def apply_triple_barrier(close, events, ptsl, molecule):
# t1前に行われた場合は,ストップロス/利食いを実施.
events_ = events.loc[molecule]
out = events_[['t1']].copy(deep=True)
if ptsl[0] > 0:pt=ptsl[0]*events_['trgt']
else:pt=pd.Series(index=events.index)
if ptsl[1] > 0:sl=-ptsl[1]*events_['trgt']
else:sl=pd.Series(index=events.index)
for loc, t1 in events_['t1'].fillna(close.index[-1]).items():
df0 = close[loc:t1] #価格経路
df0 = (df0/close[loc]-1)*events_.at[loc, 'side'] #リターン
#ストップロスの最短タイミング
out.loc[loc, 'sl'] = df0[df0 < sl[loc]].index.min()
#利食いの最短タイミング
out.loc[loc, 'pt'] = df0[df0 > pt[loc]].index.min()
return out
def calc_label(s: pd.Series, pfm):
if (min(s.t1, s.pt, s.sl) == s.pt or min(s.t1, s.pt, s.sl) == s.sl):
return (pfm.r[s.t1] / pfm.r[s.name] -
1) / abs(pfm.r[s.t1] / pfm.r[s.name] - 1)
else:
return 0
def get_ret_label(df, pfm):
"""
df's columns are as follows
- t1
- pt
- sl
pfm's columns are as follows
- r
- daily_vol
"""
return df.apply(lambda x: calc_label(x, pfm), axis=1)
warnings.simplefilter('ignore', category=RuntimeWarning)
t1 = pfm.index[h:] #垂直バリアの設定
trgt = pfm.daily_vol[:-h] #水平バリアの設定(ここではdailyのボラティリティから水平バリアを作成)
side = len(pfm[:-h]) * [1] #サイドは買いに設定
#トリプルバリア法
close = time_bar.cl
molecule = pfm.index[:-h]
events = pd.DataFrame({
't1': t1,
'trgt': trgt,
'side': side
}, index=molecule) #eventsの作成
ptsl = [1, 1]
#各バリアに触れた時のタイムスタンプを含むdataframeが返ってくる
df_triple = apply_triple_barrier(close, events, ptsl, molecule)
#最初にバリアを触れる時刻を更新(デフォルトは垂直バリアの時刻)
df_triple['t1'] = df_triple.dropna(how='all').min(axis=1)
#ラベリング(有効期間までに水平バリアに触れない場合は垂直バリアに触れた時点で損益計算(垂直バリアにあたったら0とラベリング)
labels_triple = get_ret_label(df_triple, pfm)
print("ラベルリングの結果")
print(labels_triple.value_counts())
labels_triple.plot(figsize=(17, 5))
plt.show()
print("---- トリプルバリア法 ----\n", labels_triple.value_counts())
print("---- 垂直バリアのみ ----\n", labels_dynamic.value_counts())
plt.hist(np.array([labels_triple, labels_dynamic], dtype='object'),
label=['triple', 'only_vertical'], bins=[-1.5, -0.5, 0.5, 1.5], align='mid', rwidth=0.5)
plt.legend(loc='upper left')
plt.xlabel('Label')
plt.ylabel('Frequency')
plt.xticks([-1, 0, 1])
plt.show()
#グラフを描画
fig = plt.figure(figsize=(17, 6))
ax = fig.add_subplot(3, 1, 1)
ax.set_ylabel('l_dynamic')
ax.plot(labels_dynamic)
ax = fig.add_subplot(3, 1, 2)
ax.set_ylabel('l_triple')
ax.plot(labels_triple)
ax = fig.add_subplot(3, 1, 3)
ax.set_ylabel('daily_volatility')
ax.plot(pfm.daily_vol)
plt.show()
サイドとサイズの楽手
## AFML Snipet 3.3
def get_events(close, tevents, ptsl, trgt, min_ret=0, t1=False):
#ターゲットの定義
trgt = trgt.loc[tevents]
trgt = trgt[trgt > min_ret] #より厳密にはmin_retは本来手数料分以上に設定する必要がある
#t1(最大保有期間)の定義
if t1 is False: t1 = pd.Series(pd.Nat, index=tevents)
#イベントオブジェクトを作成し,t1にストップロスを適用
side_ = pd.Series(1., index=trgt.index)
events = pd.concat({
't1': t1,
'trgt': trgt,
'side': side_
}, axis=1).dropna(subset=['trgt'])
df0 = apply_triple_barrier(close=close,
events=events,
ptsl=[ptsl, ptsl],
molecule=events.index)
events['t1'] = df0.dropna(how='all').min(axis=1)
events = events.drop('side', axis=1)
return events
## AFML Snipet 3.5
# 水地浴バリアに最初に触れた時に常に0を返すように変更する必要がある.
def get_bins(events, close):
events_ = events.dropna(subset=['t1'])
px = events_.index.union(events_['t1'].values).drop_duplicates()
px = close.reindex(px, method='bfill')
out = pd.DataFrame(index=events_.index)
out['ret'] = px.loc[events_['t1'].values].values / px.loc[
events_.index] - 1
out['bin'] = np.sign(out['ret'])
return out
tevents = pfm.index[:-h]
close = time_bar.cl
#水平バリアを設定
trgt = pfm.daily_vol[:-h]
#垂直バリアを設定
t1 = close.index.searchsorted(tevents+pd.Timedelta(hours=h))
t1 = t1[t1<close.shape[0]]
t1 = pd.Series(close.index[t1],index=tevents[:t1.shape[0]])
#トリプルバリア法に基づいてeventsオブジェクトを取得
events = get_events(close=close,tevents=tevents,ptsl=1,trgt=trgt,t1=t1)
df_label = get_bins(events,close)
print("ラベルリングの結果")
print(df_label['bin'].value_counts())
plt.hist(df_label['bin'], bins=[-1.5, -0.5, 0.5, 1.5], align='mid', rwidth=0.5)
plt.xlabel('Label')
plt.ylabel('Frequency')
plt.xticks([-1, 0, 1])
plt.show()
# ラベルの時系列変化
fig = plt.figure(figsize=(15, 6))
ax = fig.add_subplot(3, 1, 1)
ax.set_ylabel('bin')
ax.plot(df_label['bin'][:200])
ax = fig.add_subplot(3, 1, 2)
ax.set_ylabel('return')
ax.plot(df_label['ret'][:200])
ax = fig.add_subplot(3, 1, 3)
ax.set_ylabel('close')
ax.plot(close[df_label[:200].index])
plt.show()
メタラベ
## AFML Snipet 3.6
def get_events(close,
tevents,
ptsl,
trgt,
min_ret=0,
t1=False,
side=None):
#ターゲットの定義
trgt = trgt.loc[tevents]
trgt = trgt[trgt > min_ret] #min_ret:目標最小リターン)
#t1(最大保有期間の定義)
if t1 is False: t1 = pd.Series(pd.Nat, index=tevents)
#イベントオブジェクトを作成し,t1にストップロスを適用
if side is None:
side_, ptsl_ = pd.Series(1., index=trgt.index), [ptsl[0], ptsl[0]]
else:
side_, ptsl_ = side.loc[trgt.index], ptsl[:2]
events = pd.concat({
't1': t1,
'trgt': trgt,
'side': side_
}, axis=1).dropna(subset=['trgt'])
df0 = apply_triple_barrier(close=close,
events=events,
ptsl=ptsl_,
molecule=events.index)
events['t1'] = df0.loc[:, ['t1', 'pt', 'sl']].dropna(how='all').min(axis=1)
if side is None: events = events.drop('side', axis=1)
return events
## AFML Snipet 3.7
# get_bins関数の拡張
def get_bins(events,close):
#events発生時の価格
events_ = events.dropna(subset=['t1'])
px = events_.index.union(events_['t1'].values).drop_duplicates()
px = close.reindex(px,method='bfill')
#outオブジェクトを生成
out = pd.DataFrame(index=events_.index)
out['ret'] = px.loc[events_['t1'].values].values/px.loc[events_.index]-1
#メタラベリング
if 'side' in events_:
out['ret'] *= events_['side'] #sideをかけて,リターンの符号を決定
out['bin'] = np.sign(out['ret']) #符号をとって
if 'side' in events_:
out.loc[out['ret']<=0,'bin']=0 #ret<=0のケースはベットしないようにラベリング
return out
#event時間を用意
tevents = pfm.index[:-h]
close = time_bar.cl
#水平バリアの設定(リターン目標)
trgt = pfm.daily_vol[:-h]
ptsl = np.array([1, 1]) #サイドの学習は行わないので,下部バリアまたは上部バリアが0でも良い
#垂直バリアの設定
t1 = close.index.searchsorted(tevents + pd.Timedelta(hours=h))
t1 = t1[t1 < close.shape[0]]
t1 = pd.Series(close.index[t1], index=tevents[:t1.shape[0]])
#1次モデルの作成(sideの設定).ここでは各イベントに対してランダムにラベリング
side = pd.Series(np.random.choice([-1, 1], len(trgt)), index=trgt.index)
#トリプルバリアeventsの作成
events = get_events(close=close,
tevents=tevents,
ptsl=ptsl,
trgt=trgt,
t1=t1,
side=side)
#メタラベリングの実施
df_label = get_bins(events, close)
print("---- 1次モデルのサイド ----\n",side.value_counts())
print("---- メタラベル ----\n",df_label.bin.value_counts())
fig = plt.figure(figsize=(15, 8))
ax = fig.add_subplot(2, 1, 1)
ax.set_ylabel('side')
ax.plot(side[:200])
ax = fig.add_subplot(2, 1, 2)
ax.set_ylabel('meta_label')
ax.plot(df_label['bin'][:200])
ax = ax.twinx()
ax.set_ylim(-0.1, 0.1)
ax.set_ylabel('ret - trgt')
ax.plot(df_label['ret'][:200],color='orange')
plt.show()
メタラベ戦略
1時モデルの作成
def make_tech_idc(time_bar):
df = time_bar.copy(deep=True)
#移動平均線のクロスによりsideを決定
df["sma_short"] = df["cl"].rolling(window=period_short_sma).mean()
df["sma_long"] = df["cl"].rolling(window=period_long_sma).mean()
#ボリンジャーバンドを作成
df['bbstd'] = df['cl'].rolling(window=period_short_sma).std()
df['bbh1'] = df['sma_short'] + df['bbstd'] * 1
df['bbl1'] = df['sma_short'] - df['bbstd'] * 1
df['bbh2'] = df['sma_short'] + df['bbstd'] * 2
df['bbl2'] = df['sma_short'] - df['bbstd'] * 2
df['bbh3'] = df['sma_short'] + df['bbstd'] * 3
df['bbl3'] = df['sma_short'] - df['bbstd'] * 3
return df
def make_first_prediction(df):
#smaによるposを導出
diff = df["sma_short"] - df["sma_long"]
df['side_sma'] = np.where(
(np.sign(diff) - np.sign(diff.shift(1)) == 2),
1, 0) # diffの各値を直前のデータで引き,2ならゴールデンクロス(買い)なので1に,それ以外は0とおく
#ボリンジャーバンドによるposを導出
df['side_bbh1'] = np.where(df.bbh1 < df.hi, 1, 0)
df['side_bbh2'] = np.where(df.bbh2 < df.hi, 1, 0)
df['side_bbh3'] = np.where(df.bbh3 < df.hi, 1, 0)
df['side_bbl1'] = np.where(df.bbl1 > df.lo, 1, 0)
df['side_bbl2'] = np.where(df.bbl2 > df.lo, 1, 0)
df['side_bbl3'] = np.where(df.bbl3 > df.lo, 1, 0)
df['side_1st'] = df['side_sma'] | df['side_bbh1'] | df['side_bbh2'] | df[
'side_bbh3'] | df['side_bbl1'] | df['side_bbl2'] | df['side_bbl3']
return df
#描画のための関数を作成
#参考にしたテクニカル指標を返す関数
def y_long_point(s: pd.Series):
if s.side_sma == 1:
return s.sma_short
elif s.side_bbh3 == 1:
return s.bbh3
elif s.side_bbl3 == 1:
return s.bbl3
elif s.side_bbh2 == 1:
return s.bbh2
elif s.side_bbl2 == 1:
return s.bbl2
elif s.side_bbh1 == 1:
return s.bbh1
elif s.side_bbl1 == 1:
return s.bbl1
#一次モデルの予測結果を表示する関数
def show_first_prediction(df):
#適当な期間のテクニカル分析の結果を表示
num_candle_plot = len(df) // 30
df_tmp = df.iloc[num_candle_plot // 10:num_candle_plot, :]
# figを定義
fig = make_subplots(rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
row_width=[0.2, 0.7],
x_title="Date")
# SMA(移動平均線)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp["sma_short"],
name="sma_short",
mode="lines"),
row=1,
col=1)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp["sma_long"],
name="sma_long",
mode="lines"),
row=1,
col=1)
# BB(ボリンジャーバンド)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp['bbh3'],
line_color='gray',
line={'dash': 'dash'},
fill='tonexty',
name='higher band3',
opacity=0.5),
row=1,
col=1)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp['bbl3'],
line_color='gray',
line={'dash': 'dash'},
fill='tonexty',
name='lower band3',
opacity=0.5),
row=1,
col=1)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp['bbh2'],
line_color='gray',
line={'dash': 'dash'},
name='higher band2'),
row=1,
col=1)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp['bbl2'],
line_color='gray',
line={'dash': 'dash'},
name='lower band2'),
row=1,
col=1)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp['bbh1'],
line_color='gray',
line={'dash': 'dash'},
name='higher band1'),
row=1,
col=1)
fig.add_trace(go.Scatter(x=df_tmp.index,
y=df_tmp['bbl1'],
line_color='gray',
line={'dash': 'dash'},
name='lower band1'),
row=1,
col=1)
fig.add_trace(go.Candlestick(x=df_tmp.index,
open=df_tmp['op'],
high=df_tmp['hi'],
low=df_tmp['lo'],
close=df_tmp['cl'],
name='candlestick'),
row=1,
col=1)
y_array_long_point = df_tmp[df_tmp.side_1st == 1].apply(
lambda x: y_long_point(x), axis=1)
df_long_point = df_tmp[df_tmp.side_1st == 1]
fig.add_trace(go.Scatter(x=df_long_point.index,
y=y_array_long_point,
name="long point",
mode="markers",
marker_symbol="triangle-up",
marker_size=7,
marker_color="black"),
row=1,
col=1)
# 出来高
fig.add_trace(go.Bar(x=df_tmp.index, y=df_tmp["volume"], name="volume"),
row=2,
col=1)
# y軸名を定義
fig.update_yaxes(title_text="株価", row=1, col=1)
fig.update_yaxes(title_text="出来高", row=2, col=1)
fig.update(layout_xaxis_rangeslider_visible=False)
fig.show()
period_short_sma = 5
period_long_sma = 15
df = make_tech_idc(time_bar)
df = make_first_prediction(df)
show_first_prediction(df)
df.side_1st.value_counts()
2次モデルの作成
# 入力の特徴量数:一次モデルの入力に使用したsma_short,sma_long,bbh1,bbh2,bbh3,bbl1,bbl2,bbl3と1次モデルの出力であるsideを使用
feature_num = 9
batch_size = 8
sequence_num = 96
n_epoch = 50
target_dim = 1
hidden_dim = 32
#対象とする期間を抽出:sma!=Nanかつtimesteps=96から最過去の時刻を抽出
tevents = tevents[tevents > df.index[0] +
pd.Timedelta(hours=(sequence_num + period_long_sma) / 4 + 1)]
test_idx_from = int(len(tevents) * 0.8)
val_idx_from = int(len(tevents) * 0.6)
#get_bins関数の拡張(retが日次ボラティリティを超えるときのみメタラベルを1にする)
def get_bins(events,close,trgt_r=1):
#events発生時の価格
events_ = events.dropna(subset=['t1'])
px = events_.index.union(events_['t1'].values).drop_duplicates()
px = close.reindex(px,method='bfill')
#outオブジェクトを生成
out = pd.DataFrame(index=events_.index)
out['ret'] = px.loc[events_['t1'].values].values/px.loc[events_.index]-1
#メタラベリング
if 'side' in events_:
out['ret'] *= events_['side'] #sideをかけて,リターンの符号を決定
out['ret'] -= trgt_r*events_['trgt']
out['bin'] = np.sign(out['ret']) #符号をとって
if 'side' in events_:
out.loc[out['ret']<=0,'bin']=0 #ret<=0のケースはベットしないようにラベリング
return out
close = time_bar.cl
#水平バリアの設定(リターン目標)
trgt = pfm.daily_vol[tevents]
ptsl = np.array([1, 1]) #サイドの学習は行わないので,下部バリアまたは上部バリアが0でも良い
#垂直バリアの設定
t1 = close.index.searchsorted(tevents + pd.Timedelta(hours=h))
t1 = t1[t1 < close.shape[0]]
t1 = pd.Series(close.index[t1], index=tevents[:t1.shape[0]])
#sideの設定
side = df.side_1st
#トリプルバリアeventsの作成
events = get_events(close=close,
tevents=tevents,
ptsl=ptsl,
trgt=trgt,
t1=t1,
side=side)
out = get_bins(events, close, trgt_r=0.1)
df['meta_label'] = np.sign(out.bin) #トリプルバリア法によってメタラベルを作成
out = get_bins(events.drop('side', axis=1), close,
trgt_r=1) #side列を除いてget_binsを呼び出し
df['ret'] = 1 * (out.bin == 1) #トリプルバリア法によって正解ラベルを作成
df
print(df.meta_label.value_counts())
plt.hist(df.meta_label, bins=[-0.5, 0.5, 1.5], align='mid', rwidth=0.5)
plt.xlabel('Label')
plt.ylabel('Frequency')
plt.xticks([0, 1])
plt.show()
def prep_X(X_df,tEvents):
feats = np.zeros((len(tEvents),sequence_num,feature_num))
for i,t in enumerate(tEvents):
feats[i, :, :] = X_df.loc[t-pd.Timedelta(hours=sequence_num/4):t-pd.Timedelta(seconds=1), :].values
return feats
#ローソク足データの最初のsequence_numに相当するデータは入力データがないので除く
y_data = df.loc[time_bar.index[sequence_num]:]
y_data = y_data.loc[tevents, 'meta_label']
#入力データの作成
feats_1st = [
'sma_short', 'sma_long', 'bbh1', 'bbh2', 'bbh3', 'bbl1', 'bbl2', 'bbl3'
]
X_df = df.loc[:, feats_1st + ['side_1st']].copy(deep=True)
#対数データは差分変換
X_df.loc[:, feats_1st] = np.log(X_df.loc[:, feats_1st])
X_df.loc[:, feats_1st] = X_df.loc[:, feats_1st].apply(lambda x: x.diff(periods=1))
#標準化
X_df = X_df.apply(lambda x: (x - x.mean()) / x.std(), axis=0)
X_data = prep_X(X_df, y_data.index)
X_train = X_data[:val_idx_from]
X_val = X_data[val_idx_from:test_idx_from]
X_test = X_data[test_idx_from:]
y_train = y_data[:val_idx_from].values
y_val = y_data[val_idx_from:test_idx_from].values
y_test = y_data[test_idx_from:].values
# Tensorに変更
X_train = torch.tensor(X_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)
# Datasetを作成
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
# dataloaderを作成
train_batch = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=False)
val_batch = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False)
test_batch = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
shuffle=False)
# ミニバッチデータセットの確認
for data, label in train_batch:
print("batch data size: {}".format(data.size())) # バッチの入力データサイズ
print("batch label size: {}".format(label.size())) # バッチのラベルサイズ
break
class MLP(nn.Module):
'''
Multilayer Perceptron.
'''
def __init__(self, sequence_num, feature_num, hidden_dim, target_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Flatten(),
nn.Linear(sequence_num*feature_num, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, target_dim)
)
def forward(self, x):
'''Forward pass'''
# print('x.size:',x.size())
m = nn.Flatten()
l = nn.Linear(sequence_num*feature_num, hidden_dim)
# print(m(x).size())
# print(l(m(x)).size())
return self.layers(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)
# Prepare for training
model = MLP(sequence_num, feature_num, hidden_dim,
target_dim).to(device)
criterion = nn.BCEWithLogitsLoss() #nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
train_loss_list = [] # 学習損失
val_loss_list = [] # 評価損失
best_model = model.state_dict()
thre = 0
early_stopping = 5
patience = 0
# 学習(エポック)の実行
for i in range(n_epoch):
# エポックの進行状況を表示
print('---------------------------------------------')
print("Epoch: {}/{}".format(i + 1, n_epoch))
# 損失の初期化
train_loss = 0 # 学習損失
val_loss = 0 # 評価損失
# ---------学習パート--------- #
# ニューラルネットワークを学習モードに設定
model.train()
# ミニバッチごとにデータをロードし学習
for data, label in train_batch:
# GPUにTensorを転送
data = data.to(device)
label = label.to(device)
# 勾配を初期化
optimizer.zero_grad()
# データを入力して予測値を計算(順伝播)
y_pred = model(data)
y_pred = torch.sigmoid(y_pred)
# 損失(誤差)を計算
loss = criterion(y_pred, label.view(-1, target_dim))
# 勾配の計算(逆伝搬)
loss.backward()
# パラメータ(重み)の更新
optimizer.step()
# ミニバッチごとの損失を蓄積
train_loss += loss.item()
# ミニバッチの平均の損失を計算
batch_train_loss = train_loss / len(train_batch)
# ---------学習パートはここまで--------- #
# ---------評価パート--------- #
# ニューラルネットワークを評価モードに設定
model.eval()
patience += 1
thre_tmp = 0
# 評価時の計算で自動微分機能をオフにする
with torch.no_grad():
for data, label in val_batch:
# GPUにTensorを転送
data = data.to(device)
label = label.to(device)
# データを入力して予測値を計算(順伝播)
y_pred = model(data)
y_pred = torch.sigmoid(y_pred)
thre_tmp += y_pred.mean()
# 損失(誤差)を計算
loss = criterion(y_pred, label.view(-1, target_dim))
# ミニバッチごとの損失を蓄積
val_loss += loss.item()
# ミニバッチの平均の損失を計算
batch_val_loss = val_loss / len(val_batch)
thre_tmp /= float(len(val_batch))
# ---------評価パートはここまで--------- #
# エポックごとに損失を表示
print("Train_Loss: {:.4E} Val_Loss: {:.4E}".format(batch_train_loss,
batch_val_loss))
# 損失をリスト化して保存
train_loss_list.append(batch_train_loss)
val_loss_list.append(batch_val_loss)
if batch_val_loss <= min(val_loss_list):
best_model = model.state_dict()
thre = thre_tmp
patience = 0
elif patience > early_stopping:
break
model.load_state_dict(best_model)
plt.plot(train_loss_list,label='train_loss')
plt.plot(val_loss_list,label='val_loss')
plt.legend()