我正在尝试挑选输入的不同部分以馈入不同的流。切片在numpy中是这样的:
x = numpy.arange(10).reshape(-1,10)
print(x[:, [1,2,3]])
# >>> [[1 2 3]]
print(x[:, [2,5,9]])
#>>> [[2 5 9]]
但不适用于keras:
features = numpy.arange(10).reshape(-1, 10)
inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: x)(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predict(features))
# >>> [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]
inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: x[:,1:3])(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predict(features))
# >>> [[1. 2.]]
inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: x[:,[1,2,3]])(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predictsn(features))
# >>> Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "/home/user/test.py", line 589, in <module>
# predictions = keras.layers.Lambda(lambda x: x[:,[1,2,3]])(inputs)
# File "/home/user/.local/lib/python3.6/site-packages/keras/engine/base_layer.py", line 451, in __call__
# output = self.call(inputs, **kwargs)
# File "/home/user/.local/lib/python3.6/site-packages/keras/layers/core.py", line 716, in call
# return self.function(inputs, **arguments)
# File "/home/user/test.py", line 589, in <lambda>
# predictions = keras.layers.Lambda(lambda x: x[:,[1,2,3]])(inputs)
# File "/home/user/.local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 491, in _slice_helper
# end.append(s + 1)
#TypeError: can only concatenate list (not "int") to list
inputs = keras.layers.Input(shape=(10,))
predictions = keras.layers.Lambda(lambda x: tf.gather_nd(x, [1,2,3]))(inputs)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.predict(features))
#>>> Traceback (most recent call last):
# File "/home/user/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1628, in _create_c_op
# c_op = c_api.TF_FinishOperation(op_desc)
#tensorflow.python.framework.errors_impl.InvalidArgumentError: indices.shape[-1] must be <= params.rank, but saw indices shape: [3] and params shape: [?,10] for 'lambda_28/GatherNd' (op: 'GatherNd') with input shapes: [?,10], [3].
>>> tf.__version__
'1.12.0'
像我这样的stackoverflow机器人将添加什么样的细节以使该问题包含更少的代码?我是否应该删除尝试过的部分以及出现的错误,以便有人可以澄清一下?