batch-size,sequence-length和hidden_​​size之间的关系是什么?

时间:2017-12-11 11:46:38

标签: dynamic tensorflow rnn

在阅读dynamic_rnn的API文档时,我有以下问题:

批量大小,序列长度和(单元格)hidden_​​size之间的关系是否存在约束?

我在想:

sequence-length< =(cell)hidden_​​size,或,

batch-size * sequence-length< =(cell)hidden_​​size

我说错了吗?我一直在阅读很多网页,但无法找到答案。

谢谢大家。

https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn

示例:

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                   inputs=data,
                                   dtype=tf.float32)

1 个答案:

答案 0 :(得分:0)

就API而言,没有任何关系。修复这些参数中的任何两个,其余的仍然可以是任何非负整数(或者,在sequence_length的情况下,任何{ - 1}}长度的非负整数向量)。

如果batch_size很大并且训练数据非常少,那么得到的模型可能非常容易过度拟合,但它仍然有效。请注意,训练数据通常以小批量呈现,因此训练数据量不会是传递给hidden_size的序列长度的总和。

还有硬件约束,如dynamic_rnn,最大序列长度会影响内存使用。