我正在尝试使用Tensorflow Datasets API使用Tensorflow 2.0.0的Keras从多维输入预测到多维输出。
我正在tensorflow 2.0.0
上使用tensorflow-datasets 1.3.0
和python 3.6.9
。
下面是我的示例代码,我还在可运行的[Colab笔记本](https://colab.research.google.com/drive/1WMccCeLOrQU4k5D2noC4S_5rMe7-krEk)上重现了该代码:
import tensorflow as tf
data = [[1,2],[11,22]]
label = [[3,4,5], [33,44,55]]
dataset = tf.data.Dataset.from_tensor_slices((data,label))
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(3))
model.compile('adam','mse',metrics=['mse'])
model.fit(dataset, validation_data=dataset)
在此示例代码中,我试图预测
[1,2]->[3,4,5]
和[11,22]->[33,44,55]
。但是我得到了错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1609 try:
-> 1610 c_op = c_api.TF_FinishOperation(op_desc)
1611 except errors.InvalidArgumentError as e:
InvalidArgumentError: Dimensions must be equal, but are 2 and 3 for 'loss/output_1_loss/SquaredDifference' (op: 'SquaredDifference') with input shapes: [2,3], [3,1].
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
29 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1611 except errors.InvalidArgumentError as e:
1612 # Convert to ValueError for backwards compatibility.
-> 1613 raise ValueError(str(e))
1614
1615 return c_op
ValueError: Dimensions must be equal, but are 2 and 3 for 'loss/output_1_loss/SquaredDifference' (op: 'SquaredDifference') with input shapes: [2,3], [3,1].
答案 0 :(得分:0)
根据问题的thushv89's comment, 在数据集上使用批处理可修复代码。 原始代码比这更复杂,但是使用批处理对其进行了修复。
import tensorflow as tf
data = [[1,2],[11,22]]
label = [[3,4,5], [33,44,55]]
dataset = tf.data.Dataset.from_tensor_slices((data,label)).batch(2)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(3))
model.compile('adam','mse',metrics=['mse'])
model.fit(dataset, validation_data=dataset)