在TF

时间:2019-09-17 16:29:51

标签: tensorflow keras

我正在尝试在TensorFlow中实现以下内容:

Input * const
640x800x6的

矩阵乘法  这是代码     ssValues = np.zeros(shape =(6,640,800),dtype = np.float16)

inputPlaceHolder = tf.compat.v1.placeholder(shape=(6,640,800), name='InputTensor', dtype=tf.dtypes.float16)
inputLayer = tf.keras.Input(shape=(6,640,800,),
                            batch_size=1,
                            name='inputLayer',
                            dtype=tf.dtypes.float16,
                            tensor=inputPlaceHolder)

ssConstant = tf.constant(ssValues, dtype=tf.dtypes.float16, shape=(6,640,800), name='ss')
ssm = tf.keras.layers.Multiply()([inputPlaceHolder,inputPlaceHolder])
model = tf.keras.models.Model(inputs=inputLayer, outputs=ssm)

input = np.zeros(shape=(6,640,800),dtype=np.float16)

output = model.predict(input)

我收到以下错误: ValueError:(''检查模型输入时出错:预期没有数据,但得到了:',array([[[[1。,1.,1.,...,1.,1.,1。],

  • 如何克服此错误并运行预测功能?

  • 为什么tf.keras.layers.multiply不返回Layer对象?

2 个答案:

答案 0 :(得分:1)

使用Input(shape)时已经有一个占位符。创建占位符以将其传递给Input(tensor=placeholder)是没有意义的,因为Keras并非如此。

您必须:

inputs = Input(shape=(6,640,800))
ssm_tensor = Multiply()([inputs, inputs])
model = Model(inputs, ssm)

由于Keras始终具有批次大小,因此:

input = np.zeros(shape=(1,6,640,800))

答案 1 :(得分:1)

您的问题来自以下事实:您在v1占位符上声明了您的操作,而该操作仅应使用inputLayer(它已充当遵循所提供规范的输入的占位符)。

此外,当我认为您想要$ x \ times constant $时,您编写了一个返回$ x \ times x $的乘法;所以这是代码:

inputLayer = tf.keras.Input(shape=(6,640,800,),
                            batch_size=1,
                            name='inputLayer',
                            dtype=tf.dtypes.float16)
ssConstant = tf.constant(  # also fixed a shape issue here
    ssValues, dtype=tf.dtypes.float16, shape=(1, 6,640,800), name='ss'
)
ssm = tf.keras.layers.Multiply(dtype=tf.dtypes.float16)([inputLayer, ssConstant])
model = tf.keras.models.Model(inputs=inputLayer, outputs=ssm)

inputs = np.zeros(shape=(1,6,640,800), dtype=np.float16)
output = model.predict(inputs)

此外,由于这不是实际模型,因此从某种意义上来说,它使用的是恒定且不可学习的权重,因此您可能希望使用tf.keras.backend.function而不是tf.keras.Model(但这实际上取决于您)。

请注意,形状可能与您实际想要的形状不符,批大小为1 ...请考虑使用批大小为6来删除无用的尺寸。