在以下this tutorial中,我收到以下错误:
ValueError: prefix tensor must be either a scalar or vector, but saw tensor: Tensor("Placeholder_2:0", dtype=int32)
错误源自以下几行:
# Take the output from the final convolutional layer and send it to a recurrent layer
# The input must be reshaped into [batch x trace x units] for rnn processing, and then returned to
# [batch x units] when sent through the upper levels
self.batch_size = tf.placeholder(dtype=tf.int32)
self.convFlat = tf.reshape(slim.flatten(self.conv4), [self.batch_size, self.trainLength, h_size])
# !!!!This is the line where error city happens!!!!
self.state_in = rnn_cell.zero_state(self.batch_size, tf.float32)
网络初始化后:
mainQN = Qnetwork(h_size, cell, 'main')
仅在python控制台中运行代码时仍然存在此错误,因此错误是一致的。
如果有帮助,我会发布更多代码
答案 0 :(得分:2)
还有另一种解决方案可以解决这个问题。
更改
self.batch_size = tf.placeholder(dtype=tf.int32)
要
self.batch_size = tf.placeholder(dtype=tf.int32, [])
答案 1 :(得分:1)
我遇到了同样的问题,张量流的版本是1.2。+。
当我将其更改为1.1.0时,问题已解决。
我认为是因为rnn_cell.zero_state的API使得arg batch_size必须是标量或向量,而不是张量。
因此,如果您将batch_size更改为标量,例如128,问题也可以解决。