使用argmax在tensorflow中切割张量

时间:2018-02-08 08:14:24

标签: python tensorflow

我想在tensorflow中创建一个动态丢失函数。我想计算信号FFT的能量,更具体地说,只计算最主要峰值周围的3号窗口。我无法在TF中实现,因为它会引发很多错误,例如StrideInvalidArgumentError (see above for traceback): Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,64], [1,64], and [1] instead.

我的代码是:

self.spec = tf.fft(self.signal)
self.spec_mag = tf.complex_abs(self.spec[:,1:33])
self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32)
self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1)

由于我的64位批处理和数据维度也是64,因此self.signal的形状为(64,64)。我希望只计算FFT的AC分量。由于信号是真正有价值的,只有一半的频谱可以完成工作。因此,self.spec_mag的形状为(64,32)

此fft中的最大值位于self.argm,其形状为(64,1)

现在我想通过self.spec_mag[self.argm-1:self.argm+2]计算最大峰值周围3个元素的能量。

然而,当我运行代码并尝试获取self.frac的值时,我会遇到多个错误。

2 个答案:

答案 0 :(得分:4)

在访问argm时,您似乎缺少了索引。以下是1,64版本的固定版本。

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:,1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

这是扩展版本(batch,m,n)

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:, :, 1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

你可能想要修复函数名,因为我在较新版本的tensorflow中编辑了这段代码。

答案 1 :(得分:3)

Tensorflow索引使用tf.Tensor.getitem

  

此操作从张量中提取指定的区域。符号类似于NumPy,其限制目前仅支持基本索引。这意味着目前不允许使用张量作为输入

因此,使用tf.slicetf.strided_slice也是不可能的。

tf.gather indices中将切片定义为Tensor的第一维,在tf.gather_nd中,indices将切片定义为第一个N Tensor的尺寸,N = indices.shape[-1]

由于你想要max周围的3个值,我使用列表推导手动提取第一,第二和第三个元素,然后是tf.stack

import tensorflow as tf

signal = tf.placeholder(shape=(64, 64), dtype=tf.complex64)
spec = tf.fft(signal)
spec_mag = tf.abs(spec[:,1:33])
argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)

frac = tf.stack([tf.gather_nd(spec,tf.transpose(tf.stack(
             [tf.range(64), argm+i]))) for i in [-1, 0, 1]])

frac = tf.reduce_sum(frac, 1)

对于argm是行中的第一个或最后一个元素的极端情况,这将失败,但它应该很容易解决。