我的数据集是X。X的形状是(423,320,3)。 数据数量为423,数据长度为320。 我使用python的spektral软件包。
X.shape # (423,320,3)
Adj矩阵为A。A的形状为(423,423)
A.shape # (423,423)
我的y标签为y。 y的形状是(320,1)
y.shape # (320,1)
和我的模特在下面。我认为我的模型是如此简单。但这不起作用。
N = A.shape[0]
F = X.shape[-1]
n_classes = 1
X_in = Input(shape=(423,320,))
A_in = Input((N, ), sparse=True)
X_1 = GraphConv(16, 'relu')([X_in, A_in])
X_1 = Dropout(0.5)(X_1)
X_2 = GraphConv(n_classes, 'relu')([X_1, A_in])
model = Model(inputs=[X_in, A_in], outputs=X_2)
A = GraphConv.preprocess(A).astype('f4')
model.compile(optimizer='adam',
loss='mean_squared_error',
weighted_metrics=['accuracy'])
model.summary()
model.fit([X, A], y)
模型摘要如下
Model: "model_32"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_63 (InputLayer) [(None, 423, 320)] 0
__________________________________________________________________________________________________
input_64 (InputLayer) [(None, None)] 0
__________________________________________________________________________________________________
graph_conv_52 (GraphConv) (None, 423, 16) 5136 input_63[0][0]
input_64[0][0]
__________________________________________________________________________________________________
dropout_25 (Dropout) (None, 423, 16) 0 graph_conv_52[0][0]
__________________________________________________________________________________________________
graph_conv_53 (GraphConv) (None, 423, 1) 17 dropout_25[0][0]
input_64[0][0]
==================================================================================================
Total params: 5,153
Trainable params: 5,153
Non-trainable params: 0
__________________________________________________________________________________________________
错误在下面
ValueError Traceback (most recent call last)
<ipython-input-272-d00160881a92> in <module>
18 model.summary()
19
---> 20 model.fit([X, A], y)
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
817 max_queue_size=max_queue_size,
818 workers=workers,
--> 819 use_multiprocessing=use_multiprocessing)
820
821 def evaluate(self,
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
233 max_queue_size=max_queue_size,
234 workers=workers,
--> 235 use_multiprocessing=use_multiprocessing)
236
237 total_samples = _get_total_number_of_samples(training_data_adapter)
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_training_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, steps_per_epoch, validation_split, validation_data, validation_steps, shuffle, distribution_strategy, max_queue_size, workers, use_multiprocessing)
591 max_queue_size=max_queue_size,
592 workers=workers,
--> 593 use_multiprocessing=use_multiprocessing)
594 val_adapter = None
595 if validation_data:
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
644 standardize_function = None
645 x, y, sample_weights = standardize(
--> 646 x, y, sample_weight=sample_weights)
647 elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
648 standardize_function = standardize
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
2381 is_dataset=is_dataset,
2382 class_weight=class_weight,
-> 2383 batch_size=batch_size)
2384
2385 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
2408 feed_input_shapes,
2409 check_batch_axis=False, # Don't enforce the batch size.
-> 2410 exception_prefix='input')
2411
2412 # Get typespecs for the input data and sanitize it if necessary.
~/anaconda3/envs/tensor2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
580 ': expected ' + names[i] + ' to have shape ' +
581 str(shape) + ' but got array with shape ' +
--> 582 str(data_shape))
583 return data
584
**ValueError: Error when checking input: expected input_63 to have shape (423, 320) but got array with shape (320, 3)**
答案 0 :(得分:0)
不确定是否仍需要帮助,但问题出在输入上。
您的节点特征X
的形状为(423,320,3),但是您的数据仅表示一个包含423个节点的图形。 Spektral不支持多维节点属性,因此您应将X重塑为(423,320 * 3):
X = X.reshape(423, 320 * 3)
此外,由于您使用的是model.fit()
,因此应将批量大小设置为N
或执行类似的操作:
for epoch in range(epochs):
model.train_on_batch([X, A], y)