使用tensorfow使用'if'语句对python进行for'循环

时间:2019-08-05 08:20:41

标签: python for-loop tensorflow if-statement

有人可以帮助将此代码转换为TensorFlow吗?我试图在CNN输出大于0.95的值的集合中找到数据点位置,以帮助进行伪标记。

   positions = []

   for t in range(int(dataset.shape[0] // batch_size)):
        data = dataset.next_batch
        model_output = sess.run([output], feed_dict={model_input_pl: data})

        for i in range(model_output[0].shape[0]):
            if model_output[0][i][some_nodal_position] > 0.95:
                 positions.append(batch_start_position + i)

并行处理此代码将允许测试更多模型,但是拥有上述代码将花费很长时间。

1 个答案:

答案 0 :(得分:0)

这可以实现如下:

tf.where(tf.greater(output, 0.95))

这将返回一个张量,其中索引为output大于0.95。