0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Anacondaのプログラムで、このエラーの質問です。

Posted at

詳解 ディープラーニング片手に4.3.1.3のプログラムを以下のようなプログラムで試しています。

import numpy as np
import tensorflow as tf
import chainer as Variable

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn.datasets import fetch_mldata

from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.optimizers import SGD

mnist=fetch_mldata('MNIST original', data_home=".")

np.random.seed(0)

n=len(mnist.data)
N=10000 #MNISTの一部データで実験
indices=np.random.permutation(range(n))[:N] #ランダムにN枚を選択
X=mnist.data[indices]
y=mnist.target[indices]
Y=np.eye(10)[y.astype(int)] #1-of-k表現に変換

X_train, X_test, Y_train, Y_test=train_test_split(X,Y,train_size=0.8)

n_in=len(X[0])#784
n_hidden=200
n_out=len(Y[0])#10

def prelu(X, alpha):
return tf.maximum(tf.zeros(tf.shape(X)), X)
+alpha*tf.minimum(tf.zeros(tf.shape(X), X))

#入力層^隠れ層
w0=tf.Variable(tf.truncated_normal([n_in, n_hidden], stddev=0.01))
b0=tf.Variable(tf.zeros([n_hidden]))
alpha0=tf.Variable(tf.zeros([n_hidden]))
h0=prelu(tf.matmul(X, w0)+b0, alpha0)

#隠れ層ー隠れ層

w1=tf.Variable(tf.truncated_normal([n_hidden, n_hidden], stddev=0.01))
b1=tf.Variable(tf.zeros([n_hidden]))
alpha1=tf.Variable(tf.zeros([n_hidden]))
h1=prelu(tf.matmul(h0, w1)+b1, alpha1)

w2=tf.Variable(tf.truncated_normal([n_hidden, n_hidden], stddev=0.01))
b2=tf.Variable(tf.zeros([n_hidden]))
alpha2=tf.Variable(tf.zeros([n_hidden]))
h2=prelu(tf.matmul(h1, w2)+b2, alpha2)

w3=tf.Variable(tf.truncated_normal([n_hidden, n_hidden], stddev=0.01))
b3=tf.Variable(tf.zeros([n_hidden]))
alpha3=tf.Variable(tf.zeros([n_hidden]))
h3=prelu(tf.matmul(h2, w3)+b3, alpha3)

#隠れ層ー出力層

w4=tf.Variable(tf.truncated_normal([n_hidden, n_out], stddev=0.01))
b4=tf.Variable(tf.zeros([n_out]))
y=prelu(tf.matmul(h3, w4)+b4)

from keras.layers.advanced_activations import PReLU

model=Sequential()
model.add(Dense(n_hidden, imput_dim=n_in))
model.add(PReLU())

model.add(Dense(n_hidden))
model.add(PReLU())

model.add(Dense(n_hidden))
model.add(PReLU())

model.add(Dense(n_hidden))
model.add(PReLU())

model.add(Dense(n_out))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.01),
metrics=['accuracy'])

epochs=1000
batch_size=100

model.fit(X_train, Y_train,epochs=epochs,batch_size=batch_size)

loss_and_metrics=model.evaluate(X_test,Y_test)
print(loss_and_metrics)

この時エラーが以下のように表示されます。


TypeError Traceback (most recent call last)
in ()
37 b0=tf.Variable(tf.zeros([n_hidden]))
38 alpha0=tf.Variable(tf.zeros([n_hidden]))
---> 39 h0=prelu(tf.matmul(X, w0)+b0, alpha0)
40
41 #隠れ層ー隠れ層

~\Anaconda3\envs\TK1\lib\site-packages\tensorflow\python\ops\math_ops.py in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, name)
1814 else:
1815 return gen_math_ops._mat_mul(
-> 1816 a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
1817
1818

~\Anaconda3\envs\TK1\lib\site-packages\tensorflow\python\ops\gen_math_ops.py in _mat_mul(a, b, transpose_a, transpose_b, name)
1215 """
1216 result = _op_def_lib.apply_op("MatMul", a=a, b=b, transpose_a=transpose_a,
-> 1217 transpose_b=transpose_b, name=name)
1218 return result
1219

~\Anaconda3\envs\TK1\lib\site-packages\tensorflow\python\framework\op_def_library.py in apply_op(self, op_type_name, name, **keywords)
587 _SatisfiesTypeConstraint(base_type,
588 _Attr(op_def, input_arg.type_attr),
--> 589 param_name=input_name)
590 attrs[input_arg.type_attr] = attr_value
591 inferred_from[input_arg.type_attr] = input_name

~\Anaconda3\envs\TK1\lib\site-packages\tensorflow\python\framework\op_def_library.py in _SatisfiesTypeConstraint(dtype, attr_def, param_name)
58 "allowed values: %s" %
59 (param_name, dtypes.as_dtype(dtype).name,
---> 60 ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
61
62

TypeError: Value passed to parameter 'a' has DataType uint8 not in list of allowed values: float16, float32, float64, int32, complex64, complex128

これらを見てそれぞれどのように改善するべきなのか分からず途方に暮れています・・・

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?