(keras)输入张量和索引,得到我想要的张量

时间:2019-09-27 07:24:12

标签: python tensorflow keras

我的输入就像(3,3,2)数组和(3,3)数组:

img = np.array([[[1,1],[2,2],[3,3]],
                [[4,4],[5,5],[6,6]],
                [[7,7],[8,8],[9,9]]])

idx = np.array([[1,0,0],
                [0,0,1],
                [1,1,0]])

我的理想输出应该是:

[[1 1]
 [6 6]
 [7 7]
 [8 8]]

我想通过自定义图层来做到这一点:

  1. 制作图层:
def extract_layer(data, idx):

    idx = tf.where(idx)
    data = tf.gather_nd(data,idx)
    data = tf.reshape(data,[-1,2])

    return data
  1. 制作模型:
input_data = kl.Input(shape=(3,3,2))
input_idxs = kl.Input(shape=(3,3))
extraction = kl.Lambda(lambda x:extract_layer(*x),name='extraction')([input_data,input_idxs])

我可以构建模型,并且可以看到模型的keras摘要, 输出是

model = Model(inputs=([input_data,input_idxs]), outputs=extraction)
model.summary()

...
input_1 (InputLayer)            (None, 3, 3, 2) 
input_2 (InputLayer)            (None, 3, 3) 
extraction (Lambda)             (None, 2)
Total params: 0
...

但是当我开始预测时:

'i have already made the two inputs into (1,3,3,2) and (1,3,3) shape'
result = model.predict(x=([img,idx]))

出现错误:

'ValueError: could not broadcast input array from shape (4,2) into shape (1,2)'

我认为shape(4,2)的张量是我想要的值 但我不知道为什么喀拉拉邦将其广播给(1,2)

有人可以帮助我吗?

非常感谢!

1 个答案:

答案 0 :(得分:0)

在您的extract_layer()函数中,data是两个暗淡张量。但是model.predict应该以额外的批处理暗淡返回结果。当data中的返回extract_layer()可以解决此错误时,只需扩大暗淡度即可。

def extract_layer(data, idx):

    idx = tf.where(idx)
    data = tf.gather_nd(data,idx)
    data = tf.reshape(data,[-1,2])

    return tf.expand_dims(data, axis=0)

注意:由于tf.gather_nd返回的结果可能有不同的长度,所以我认为每批的大小只能是1。如果我错了,请更正我。