在Keras Lambda中切片无法按预期进行

时间:2019-08-26 21:05:53

标签: python keras slice

我正在尝试挑选输入的不同部分以馈入不同的流。切片在numpy中是这样的:

x = numpy.arange(10).reshape(-1,10)
print(x[:, [1,2,3]])
# >>> [[1 2 3]]
print(x[:, [2,5,9]])
#>>> [[2 5 9]]

但不适用于keras:

features = numpy.arange(10).reshape(-1, 10)

inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: x)(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predict(features))
# >>> [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]


inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: x[:,1:3])(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predict(features))
# >>> [[1. 2.]]

inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: x[:,[1,2,3]])(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predictsn(features))
# >>> Traceback (most recent call last):
#  File "<stdin>", line 1, in <module>
#  File "/home/user/test.py", line 589, in <module>
#    predictions = keras.layers.Lambda(lambda x: x[:,[1,2,3]])(inputs)
#  File "/home/user/.local/lib/python3.6/site-packages/keras/engine/base_layer.py", line 451, in __call__
#    output = self.call(inputs, **kwargs)
#  File "/home/user/.local/lib/python3.6/site-packages/keras/layers/core.py", line 716, in call
#    return self.function(inputs, **arguments)
#  File "/home/user/test.py", line 589, in <lambda>
#    predictions = keras.layers.Lambda(lambda x: x[:,[1,2,3]])(inputs)
#  File "/home/user/.local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 491, in _slice_helper
#    end.append(s + 1)
#TypeError: can only concatenate list (not "int") to list

inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: tf.gather_nd(x, [1,2,3]))(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predict(features))
#>>> Traceback (most recent call last):
#  File "/home/user/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1628, in _create_c_op
#    c_op = c_api.TF_FinishOperation(op_desc)
#tensorflow.python.framework.errors_impl.InvalidArgumentError: indices.shape[-1] must be <= params.rank, but saw indices shape: [3] and params shape: [?,10] for 'lambda_28/GatherNd' (op: 'GatherNd') with input shapes: [?,10], [3].


>>> tf.__version__
'1.12.0'

像我这样的stackoverflow机器人将添加什么样的细节以使该问题包含更少的代码?我是否应该删除尝试过的部分以及出现的错误,以便有人可以澄清一下?

0 个答案:

没有答案