3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

torch.mean(torch.pow(torch.sub(a, b), 2))よりnn.MSELoss(a, b)の方が断然早い

Posted at

何をしたいか

平均二乗誤差(Mean Squared Error)を手で生で書いた場合と、pytorchのAPIを叩いた場合で、速さを比較したい。

MSEってなに

誤差を二乗したものの平均です。
$$\ell(x, y) = L = {l_1,\dots,l_N}^\top, \quad
l_n = \left( x_n - y_n \right)^2$$
↑はnn.MSELoss()より。
例えば$[1, 2]$と$[2, 3]$のMSEは、$((2-1)^{2} + (3-2)^{2}) \div 2 = 1$です。

コードの比較

生で書いた場合
torch.mean(torch.pow(torch.sub(a, b), 2))
API
nn.MSELoss(a, b)

結果

mse.py
from time import time
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn


a = torch.randn([1, 2])
b = torch.randn([1, 2])
mse = nn.MSELoss()


def measure_with_api(cnt):
    times = []
    since = time()
    for i in range(cnt):
        mse(a, b)
        times.append(time() - since)
    return times


def measure_without_api(cnt):
    times = []
    since = time()
    for i in range(cnt):
        torch.mean(torch.pow(torch.sub(a, b), 2))
        times.append(time() - since)
    return times


if __name__ == '__main__':
    sns.set()
    times_without = measure_without_api(1000)
    times_with = measure_with_api(1000)
    plt.plot(times_without, label="without")
    plt.plot(times_with, label="with")
    plt.title("measure")
    plt.legend()
    plt.show()

これを実行すると、こんな結果が表示されると思います。
measure.png

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?