如何在张量流中获得模型的输出和输入值?

时间:2017-05-24 14:17:39

标签: python tensorflow tensorboard

我正在研究GAN并决定使用HyperGAN实现我的算法。它是使用TensorFlow的DCGAN包装器。 HyperGAN使用TF的检查点方法保存输出。

后来,我尝试使用以下方式运行模型:

import tensorflow as tf
sess=tf.Session()    
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())

然而,由于它是一个GAN,它需要一个输入Latent Vector并输出一个图像。这是使用

完成的
out_image = sess.run(last_node, feed_dict(input_node: value))

但是由于我加载了模型,我不知道最后一个节点的名称是什么以及输入节点占位符的名称是什么。如何获取首先用于创建图形的名称?我尝试使用TensorBoard进行可视化,但图表很大,因此卡住了。

1 个答案:

答案 0 :(得分:0)

您应该尝试在图表中打印张量列表:

with tf.Graph().as_default() as graph:
....

count = 0
for op in graph.get_operations():
    print op.values()
    count+=1
    if count == 50:
        assert False

为了查看图表的前50个节点,您将看到如下内容:

(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,)

我将计数放在那里,因为通常终端打印出很多张量,初始输入节点名称在终端中消失。

最后,只需注释掉计数使用的行:

#count = 0
for op in graph.get_operations():
    print op.values()
    #count+=1
    #if count == 50:
    #    assert False

打印出最后几个节点(即输出节点)。