tf.split输出具有无用维度的张量

时间:2018-06-04 14:05:34

标签: python tensorflow

我试图构建一个递归神经网络,输入来自mnist数据库,因此它们是28x28像素的图像。

这是我的网络模型:

def _pre_activation(self, input_values, layer):
   return tf.add(tf.matmul(input_values, layer['weights']), layer['biases'])

def network_model(self, x):
    layer = {'weights':tf.Variable(tf.random_normal([self.rnn_size, self.nb_of_outputs])),
             'biases': tf.Variable(tf.random_normal([self.nb_of_outputs]))}

    x = tf.transpose(x, [1,0,2])
    x = tf.reshape(x, [-1, self.chunk_size])
    x = tf.split(x, self.nb_of_chunks, 0)

    lstm_cell = rnn.BasicLSTMCell(self.rnn_size)
    lstm_outputs, states = rnn.static_rnn(lstm_cell, x_splitted, dtype=tf.float32)

    output = self._pre_activation(lstm_outputs, layer)

    return output

我收到以下错误:

  

形状必须是等级2,但对于' MatMul'是等级3。 (op:' MatMul')   输入形状:[28,1,128],[128,10]。

我的问题是,从形状张量(28, 28)tf.split输出28个形状(1,28)数组的列表,但我想要28个形状(28, )数组。

我尝试将其迭代到一个新列表并逐个重新整形每个数组,但问题是此时没有会话正在运行,并且这些张量中没有值。

我没有在互联网上找到任何有用的帮助,可能是因为我不知道如何制定我的问题。

1 个答案:

答案 0 :(得分:0)

您的错误表示输出的等级为3,乘法应为等级2。即使时间步长为1,LSTM的输出也是第3级。因此,在lstm输出上使用tf.squeeze来解决问题。

output = self._pre_activation(tf.squeeze(lstm_outputs), layer)