tensorflow をインストールした直後にインストールが出来ているか簡単に確認するプログラム
元のプログラムは、こちら
http://qiita.com/mochizukikotaro/items/7624b81af498317a0865
このプログラムを、アップデートし、python3 にも対応しました。
tf_sample_square_error.py
#! /usr/bin/python
#
# tf_sample_square_error.py
#
# Aug/07/2023
# ------------------------------------------------------------------
import sys
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
sys.stderr.write("*** start ***\n")
input_x = [[1.],[5.]]
input_y = [[4.],[2.]]
x = tf.placeholder("float", [None, 1])
y_ = tf.placeholder("float", [None, 1])
a = tf.Variable([1.], name="slope")
b = tf.Variable([0.], name="y-intercept")
y = tf.multiply(a, x) + b
init = tf.global_variables_initializer()
# 誤差関数
loss = tf.reduce_sum(tf.square(y_ - y))
# トレーニング方法は、勾配降下法を選択
train_step = tf.train.GradientDescentOptimizer(0.03).minimize(loss)
with tf.Session() as sess:
sess.run(init)
print('初期状態')
print('誤差' + str(sess.run(loss, feed_dict={x: input_x, y_: input_y})))
print("slope: %f, y-intercept: %f" % (sess.run(a), sess.run(b)))
for step in range(100):
sess.run(train_step, feed_dict={x: input_x, y_: input_y})
if (step+1) % 20 == 0:
print('\nStep: %s' % (step+1))
print('誤差' + str(sess.run(loss, feed_dict={x: input_x, y_: input_y})))
print("slope: %f, y-intercept: %f" % (sess.run(a), sess.run(b)))
#
sys.stderr.write("*** end ***\n")
# ------------------------------------------------------------------
警告を出さないように実行するには、
export TF_CPP_MIN_LOG_LEVEL=2
./tf_sample_square_error.py
実行結果
tensorflow がインストールしてない時の実行結果
$ ./tf_sample_square_error.py
Traceback (most recent call last):
File "./tf_sample_square_error.py", line 9, in <module>
import tensorflow as tf
ModuleNotFoundError: No module named 'tensorflow'
Arch Linux での tensorflow のインストール方法
sudo pacman -S python-tensorflow
確認した環境
$ uname -a
Linux shimizu 6.4.8-arch1-1 #1 SMP PREEMPT_DYNAMIC Thu, 03 Aug 2023 16:02:01 +0000 x86_64 GNU/Linux
$ python --version
Python 3.11.3
$ python
Python 3.11.3 (main, Jun 5 2023, 09:32:32) [GCC 13.1.1 20230429] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow
>>> tensorflow.__version__
'2.13.0'