如何将字符串传递给tf.contrib.rnn.MultiRNNCell?

时间:2017-05-02 01:56:38

标签: tensorflow

我正在使用TensorFlow对LSTM进行编码,该LSTM基于长度为10的前一个字符串(恰好与方向对应)来预测布尔目标,即URDDLRUDUD。

当我将tf.string张量传递到tf.nn.dynamic_rnn时:

multi_rnn_cell = tf.contrib.rnn.MultiRNNCell(
  lstm_cells, state_is_tuple=True)
output_data, _ = tf.nn.dynamic_rnn(
  multi_rnn_cell,
  tf.reshape(input_layer, (batch_size, 10, 1)),
  initial_state=lstm_layers.zero_state(
    batch_size=batch_size,
    dtype=tf.string))

我收到错误:

ValueError: dtype must be convertible to float. dtype: <dtype: 
'string'>, column_name: sequence

似乎tf.nn.dynamic_rnn的输入应该是嵌入值 - 虽然我不想传递整个长度为10个字的嵌入,因为我希望LSTM是字符基

设计输入并将输入传递给基于字符的LSTM的正确方法是什么?

1 个答案:

答案 0 :(得分:3)

您需要将字符编码为int值

U 0
R 1
L 2
D 3

您的输入应该是:

[[0,1,2,3,2,2,3,1],[0,1,2,1,2,1,3,1]...]

添加嵌入图层(使用tf.contrib.layers.embed_sequence嵌入输入数据)并将编码数据提供给您的lstm单元格。