我正在尝试使用tf.nn.max_pool_with_argmax()
的argmax结果来索引另一个张量。为简单起见,假设我正在尝试实现以下内容:
output, argmax = tf.nn.max_pool_with_argmax(input, ksize, strides, padding)
tf.assert_equal(input[argmax],output)
现在我的问题是如何实现必要的索引操作input[argmax]
以实现所需的结果?我猜这涉及tf.gather_nd()
和相关电话的一些用法,但我无法弄明白。如有必要,我们可以假设输入具有[BatchSize, Height, Width, Channel]
维度。
谢谢你的帮助!
垫
答案 0 :(得分:1)
我找到了一个使用tf.gather_nd
的解决方案,虽然看起来不那么优雅,但它确实有效。我使用了here发布的函数unravel_argmax
。
def unravel_argmax(argmax, shape):
output_list = []
output_list.append(argmax // (shape[2] * shape[3]))
output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
return tf.stack(output_list)
def max_pool(input, ksize, strides,padding):
output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
shape = input.get_shape()
arg_max = tf.cast(arg_max,tf.int32)
unraveld = unravel_argmax(arg_max,shape)
indices = tf.transpose(unraveld,(1,2,3,4,0))
channels = shape[-1]
bs = tf.shape(iv.m)[0]
t1 = tf.range(channels,dtype=arg_max.dtype)[None, None, None, :, None]
t2 = tf.tile(t1,multiples=(bs,) + tuple(indices.get_shape()[1:-2]) + (1,1))
t3 = tf.concat((indices,t2),axis=-1)
t4 = tf.range(tf.cast(bs, dtype=arg_max.dtype))
t5 = tf.tile(t4[:,None,None,None,None],(1,) + tuple(indices.get_shape()[1:-2].as_list()) + (channels,1))
t6 = tf.concat((t5, t3), -1)
return tf.gather_nd(input,t6)
如果有人有更优雅的解决方案,我仍然很想知道。
垫
答案 1 :(得分:0)
我是以这种方式做的:
def max_pool(input, ksize, strides,padding):
output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
shape=tf.shape(output)
output1=tf.reshape(tf.gather(tf.reshape(input,[-1]),arg_max),shape)
err=tf.reduce_sum(tf.square(tf.subtract(output,output1)))
return output1, err
答案 2 :(得分:0)
此小片段有效:
def get_results(data,other_tensor):
pooled_data, indices = tf.nn.max_pool_with_argmax(data,ksize=[1,ksize,ksize,1],strides=[1,stride,stride,1],padding='VALID',include_batch_in_index=True)
b,w,h,c = other_tensor.get_shape.as_list()
other_tensor_pooled = tf.gather(tf.reshape(other_tensor,shape= [b*w*h*c,]),indices)
return other_tensor_pooled
以上indices
可用于索引张量。此函数实际上返回平整的索引,并将其与batch_size > 1
一起使用时,您需要将include_batch_in_index
作为True
传递,以获得正确的结果。我在这里假设othertensor
与data.