我有一个序列(样本,时间步长,特征)的“ myseq” =(N,500,8),其中N是批处理大小
它馈入大小为48的softmax(当我使用跨度为10的窗口查看时,代表48个可能的大小为(N,20,80)的窗口)。
然后我找到“索引”,即softmax输出值的最大索引。 =(N,1),其中第二维对应于每个单独样本的softmax的每个最大索引“ i”。
在这一点上,我想切片[:,indices:indices + 20,:]
示例代码如下:
myseq = Input(shape=(500,4))
choices = Dense(48, activation='softmax')(myseq)
indices = K.argmax(choices, axis=1)
indices = K.cast(indices,"int32")
windows = ??
因此,作为一个简短的摘要:我想使用张量“索引”作为每个批次的起始索引,从myseq的axis1可变地切片长度为20的窗口。
我现在可以通过使用一个函数来执行此操作,该函数对批处理中的每个样本进行迭代,提取每个特定的窗口,然后将它们全部重新组合成一个新的张量。但这效率极低,我什至不能100%地确定它不会破坏网络中的某些内容。这就是我要做的:
def get_window(myarr, sums, length, stride, batchsize):
indexez = K.argmax(sums, axis=1)
indices = K.cast(indexez,"int32")
for i in range(batchsize):
new_arr = Lambda(lambda x: K.slice(x, (i,indices[i]*stride,0), (1,length,myarr.shape[2])))(myarr)
if i == 0:
full_arr = new_arr
else:
full_arr = Concatenate(axis=0)([full_arr,new_arr])
full_arr = tf.compat.v1.placeholder_with_default(full_arr,[None,full_arr.shape[1],full_arr.shape[2]])
return full_arr
是否有更好的建议?谢谢。