TensorFlow embedding_rnn_decoder' Tensor'对象不可迭代

时间:2017-08-15 17:53:35

标签: python tensorflow google-cloud-ml-engine

我正在尝试为我的ML Engine包构建一个自定义估算器,我似乎无法以正确的格式正确构建我的解码器输入序列。考虑以下内容,其中label1,label2应该是一系列标签。

label1, label2 = tf.decode_csv(rows, record_defaults=[[""], [""]])
labels = tf.stack([label1, label2], axis=1)
label_index = tf.contrib.lookup.index_table_from_file(
    vocabulary_file = label_file)
label_idx = label_index.lookup(labels)
features = dict(zip(['decoder_input'], [label_idx]))

这些"功能"然后作为解码器输入传递,如下所示。当我使用decoder_input作为我的自定义估算工具的输入时,我遇到了一个错误' TypeError:' Tensor'对象不可迭代。'这里:

outputs, state = tf.contrib.legacy_seq2seq.embedding_rnn_decoder(
    decoder_inputs = features['decoder_input'],
    initial_state = curr_layer,
    cell = tf.contrib.rnn.GRUCell(hidden_units),
    num_symbols = n_labels,
    embedding_size = embedding_dims, # should not be hard-coded
    feed_previous = False)

完整的堆栈跟踪(下面)表明导致问题的代码部分是'for i in decoder_inputs' from line 296所以我觉得很明显问题在于如何在input_fn()中构造我的decoder_input。但是,我似乎无法弄清楚如何使Tensor对象成为可迭代的序列列表。

堆栈跟踪:

File "/Users/user/anaconda/envs/tensorflow-

  cloud/lib/python2.7/sitepackages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 296, in embedding_rnn_decoder
    for i in decoder_inputs)
  File "/Users/user/anaconda/envs/tensorflow-cloud/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 541, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

有人可以帮助发现我应该如何正确格式化我的标签,以便它们可以迭代吗?文档说,decoder_inputs应该是" 1D批量大小的int32 Tensors(解码器输入)的列表。"所以我认为通过staIs生成标签序列比tf.stack()更合适吗?

1 个答案:

答案 0 :(得分:1)

label_idx值不是列表,因此您遇到此问题:

下面的例子应该更好地澄清:

label_idx = 1

features = dict(zip(['decoder_input'], [label_idx]))

features['decoder_input']

# 1 output

好像我将label_idx更改为列表:

label_idx = [1]

features = dict(zip(['decoder_input'], [label_idx]))

features['decoder_input']

# [1] output

您还可以简化创建字典的方式:

features = {'decoder_input': [label_idx]} # if label_idx is a value
features = {'decoder_input': label_idx} # if label_idx is a list