如何将张量dtype float32_ref转换为dtype float32?

时间:2019-01-10 01:26:06

标签: tensorflow

word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=self.word_ids)
word_embeddings_modify = tf.scatter_nd_update(word_embeddings, self.error_word, sum_all)

Error:
Tensor conversion requested dtype float32_ref for Tensor with dtype float32

从该错误看来,函数word_embeddings中的scatter_nd_update实际为dtypetf.float_32,但是scatter_nd_update应该接受word_embeddings dtype tf.float_32_ref

在使用word_embeddings's之前如何将dtype tf.float_32tf.float_32_ref更改为tf.scatter_nd_update

1 个答案:

答案 0 :(得分:0)

您可以使用tf.Variable()直接转换dtype。一个例子:

import tensorflow as tf

_word_embeddings = tf.get_variable(name='embedding',shape=[30, 5])

word_ids = [3,6,23]
word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=word_ids)

error_word = [[1]]
sum_all = [[0,0,0,0,0]]
word_embeddings_modify = tf.scatter_nd_update(tf.Variable(word_embeddings), error_word, sum_all)

with tf.Session()as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(word_embeddings))
    print(sess.run(word_embeddings_modify))
[[ 0.08698401 -0.15590087  0.00285593 -0.13804913 -0.12418613]
 [-0.25748074  0.32121882 -0.390212    0.24590132  0.3976703 ]
 [-0.3023583   0.00366881 -0.05178839 -0.20865369  0.2887713 ]]
[[ 0.08698401 -0.15590087  0.00285593 -0.13804913 -0.12418613]
 [ 0.          0.          0.          0.          0.        ]
 [-0.3023583   0.00366881 -0.05178839 -0.20865369  0.2887713 ]]

奇怪的是为什么更新单词嵌入结果。