Keras不能将数字数组相加而产生标量吗?

时间:2019-04-10 14:14:35

标签: python tensorflow keras

这个问题与this question有关,但是简单一些。

我希望Keras能够接受一组数字并将其求和。数组应具有由批处理尺寸定义的任意长度。

以下简单示例不起作用:

from keras.models import Model
from keras.layers import Input, Lambda
from keras import backend as K

inp = Input(shape = (1,))
out = Lambda(lambda x: K.sum(x))(inp)
m = Model(inp, out)

m.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         (None, 1)                 0         
_________________________________________________________________
lambda_2 (Lambda)            ()                        0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

输出形状应为标量:(1)。无批次尺寸。如上所述,m无法编译。

期望的结果将是m.predict(np.array([1,2,3]))产生6。也许包含6的数组或张量。

可以在Keras中完成此基本任务-作为函数应用到批次维度并返回标量吗?如果没有,可以在纯张量流中完成吗?

编辑:我刚刚了解到,您可以从此模型进行预测而无需编译:

m.predict(np.array([1,2]))

收益

array([3., 3.], dtype=float32)

我想我现在的问题是如何将其压缩为标量,然后进行编译。

1 个答案:

答案 0 :(得分:1)

tf.reduce_sum()axis=None一起使用。它将缩小所有尺寸:

import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import backend as K

inp = Input(shape = (3,))
out = Lambda(lambda x: tf.reduce_sum(x, axis=None))(inp)
m = Model(inp, out)

m.predict(np.array([[1, 2, 4]])) # array([7.], dtype=float32)

m.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 3)                 0         
_________________________________________________________________
lambda_1 (Lambda)            ()                        0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: