如何以Tensorflow队列中的张量属性(“过采样”)为条件复制输入张量?

时间:2016-12-01 12:34:46

标签: tensorflow

我正在移植Tensorflow Cifar-10教程文件用于我自己的目的,并且由于Tensorflow的图形和会话架构而遇到了一个我无法轻易概念化的有趣问题。

问题在于我的输入数据集是高度不平衡的,因此我需要在输入管道中对其标签进行“过采样”(和扩充)条件。在普通的Python环境中,我可以设置一个形式为if label then duplicate的简单控制流语句,但由于在正在运行的会话之外存在的控制流操作而且label,我无法在Tensorflow中编写相同的语法。 1}}在这种情况下不会返回值。

我的问题是,在Tensorflow队列中过度采样张量的最简单方法是什么?

我知道我可以在输入操作之前简单地复制感兴趣的数据,但这显然会消除运行期间过采样所带来的任何存储节省。

我想要做的是评估Tensor的标签(在Cifar-10情况下,通过检查1D image.label属性)然后用固定因子复制该Tensor(例如,如果标签是“dog”,则复制4x) )并将所有Tensors发送到批处理操作。我最初的方法是在Reader操作之后和批处理操作之前尝试复制步骤,但这也是在正在运行的会话之外。我正在考虑使用TF的while控制流语句,但我不确定这个功能是否能够执行除修改输入Tensor之外的任何操作。你觉得怎么样?

更新#1

基本上我试图创建一个py_func(),它接受展平的图像字节和标签字节,并根据标签的值垂直堆叠相同的图像字节N次,然后将其作为返回(N x image_bytes张量(py_func()自动将输入张量转换为numpy和back)。我试图从变量高度张量创建一个input_queue,其形状报告为(?,image_bytes),然后实例化一个读取器以撕掉image_byte大小的记录。好吧,你似乎无法建立未知数据大小的队列,所以这种方法对我来说不起作用,事后才有意义,但我仍然无法概念化识别队列中记录的方法,并重复该记录特定次数。

更新#2

48小时后,我终于想出了一个解决方法,感谢this SO thread我能够挖掘出来。该线程中概述的解决方案仅假设2类数据,因此tf.cond()函数足以在pred为True时对一个类进行过采样,如果pred为False则对其他类进行过采样。为了获得n路有条件,我尝试建立导致tf.case()的{​​{1}}函数。事实证明ValueError: Cannot infer Tensor's rank函数不保留tf.case()属性,并且图形构造失败,因为输入管道末端的任何批处理操作必须采用形状参数,或采用定义形状的张量,按照这句话在documentation

  

N.B。:您必须确保(i)传递形状参数,或(ii)张量中的所有张量必须具有完全定义的形状。如果这些条件都不成立,则会引发ValueError。

进一步挖掘表明这是一个known issue shape,截至2016年12月尚未解决。这只是Tensorflow众多控制流程中的一个。无论如何,我对n路过采样问题的精简解决方案是:

tf.case()

1 个答案:

答案 0 :(得分:1)

我的问题的解决方案是一种解决方法,并在“更新2'在最初的问题中。