tensorflow - tf.where TypeError

时间:2017-03-06 12:31:26

标签: python tensorflow

我编写了以下测试代码(更大代码的一部分)

import tensorflow as tf

update_boolean = [True, False, True, False, True, True]


with tf.Session() as sess:
    op = tf.where(update_boolean, lambda: tf.train.AdamOptimizer(0.1), lambda: tf.no_op())

我收到以下错误

TypeError: Expected binary or unicode string, got <function <lambda> at 0x000000000118E400>

我该如何解决这个问题?

我想要做的是创建一个list \ tensor运算符(优化器\什么都不做)给定一定条件(不使用tf.cond因为我想将它应用于批处理所以我有一个列表\ booleans张量) *在tensorflow 1.0中工作

1 个答案:

答案 0 :(得分:0)

tf.train.AdamOptimizer(0.1)返回一个Optimizer对象,不清楚它与no_op的关系。我建议首先创建优化器然后调整梯度更新,例如通过将所需批次的部分的梯度贡献归零。 tf.where不适用于操作,它提供了访问张量的索引。