我是tensorflow的新手,我的代码如下:
import tensorflow as tf
logdir="/tmp/mnist_tutorial5/"
mnist = tf.contrib.learn.datasets.mnist.read_data_sets(train_dir=logdir+"data",one_hot = True)
tf.reset_default_graph()
sess = tf.Session()
writer = tf.summary.FileWriter(logdir)
def model(input):
w = tf.Variable(tf.truncated_normal([784,10], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[10]), name="B")
act = tf.matmul(input,w) + b
tf.summary.histogram("weights",w)
tf.summary.histogram("biases",b)
tf.summary.histogram("activations",act)
return act
def train():
x = tf.placeholder(tf.float32, shape=[None, 784], name="input_img")
y = tf.placeholder(tf.float32, shape=[None, 10], name="labels")
my_label = model(x)
print("linear_regression is completed")
mean_error = tf.reduce_mean(tf.reduce_sum(tf.square(my_label-y)))
tf.summary.scalar("loss", mean_error)
train_step=tf.train.GradientDescentOptimizer(0.0003).minimize(mean_error)
correct_prediction = tf.equal(tf.argmax(my_label, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
sess.run(tf.global_variables_initializer())
summ = tf.summary.merge_all()
for i in range(2000):
batch = mnist.train.next_batch(100)
train_accuracy = sess.run(train_step, feed_dict={x: batch[0], y: batch[1]})
print("%s th iteration"%i)
if i%500==0:
print("over 2")
summarys = sess.run(summ, {x: batch[0], y: batch[1]})##i'm getting error here
print("over 3")
writer.add_summary(summarys,i)
print("one over")
train()
writer.add_graph(sess.graph)
这是我得到的错误:
InvalidArgumentError (see above for traceback): Nan in summary histogram for: weights
[[Node: weights = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](weights/tag, W/read)]]
答案 0 :(得分:0)