何をしたいか
平均二乗誤差(Mean Squared Error)を手で生で書いた場合と、pytorchのAPIを叩いた場合で、速さを比較したい。
MSEってなに
誤差を二乗したものの平均です。
$$\ell(x, y) = L = {l_1,\dots,l_N}^\top, \quad
l_n = \left( x_n - y_n \right)^2$$
↑はnn.MSELoss()より。
例えば$[1, 2]$と$[2, 3]$のMSEは、$((2-1)^{2} + (3-2)^{2}) \div 2 = 1$です。
コードの比較
生で書いた場合
torch.mean(torch.pow(torch.sub(a, b), 2))
API
nn.MSELoss(a, b)
結果
mse.py
from time import time
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
a = torch.randn([1, 2])
b = torch.randn([1, 2])
mse = nn.MSELoss()
def measure_with_api(cnt):
times = []
since = time()
for i in range(cnt):
mse(a, b)
times.append(time() - since)
return times
def measure_without_api(cnt):
times = []
since = time()
for i in range(cnt):
torch.mean(torch.pow(torch.sub(a, b), 2))
times.append(time() - since)
return times
if __name__ == '__main__':
sns.set()
times_without = measure_without_api(1000)
times_with = measure_with_api(1000)
plt.plot(times_without, label="without")
plt.plot(times_with, label="with")
plt.title("measure")
plt.legend()
plt.show()