如何将numpy数组提供给TensorFlow的预取和缓冲管道

时间:2016-07-13 09:14:49

标签: numpy tensorflow

我试图关注Cifar10 example。但是,我想用Numpy数组替换文件读取。这样做有几个好处:

  • 更简单的代码(我想删除二进制文件解析)
  • 更简单的图形和可视化 - >更容易向其他观众解释
  • 小的性能提升(由于I / O和解析)?

这是一个简单的方法吗?

1 个答案:

答案 0 :(得分:0)

您需要通过以下任一方式获得张量reshape_image

  • 给它起个名字
  • 找到其默认名称,例如Tensorboard
reshaped_image = tf.cast(read_input.uint8image, tf.float32, name="float_image")

然后你可以使用feed_dict来提供你的numpy数组,如:

reshaped_image = tf.get_default_graph().get_tensor_by_name("float_image")
sess.run(loss, feed_dict={reshaped_image: your_numpy})

标签也是如此。