Tensorflow:动态地进行字母预测

时间:2016-01-10 21:03:21

标签: python machine-learning tensorflow

我正在尝试使用Tensorflow中的LSTM模块在时间(t-1)使用单热字母预测作为时间(t)下一个状态的输入。我正在做一些事情:

one_hot_dictionary = {0:np.array([1.,0.,0.]),1:np.array([0.,1.,0.]),\
                         2:np.array([0.,0.,1.])}
state = init_state
for time in xrange(sequence_length):
    #run the cell
    output, state = rnn_cell.cell(input,state)

    #transform the output so they are of the one-hot letter dimension
    transformed_val = tf.nn.xw_plus_b(output, W_o, b_o)

    #take the softmax to normalize
    softmax_val = tf.nn.softmax(transformed_val)

    #then get the argmax of these to know what the predicted letter is
    argmax_val = tf.argmax(softmax_val,1)

    #finally, turn these back into one-hots with a number to numpy
    #   array dictionary
    input = [one_hot_dictionary[argmax_val[i]] for i in xrange(batch_size)]

然而,我收到错误:

input = [one_hot_dictionary[argmax_val[i]] for i in xrange(batch_size)]
KeyError: <tensorflow.python.framework.ops.Tensor object at 0x7f772991ce50>

有没有办法可以使用我的字典动态创建这些单热字母,从argmax值到单热字母编码?

1 个答案:

答案 0 :(得分:1)

有几种方法可以实现这一目标。

最直接的代码修改方法是使用tf.gather()操作从单位矩阵中选择行,如下所示:

# Build the matrix [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]].
identity_matrix = tf.diag(tf.ones([3]))

for ...:

  # Compute a vector of predicted letter indices.
  argmax_val = ...

  # For each element of `argmax_val`, select the corresponding row
  # from `identity_matrix` and concatenate them into matrix.
  input = tf.gather(identity_matrix, argmax_val)

对于您显示只有3个不同字母的情况,性能可能并不重要。但是,如果字母数量(以及identity_matrix的大小)与批量大小相比要大得多 - 您可以通过构建tf.SparseTensor并使用{{3}来提高内存效率}建立input