Tensorflow argmax沿多个维度

时间:2017-05-31 15:03:14

标签: python tensorflow deep-learning

我是def select(input_layer): shape = input_layer.get_shape().as_list() rel = tf.nn.relu(input_layer) print (rel) redu = tf.reduce_sum(rel,3) print (redu) location2 = tf.argmax(redu, 1) print (location2) sess = tf.InteractiveSession() I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32) matI, matO = sess.run([I, select(I, 3)]) print(matI, matO) 的新手,我正在尝试获取Tensor中最大值的索引。这是代码:

Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32)
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32)
Tensor("ArgMax:0", shape=(32, 3), dtype=int64)
...

这是输出:

argmax

由于Tensor("ArgMax:0") = (32,3)函数中的维度= 1,因此argmax的形状。在应用(32,)之前,有没有办法在不reshape的情况下获得argmax输出张量大小= {{1}}?

1 个答案:

答案 0 :(得分:1)

您可能不希望输出大小为(32,)的输出,因为当您沿着多个方向argmax时,您通常希望所有缩小尺寸的坐标都为最大值。在您的情况下,您希望输出大小为(32,2)

你可以像这样做一个二维argmax

import numpy as np
import tensorflow as tf

x = np.zeros((10,9,8))
# pick a random position for each batch image that we set to 1
pos = np.stack([np.random.randint(9,size=10), np.random.randint(8,size=10)])

posext = np.concatenate([np.expand_dims([i for i in range(10)], axis=0), pos])
x[tuple(posext)] = 1

a = tf.argmax(tf.reshape(x, [10, -1]), axis=1)
pos2 = tf.stack([a // 8, tf.mod(a, 8)]) # recovered positions, one per batch image

sess = tf.InteractiveSession()
# check that the recovered positions are as expected
assert (pos == pos2.eval()).all(), "it did not work"