有人可以帮助将此代码转换为TensorFlow吗?我试图在CNN输出大于0.95的值的集合中找到数据点位置,以帮助进行伪标记。
positions = []
for t in range(int(dataset.shape[0] // batch_size)):
data = dataset.next_batch
model_output = sess.run([output], feed_dict={model_input_pl: data})
for i in range(model_output[0].shape[0]):
if model_output[0][i][some_nodal_position] > 0.95:
positions.append(batch_start_position + i)
并行处理此代码将允许测试更多模型,但是拥有上述代码将花费很长时间。
答案 0 :(得分:0)
这可以实现如下:
tf.where(tf.greater(output, 0.95))
这将返回一个张量,其中索引为output
大于0.95。