我正在尝试使用tf.cond()
和tf.case()
控制图流。
我有几个网络可以产生相同大小(networks
)的输出。
我还有一个额外的网络,可以输出要使用的上述哪个网络(networks_prob
。
在培训期间,我想堆叠所有的networks
结果。在测试过程中,我想构造一个张量,该张量由networks
最大的索引中networks_prob
中的结果组成。 (因此,我可以避免评估所有网络,而只评估概率最高的网络)
这是我想到的一个简单示例,但是它不起作用,我也不明白为什么。
import tensorflow as tf
networks = tf.constant([[[1], [2], [3]], [[4], [5], [6]]])
networks_prob = tf.constant([[0.2, 0.3, 0.4], [0.8, 0.1, 0.0]])
is_training = tf.placeholder(tf.bool, shape=())
network_idx = tf.argmax(networks_prob)
case_dict = {tf.equal(network_idx, i): lambda: networks[i] for i in range(networks.shape[1])}
output = tf.cond(is_training, lambda: tf.stack(networks), lambda: tf.case(case_dict, default=lambda: 1))
with tf.Session() as sess:
output_val = sess.run(output, feed_dict={is_training: False})
print(output_val)
我收到Shape must be rank 0 but is rank 1 for 'cond/case/If_0/Switch' (op: 'Switch') with input shapes: [3], [3].
错误。