Tensorflow:tf.gather不适用于大数组

时间:2017-10-10 10:00:55

标签: arrays numpy tensorflow

我的numpy数组存在问题,其大小(29000,200,1024)(7Go)。它是我数据集图像的特征。

加载后,我的函数接收索引以将当前批处理构建为张量。 不幸的是,使用:

tf.gather(array, indices) 

冻结。虽然打印例如array [0]立即工作。 我试图用convert_to_tensor转换我的numpy数组,所以我可以直接使用array_tensor(indice)但是convert_to_tensor会导致内存限制错误。

有什么解决方法吗?

非常感谢

1 个答案:

答案 0 :(得分:2)

将numpy数组直接传递给tf op构造API将其转换为tf.constant op,其中包含op定义中的数据,因此您将整个内容内联到GraphDef中,受到2GB GraphDef限制。

要避免这种情况,请创建var=tf.Variable(my_placeholder)并通过运行var.initializer, feed_dict={my_placeholder: np_array}初始化此变量。这会将numpy数组数据直接放入变量存储中。