Tensorflow直方图错误

时间:2018-01-16 12:22:30

标签: python tensorflow histogram

我是tensorflow的新手,我的代码如下:

import tensorflow as tf
logdir="/tmp/mnist_tutorial5/"
mnist = tf.contrib.learn.datasets.mnist.read_data_sets(train_dir=logdir+"data",one_hot = True)
tf.reset_default_graph()
sess = tf.Session()
writer = tf.summary.FileWriter(logdir)
def model(input):
    w = tf.Variable(tf.truncated_normal([784,10], stddev=0.1), name="W")
    b = tf.Variable(tf.constant(0.1, shape=[10]), name="B")
    act = tf.matmul(input,w) + b
    tf.summary.histogram("weights",w)
    tf.summary.histogram("biases",b)
    tf.summary.histogram("activations",act)
return act

def train():
    x = tf.placeholder(tf.float32, shape=[None, 784], name="input_img")
    y = tf.placeholder(tf.float32, shape=[None, 10], name="labels")
    my_label = model(x)
    print("linear_regression is completed")
    mean_error = tf.reduce_mean(tf.reduce_sum(tf.square(my_label-y)))
    tf.summary.scalar("loss", mean_error)
    train_step=tf.train.GradientDescentOptimizer(0.0003).minimize(mean_error)
    correct_prediction = tf.equal(tf.argmax(my_label, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar("accuracy", accuracy)
    sess.run(tf.global_variables_initializer())
    summ = tf.summary.merge_all()
    for i in range(2000):
        batch = mnist.train.next_batch(100)
        train_accuracy = sess.run(train_step, feed_dict={x: batch[0], y: batch[1]})
        print("%s th iteration"%i)
        if i%500==0:
            print("over 2")
            summarys = sess.run(summ, {x: batch[0], y: batch[1]})##i'm getting error here
            print("over 3")
        writer.add_summary(summarys,i)
    print("one over")
train()
writer.add_graph(sess.graph)

这是我得到的错误:

InvalidArgumentError (see above for traceback): Nan in summary histogram for: weights
     [[Node: weights = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](weights/tag, W/read)]]

1 个答案:

答案 0 :(得分:0)

  • 首先它只是一个没有隐藏层的单层网络
  • 您尚未应用任何类型的激活功能
  • 没有激活意味着您的输出不会被压扁
  • 你的渐变会爆炸,因为更新过程中的重量会爆炸,这可能导致纳米
  • 你正在训练2000个时代,这意味着在几个时代之后,权重将是纳米
  • 尝试使用像sigmoid这样的激活函数并添加至少一个隐藏层,你会没事的
  • 也减少了时代的数量......对于mnist分类,你不需要模型训练2000个时代......浪费时间