我想通过功能float32_ref
使用将word_embeddings dtype从float32
修改为tf.cast()
:
word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)
但是它没有按预期工作,并且word_embeddings_modify dtype仍然为tf.float32_ref。
word_embeddings = tf.scatter_nd_update(var_output, error_word_f,sum_all)
word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)
word_embeddings_dropout = tf.nn.dropout(word_embeddings_2, dropout_pl)