Tensorflow:如何在BasicLSTMCell.zero_state()的参数上设置`None`?

时间:2017-11-02 11:26:55

标签: tensorflow lstm

我正在创建一个图像字幕模型,这是代码的一部分:

.
.
.
self.img_features = tf.placeholder(tf.float32, [None, self.n_inputs])
self.caption = tf.placeholder(tf.int32, [None, self.n_steps])
self.mask = tf.placeholder(tf.float32, [None, self.n_steps])

# getting an initial LSTM embedding from our image_imbedding
self.embedded_img = tf.matmul(self.img_features, self.w_img_proj) + self.b_img_proj

# setting initial state of our LSTM
state = self.lstm.zero_state(self.batch_size, dtype=tf.float32)
.
.

如上所示,zero_state的arg设置为self.batch_size,这不灵活。我想将它设置为None,以便我可以传递任何长度的数据。

我该怎么做?

我做过像self.lstm.zero_state(self.img_features.get_shape()[0], dtype=tf.float32)self.lstm.zero_state(None, dtype=tf.float32)这样的事情,但它确实不起作用......

1 个答案:

答案 0 :(得分:0)

如果你不知道它的值,你可以扔一个占位符,它可以是int,float或单位Tensor 代表批量大小。

以下是获得张量形状的两种方法:

tensor.get_shape()[0].value
tf.shape(tensor)[0]

第一个导致值,但第二个导致单位Tensor 。您可以从the documentation验证:

  

返回:

     

类型为out_type的张量。