Tensorflow Estimators API批处理错误?

时间:2018-10-18 12:23:22

标签: python tensorflow

我想在https://github.com/WeiTang114/MVCNN-TensorFlow上使用Estimator API。我已经转换了代码,因此可以将其输入tf-records数据集并适合estimator API。 但是,它仅适用于batch-size = view-count 如果我使用更高的批处理尺寸,则将无法使用,因为将针对一种补给特征计算净值,而将一个样本的计算结果与整个标签张量进行比较。我尝试使用tf.gather并在估算器的标签函数(带有for循环)中手动计算每个进纸标签的损耗。这无法解决,因为它会在100次迭代后使会话崩溃。

如果我仅将1张标签传递给20张图像,它将正常工作。但是,当我增加批处理大小时,它不会启动并在损失计算中将其归咎于形状不匹配(我正在使用tf.nn.sparse_softmax_cross_entropy_with_logits。并且恰好在此函数调用之后出现错误:Label和Logits的形状不匹配张量:

Label_Tensor.shape() = [Batch_size / 20]   # e.g. [200/20] = [10]
Logit_Tensor.shape() = [1, Number_Of_Classes]  # e.g. [1, 17]

现在Tensorflow希望我的第一维尺寸相同。但这意味着一次只训练一个样本。我已经在input_function中以及直接在损失计算之前跟踪了标签的Tensor的大小。一切似乎都正确。我无法跟踪logit的形状,仅在进行损耗计算时它会向我指出给定的错误,那么我该如何调试/解决此问题?

0 个答案:

没有答案