Tensorflow CNN Shape错误

时间:2017-05-25 08:23:49

标签: tensorflow neural-network tf-slim

我试图在某些数据上使用CNN,但我在丢失函数中出现错误,因为我的模型的输出是形状[1000,1000,4000]应该是[ 1000,4000]。在这种情况下,前1000是批量大小,而4000是我所拥有的类的数量,因为这是一个分类问题。

我想我可能需要在完全连接的图层后再次使用tf.reshape()函数来获得正确的输出,但我不太确定如何做到这一点。我已经尝试过tf.reshape(输出[-1,4000]),但仍然保持其他1000内部。

这是我的代码:

    cnn_input = tf.reshape(input, [-1, 1000, 1])
    net = slim.conv2d(cnn_input, 128, [3])
    net = slim.pool(net, [2], "MAX")
    output = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax)
    return output

基本上,我的输出需要是等级2的形状,但由于某种原因,它的结果是3维。我需要输出形状为[1000,4000],批量大小为x num_classes。

非常感谢任何帮助。提前谢谢!

顺便说一下,我正在使用tf-slim库。

编辑:在完全连接的图层之前,tf.flatten会为此工作吗?

1 个答案:

答案 0 :(得分:0)

我遇到了同样的错误。 documentation (line 1609)(从here链接)表示'fully_connected'操作应该使输出变平,但事实并非如此。我刚刚按照你的建议在最后一对完全连接的操作之前使用了slim.flatten,但我没有具体的证据证明它有效。

有6个月没有评论,我认为有些东西比其他人更好,但如果有其他人有更多的见解,我们将不胜感激。