我有一个字符串Tensor(命名为句子),我想在其中获取其词的嵌入内容:
sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)
我使用上面的代码将字符串拆分应用于批处理中的所有句子。然后,在我的单词表中进行查找,以获取这些张量中每个单词的单词索引:
sentence = tf.map_fn(lambda x: tf.cast(word_table.lookup(x), tf.int32), sentence, dtype=tf.int32)
以 1 的批量大小运行时,我没有问题。但是,如果批处理大小大于1,我总是会收到以下错误,该错误指向上面的第一个代码段。
InvalidArgumentError(请参见上面的回溯):TensorArray statement_splitter / map / TensorArray_1_1:无法写入TensorArray索引10,因为值的形状为[4],与TensorArray的推断元素形状不兼容:[6](考虑设置infer_shape = False)。
我不明白Tensorflow试图说出这个错误!如果有人可以解释这个错误,那就太好了。谢谢!
答案 0 :(得分:1)
如果您的批处理大小大于1,则在此代码之后
sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)
tf.string_split()函数对不同的句子产生不同数量的拆分结果。每个维度的不兼容使得最终结果无法存储到张量中,因此会发生错误。这个清楚吗?