Tensorflow无效形状(InvalidArgumentError)

时间:2019-05-06 07:13:43

标签: python-3.x tensorflow tensor tensorflow2.0

model.fit产生异常:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot update variable with shape [] using a Tensor with shape [32], shapes must be equal.
         [[{{node metrics/accuracy/AssignAddVariableOp}}]]
         [[loss/dense_loss/categorical_crossentropy/weighted_loss/broadcast_weights/assert_broadcastable/AssertGuard/pivot_f/_50/_63]] [Op:__inference_keras_scratch_graph_1408]

模型定义:

model = tf.keras.Sequential()

    model.add(tf.keras.layers.InputLayer(
        input_shape=(360, 7)
    ))

    model.add(tf.keras.layers.Conv1D(32, 1, activation='relu', input_shape=(360, 7)))
    model.add(tf.keras.layers.Conv1D(32, 1, activation='relu'))
    model.add(tf.keras.layers.MaxPooling1D(3))
    model.add(tf.keras.layers.Conv1D(512, 1, activation='relu'))
    model.add(tf.keras.layers.Conv1D(1048, 1, activation='relu'))
    model.add(tf.keras.layers.GlobalAveragePooling1D())
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.Dense(32, activation='softmax'))

输入要素形状

(105, 360, 7)

输入标签形状

(105, 32, 1)

编译语句

model.compile(optimizer='adam',
                  loss=tf.keras.losses.CategoricalCrossentropy(),
                  metrics=['accuracy'])

Model.fit语句

 model.fit(features,
              labels,
              epochs=50000,
              validation_split=0.2,
              verbose=1)

任何帮助将不胜感激

2 个答案:

答案 0 :(得分:1)

您可以使用function InsertItem() { var clientContext = new SP.ClientContext.get_current(); var oList = clientContext.get_web().get_lists().getByTitle("MyList2"); var itemCreateInfo = new SP.ListItemCreationInformation(); var oListItem = oList.addItem(itemCreateInfo); var lookupSingle = new SP.FieldLookupValue(); lookupSingle.set_lookupId(9); oListItem.set_item('Title', 'testInsert'); oListItem.set_item('plat_column', lookupSingle); oListItem.update(); clientContext.load(oListItem); clientContext.executeQueryAsync( Function.createDelegate(this, function () { ItemIDCache = oListItem.get_id(); alert('Item created: ' + oListItem.get_id()); }), Function.createDelegate(this, function (sender, args) { console.log(args); })); } 来查看模型架构。

model.summary()

您的输出层的形状必须为print(model.summary()) _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv1d (Conv1D) (None, 360, 32) 256 _________________________________________________________________ conv1d_1 (Conv1D) (None, 360, 32) 1056 _________________________________________________________________ max_pooling1d (MaxPooling1D) (None, 120, 32) 0 _________________________________________________________________ conv1d_2 (Conv1D) (None, 120, 512) 16896 _________________________________________________________________ conv1d_3 (Conv1D) (None, 120, 1048) 537624 _________________________________________________________________ global_average_pooling1d (Gl (None, 1048) 0 _________________________________________________________________ dropout (Dropout) (None, 1048) 0 _________________________________________________________________ dense (Dense) (None, 32) 33568 ================================================================= Total params: 589,400 Trainable params: 589,400 Non-trainable params: 0 _________________________________________________________________ None ,但您的(None,32)的形状必须为labels。因此,您需要将形状更改为(105,32,1)。当我们想从数组形状中删除一维条目时,可以使用(105,32)函数。

答案 1 :(得分:0)

在密集层之前使用 Flatten()。