word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=self.word_ids)
word_embeddings_modify = tf.scatter_nd_update(word_embeddings, self.error_word, sum_all)
Error:
Tensor conversion requested dtype float32_ref for Tensor with dtype float32
从该错误看来,函数word_embeddings
中的scatter_nd_update
实际为dtype
是tf.float_32
,但是scatter_nd_update
应该接受word_embeddings dtype tf.float_32_ref
。
在使用word_embeddings's
之前如何将dtype
tf.float_32
从tf.float_32_ref
更改为tf.scatter_nd_update
?
答案 0 :(得分:0)
您可以使用tf.Variable()
直接转换dtype。一个例子:
import tensorflow as tf
_word_embeddings = tf.get_variable(name='embedding',shape=[30, 5])
word_ids = [3,6,23]
word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=word_ids)
error_word = [[1]]
sum_all = [[0,0,0,0,0]]
word_embeddings_modify = tf.scatter_nd_update(tf.Variable(word_embeddings), error_word, sum_all)
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(word_embeddings))
print(sess.run(word_embeddings_modify))
[[ 0.08698401 -0.15590087 0.00285593 -0.13804913 -0.12418613]
[-0.25748074 0.32121882 -0.390212 0.24590132 0.3976703 ]
[-0.3023583 0.00366881 -0.05178839 -0.20865369 0.2887713 ]]
[[ 0.08698401 -0.15590087 0.00285593 -0.13804913 -0.12418613]
[ 0. 0. 0. 0. 0. ]
[-0.3023583 0.00366881 -0.05178839 -0.20865369 0.2887713 ]]
奇怪的是为什么更新单词嵌入结果。