我正在移植Tensorflow Cifar-10教程文件用于我自己的目的,并且由于Tensorflow的图形和会话架构而遇到了一个我无法轻易概念化的有趣问题。
问题在于我的输入数据集是高度不平衡的,因此我需要在输入管道中对其标签进行“过采样”(和扩充)条件。在普通的Python环境中,我可以设置一个形式为if label then duplicate
的简单控制流语句,但由于在正在运行的会话之外存在的控制流操作而且label
,我无法在Tensorflow中编写相同的语法。 1}}在这种情况下不会返回值。
我的问题是,在Tensorflow队列中过度采样张量的最简单方法是什么?
我知道我可以在输入操作之前简单地复制感兴趣的数据,但这显然会消除运行期间过采样所带来的任何存储节省。
我想要做的是评估Tensor的标签(在Cifar-10情况下,通过检查1D image.label属性)然后用固定因子复制该Tensor(例如,如果标签是“dog”,则复制4x) )并将所有Tensors发送到批处理操作。我最初的方法是在Reader操作之后和批处理操作之前尝试复制步骤,但这也是在正在运行的会话之外。我正在考虑使用TF的while
控制流语句,但我不确定这个功能是否能够执行除修改输入Tensor之外的任何操作。你觉得怎么样?
更新#1
基本上我试图创建一个py_func(),它接受展平的图像字节和标签字节,并根据标签的值垂直堆叠相同的图像字节N次,然后将其作为返回(N x image_bytes张量(py_func()自动将输入张量转换为numpy和back)。我试图从变量高度张量创建一个input_queue,其形状报告为(?,image_bytes),然后实例化一个读取器以撕掉image_byte大小的记录。好吧,你似乎无法建立未知数据大小的队列,所以这种方法对我来说不起作用,事后才有意义,但我仍然无法概念化识别队列中记录的方法,并重复该记录特定次数。
更新#2
48小时后,我终于想出了一个解决方法,感谢this SO thread我能够挖掘出来。该线程中概述的解决方案仅假设2类数据,因此tf.cond()
函数足以在pred
为True时对一个类进行过采样,如果pred
为False则对其他类进行过采样。为了获得n路有条件,我尝试建立导致tf.case()
的{{1}}函数。事实证明ValueError: Cannot infer Tensor's rank
函数不保留tf.case()
属性,并且图形构造失败,因为输入管道末端的任何批处理操作必须采用形状参数,或采用定义形状的张量,按照这句话在documentation:
N.B。:您必须确保(i)传递形状参数,或(ii)张量中的所有张量必须具有完全定义的形状。如果这些条件都不成立,则会引发ValueError。
进一步挖掘表明这是一个known issue shape
,截至2016年12月尚未解决。这只是Tensorflow众多控制流程中的一个。无论如何,我对n路过采样问题的精简解决方案是:
tf.case()
答案 0 :(得分:1)
我的问题的解决方案是一种解决方法,并在“更新2'在最初的问题中。