我想在tensorflow中创建一个动态丢失函数。我想计算信号FFT的能量,更具体地说,只计算最主要峰值周围的3号窗口。我无法在TF中实现,因为它会引发很多错误,例如Stride
和InvalidArgumentError (see above for traceback): Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,64], [1,64], and [1] instead.
我的代码是:
self.spec = tf.fft(self.signal)
self.spec_mag = tf.complex_abs(self.spec[:,1:33])
self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32)
self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1)
由于我的64位批处理和数据维度也是64,因此self.signal
的形状为(64,64)
。我希望只计算FFT的AC分量。由于信号是真正有价值的,只有一半的频谱可以完成工作。因此,self.spec_mag
的形状为(64,32)
。
此fft中的最大值位于self.argm
,其形状为(64,1)
。
现在我想通过self.spec_mag[self.argm-1:self.argm+2]
计算最大峰值周围3个元素的能量。
然而,当我运行代码并尝试获取self.frac
的值时,我会遇到多个错误。
答案 0 :(得分:4)
在访问argm时,您似乎缺少了索引。以下是1,64版本的固定版本。
import tensorflow as tf
import numpy as np
x = np.random.rand(1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)
signal = xt
print('signal', signal.shape)
print('signal', signal.eval())
spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())
spec_mag = tf.abs(spec[:,1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())
argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())
frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())
这是扩展版本(batch,m,n)
import tensorflow as tf
import numpy as np
x = np.random.rand(1, 1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)
signal = xt
print('signal', signal.shape)
print('signal', signal.eval())
spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())
spec_mag = tf.abs(spec[:, :, 1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())
argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())
frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())
你可能想要修复函数名,因为我在较新版本的tensorflow中编辑了这段代码。
答案 1 :(得分:3)
Tensorflow索引使用tf.Tensor.getitem:
此操作从张量中提取指定的区域。符号类似于NumPy,其限制目前仅支持基本索引。这意味着目前不允许使用张量作为输入
因此,使用tf.slice
和tf.strided_slice
也是不可能的。
在tf.gather
indices
中将切片定义为Tensor
的第一维,在tf.gather_nd
中,indices
将切片定义为第一个N
Tensor
的尺寸,N = indices.shape[-1]
由于你想要max
周围的3个值,我使用列表推导手动提取第一,第二和第三个元素,然后是tf.stack
import tensorflow as tf
signal = tf.placeholder(shape=(64, 64), dtype=tf.complex64)
spec = tf.fft(signal)
spec_mag = tf.abs(spec[:,1:33])
argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
frac = tf.stack([tf.gather_nd(spec,tf.transpose(tf.stack(
[tf.range(64), argm+i]))) for i in [-1, 0, 1]])
frac = tf.reduce_sum(frac, 1)
对于argm
是行中的第一个或最后一个元素的极端情况,这将失败,但它应该很容易解决。