tensorflow如何解决one_hot和sign之间的dtype冲突

时间:2018-10-23 03:16:49

标签: python tensorflow neural-network

我的神经网络具有以下输出:

  • logits是tanh节点的输出,因此该值是(-1,1)内的浮点数。
  • actionlogits
  • 的标志
  • one_hotaction的one_hot版本,维度3表示-1、0和+1

问题是,我的损失函数与one_hot值相关,所以我建立了神经网络输出部分,如下所示:

logits = tf.contrib.layers.fully_connected(outputs, 1, activation_fn=tf.tanh)
action = tf.sign(logits)
one_hot = tf.one_hot(action+1, depth=3)

这给了我

的TypeError
  

TypeError:传递给参数'indices'的值具有DataType float32   不在允许的值列表中:uint8,int32,int64

然后我尝试将one_hot更改为:

one_hot = tf.one_hot(tf.cast(action, tf.int32)+1, depth=3)

还有另一个没有渐变的错误:

  

ValueError:没有为任何变量提供渐变,请检查您的图表   对于不支持渐变的操作,在变量之间[...]

是否可以使用任何解决方法来避免这两个错误。任何帮助表示赞赏。

0 个答案:

没有答案