密集层需要1个输入,但收到2个输入张量,我该如何更改

时间:2020-09-15 16:54:17

标签: tensorflow python keras

所以我是一个初学者,只是接近tensorflow2和keras,当我偶然发现以下错误时,我只是在玩耍并尝试制作一些模型:

Traceback (most recent call last):
  File "/home/arch_poppin/dev/AI/reviews/rev.py", line 7, in <module>
    x = layers.Dense(8, activation='relu')([input1, input2])
  File "/usr/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 930, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/usr/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1068, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "/usr/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 801, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/usr/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 841, in _infer_output_signature
    self._maybe_build(inputs)
  File "/usr/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 2647, in _maybe_build
    input_spec.assert_input_compatibility(
  File "/usr/lib/python3.8/site-packages/tensorflow/python/keras/engine/input_spec.py", line 204, in assert_input_compatibility
    raise ValueError('Layer ' + layer_name + ' expects ' +
ValueError: Layer dense expects 1 input(s), but it received 2 input tensors. Inputs received: [<tf.Tensor 'Placeholder:0' shape=(None,) dtype=float32>, <tf.Tensor 'Placeholder_1:0' shape=(None,) dtype=float32>].

如何使我的密集层接受前面两个层的输出作为输入?我的密码 如下:

X1 = tf.constant([2, 3, 4, 5, 6, 7])
X2 = tf.constant([2, 3, 4, 5, 6, 7])
yTrain = tf.constant([4, 6, 8, 10, 12, 14])

input1 = keras.Input(shape=(X1.shape[1:]))
input2 = keras.Input(shape=(X2.shape[1:]))
x = layers.Dense(8, activation='relu')([input1, input2])
outputs = layers.Dense(2)(x)
mlp = keras.Model(input1, input2, outputs)

mlp.compile(loss='mean_squared_error',
            optimizer='adam', metrics=['accuracy'])

mlp.fit(X1, X2, yTrain, batch_size=1, epochs=10, validation_split=0.2,
        shuffle=True)

mlp.evaluate(X1, X2, yTrain)
```

1 个答案:

答案 0 :(得分:0)

您必须使用concatenate层来提供多个输入。建议您阅读本教程-Models with multiple inputs and outputs。还要注意我如何在mlp.fitmlp.evaluate中传递输入。

您必须像下面那样修改代码-

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

X1 = tf.constant([2, 3, 4, 5, 6, 7])
X2 = tf.constant([2, 3, 4, 5, 6, 7])
yTrain = tf.constant([4, 6, 8, 10, 12, 14])

input1 = keras.Input(shape=(1,))
input2 = keras.Input(shape=(1,))

x = layers.concatenate([input1, input2])
x = layers.Dense(8, activation='relu')(x)
outputs = layers.Dense(2)(x)
mlp = keras.Model([input1, input2], outputs)

mlp.summary()

mlp.compile(loss='mean_squared_error',
            optimizer='adam', metrics=['accuracy'])

mlp.fit([X1, X2], yTrain, batch_size=1, epochs=10, validation_split=0.2,
        shuffle=True)

mlp.evaluate([X1, X2], yTrain)

输出-

Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_16 (InputLayer)           [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_17 (InputLayer)           [(None, 1)]          0                                            
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 2)            0           input_16[0][0]                   
                                                                 input_17[0][0]                   
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 8)            24          concatenate_3[0][0]              
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 2)            18          dense_8[0][0]                    
==================================================================================================
Total params: 42
Trainable params: 42
Non-trainable params: 0
__________________________________________________________________________________________________
Epoch 1/10
4/4 [==============================] - 0s 32ms/step - loss: 36.3828 - accuracy: 0.0000e+00 - val_loss: 113.8439 - val_accuracy: 0.0000e+00
Epoch 2/10
4/4 [==============================] - 0s 5ms/step - loss: 35.9418 - accuracy: 0.0000e+00 - val_loss: 112.5230 - val_accuracy: 0.0000e+00
Epoch 3/10
4/4 [==============================] - 0s 5ms/step - loss: 35.4931 - accuracy: 0.0000e+00 - val_loss: 111.2055 - val_accuracy: 0.0000e+00
Epoch 4/10
4/4 [==============================] - 0s 5ms/step - loss: 35.1113 - accuracy: 0.0000e+00 - val_loss: 109.8443 - val_accuracy: 0.0000e+00
Epoch 5/10
4/4 [==============================] - 0s 6ms/step - loss: 34.6126 - accuracy: 0.0000e+00 - val_loss: 108.5272 - val_accuracy: 0.0000e+00
Epoch 6/10
4/4 [==============================] - 0s 5ms/step - loss: 34.1876 - accuracy: 0.0000e+00 - val_loss: 107.2023 - val_accuracy: 0.0000e+00
Epoch 7/10
4/4 [==============================] - 0s 6ms/step - loss: 33.7342 - accuracy: 0.0000e+00 - val_loss: 105.8807 - val_accuracy: 0.0000e+00
Epoch 8/10
4/4 [==============================] - 0s 6ms/step - loss: 33.2854 - accuracy: 0.0000e+00 - val_loss: 104.5553 - val_accuracy: 0.0000e+00
Epoch 9/10
4/4 [==============================] - 0s 6ms/step - loss: 32.8268 - accuracy: 0.0000e+00 - val_loss: 103.2337 - val_accuracy: 0.0000e+00
Epoch 10/10
4/4 [==============================] - 0s 6ms/step - loss: 32.3614 - accuracy: 0.0000e+00 - val_loss: 101.9115 - val_accuracy: 0.0000e+00
1/1 [==============================] - 0s 1ms/step - loss: 55.3572 - accuracy: 0.0000e+00
[55.35719299316406, 0.0]