无法获得张量值

时间:2017-08-25 16:15:12

标签: tensorflow mnist

当运行MNIST数据集时,我想知道在训练批处理期间我的模型实际输出了什么。这是我的代码:(我没有添加优化器和损失函数):

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE  = 784 # the total pixels of the input images
OUTPUT_NODE = 10  # the output varies from 0 to 9
LAYER_NODE = 500
BATCH_SIZE = 100
TRAINING_STEPS = 10

def inference(input_tensor, avg_class, weight1, biase1, weight2, biase2):
    if avg_class == None:
        layer = tf.nn.relu(tf.matmul(input_tensor, weight1) + biase1)
        return tf.matmul(layer, weight2)+biase2
    else:
        layer = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weight1)) + 
                avg_class.average(biase1))
        return tf.matmul(layer, avg_class.average(weight2)) + avg_class.average(biase2)


def train(mnist):
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name = 'x-input')
    y = tf.placeholder(tf.float32, [None, OUTPUT_NODE],name = 'y-input')

    weight1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER_NODE], stddev = 0.1))
    biase1  = tf.Variable(tf.constant(0.1, shape = [LAYER_NODE]))
    weight2 = tf.Variable(tf.truncated_normal([LAYER_NODE, OUTPUT_NODE], stddev = 0.1))
    biase2  = tf.Variable(tf.constant(0.1, shape = [OUTPUT_NODE]))

    out = inference(x, None, weight1, biase1, weight2, biase2)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        validate_feed = {x:mnist.validation.images, y:mnist.validation.labels}
        test_feed = {x:mnist.test.images, y:mnist.test.labels}

        for i in range(TRAINING_STEPS):

            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            sess.run(out, feed_dict= {x:xs, y:ys})
            print(out)

def main(arg = None):
    mnist = input_data.read_data_sets("/home/vincent/Tensorflow/MNIST/data/", one_hot = True)
    train(mnist)

if __name__ == '__main__':
    tf.app.run()

我尝试打印出来:

  

Tensor(“add_1:0”,shape =(?,10),dtype = float32)

如果我想知道out的价值,我该怎么办? 我尝试print(out.eval()),它引发了错误

1 个答案:

答案 0 :(得分:2)

out是一个张量对象。如果要获取其值,请替换

sess.run(out, feed_dict= {x:xs, y:ys})
print(out)

res_out=sess.run(out, feed_dict= {x:xs, y:ys})
print(res_out)