tensorflow python期望density_input具有2维,但数组的形状为(5,28,5)

时间:2018-11-19 01:02:04

标签: python tensorflow

我是tensorflow的完全新手,试图了解它并解决问题。我尝试了很多教程,但是他们都讨论了相同的分类图像或mnist内容,因此我遵循了文档并试图找出一些答案。

目标是找到一种模式来预测输入为[[1000,10,5,3,1744 ... etc。只有5种情况,值为300 400、500、600、700,形状为28.5,每个结果为28.2列表。数据从文件加载并分配给tf.tensor。

这是我的代码:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(28, activation=tf.nn.relu, input_shape=(5,)))
model.add(tf.keras.layers.Dense(28, activation=tf.nn.relu, input_shape=(5,)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(28))

model.compile(optimizer='adam',
              loss='mean_squared_error',
              metrics=['accuracy'])

model.fit(newData, newResults, epochs=3, steps_per_epoch=5)

newData:

[[[300, 10, 5, 3, 1744], [300, 10, 5, 5, 2848], [300, 10, 5, 4, 2418], [300, 10, 5, 2, 1152], [300, 10, 5, 3, 1126], [300, 10, 5, 3, 1897], [300, 10, 5, 3, 1089], [300, 10, 5, 2, 1581], [300, 10, 5, 4, 1793], [300, 10, 5, 3, 1525], [300, 10, 5, 2, 1529], [300, 10, 5, 3, 1052], [300, 10, 5, 2, 1556], [300, 10, 5, 3, 1569], [300, 10, 5, 5, 2873], [300, 10, 5, 4, 2269], [300, 10, 5, 3, 3003], [300, 10, 5, 3, 1310], [300, 10, 5, 3, 1464], [300, 10, 5, 3, 2807], [300, 10, 5, 2, 1262], [300, 10, 5, 3, 1734], [300, 10, 5, 2, 2709], [300, 10, 5, 3, 2234], [300, 10, 5, 3, 1961], [300, 10, 5, 2, 1594], [300, 10, 5, 2, 1836], [300, 10, 5, 2, 1345]], 
[[400, 10, 5, 3, 1744], [400, 10, 5, 5, 2848], [400, 10, 5, 4, 2418], [400, 10, 5, 2, 1152], [400, 10, 5, 3, 1126], [400, 10, 5, 3, 1897], [400, 10, 5, 3, 1089], [400, 10, 5, 2, 1581], [400, 10, 5, 4, 1793], [400, 10, 5, 3, 1525], [400, 10, 5, 2, 1529], [400, 10, 5, 3, 1052], [400, 10, 5, 2, 1556], [400, 10, 5, 3, 1569], [400, 10, 5, 5, 2873], [400, 10, 5, 4, 2269], [400, 10, 5, 3, 3003], [400, 10, 5, 3, 1310], [400, 10, 5, 3, 1464], [400, 10, 5, 3, 2807], [400, 10, 5, 2, 1262], [400, 10, 5, 3, 1734], [400, 10, 5, 2, 2709], [400, 10, 5, 3, 2234], [400, 10, 5, 3, 1961], [400, 10, 5, 2, 1594], [400, 10, 5, 2, 1836], [400, 10, 5, 2, 1345]], 
[[500, 10, 5, 3, 1744], [500, 10, 5, 5, 2848], [500, 10, 5, 4, 2418], [500, 10, 5, 2, 1152], [500, 10, 5, 3, 1126], [500, 10, 5, 3, 1897], [500, 10, 5, 3, 1089], [500, 10, 5, 2, 1581], [500, 10, 5, 4, 1793], [500, 10, 5, 3, 1525], [500, 10, 5, 2, 1529], [500, 10, 5, 3, 1052], [500, 10, 5, 2, 1556], [500, 10, 5, 3, 1569], [500, 10, 5, 5, 2873], [500, 10, 5, 4, 2269], [500, 10, 5, 3, 3003], [500, 10, 5, 3, 1310], [500, 10, 5, 3, 1464], [500, 10, 5, 3, 2807], [500, 10, 5, 2, 1262], [500, 10, 5, 3, 1734], [500, 10, 5, 2, 2709], [500, 10, 5, 3, 2234], [500, 10, 5, 3, 1961], [500, 10, 5, 2, 1594], [500, 10, 5, 2, 1836], [500, 10, 5, 2, 1345]], 
[[600, 10, 5, 3, 1744], [600, 10, 5, 5, 2848], [600, 10, 5, 4, 2418], [600, 10, 5, 2, 1152], [600, 10, 5, 3, 1126], [600, 10, 5, 3, 1897], [600, 10, 5, 3, 1089], [600, 10, 5, 2, 1581], [600, 10, 5, 4, 1793], [600, 10, 5, 3, 1525], [600, 10, 5, 2, 1529], [600, 10, 5, 3, 1052], [600, 10, 5, 2, 1556], [600, 10, 5, 3, 1569], [600, 10, 5, 5, 2873], [600, 10, 5, 4, 2269], [600, 10, 5, 3, 3003], [600, 10, 5, 3, 1310], [600, 10, 5, 3, 1464], [600, 10, 5, 3, 2807], [600, 10, 5, 2, 1262], [600, 10, 5, 3, 1734], [600, 10, 5, 2, 2709], [600, 10, 5, 3, 2234], [600, 10, 5, 3, 1961], [600, 10, 5, 2, 1594], [600, 10, 5, 2, 1836], [600, 10, 5, 2, 1345]], 
[[700, 10, 5, 3, 1744], [700, 10, 5, 5, 2848], [700, 10, 5, 4, 2418], [700, 10, 5, 2, 1152], [700, 10, 5, 3, 1126], [700, 10, 5, 3, 1897], [700, 10, 5, 3, 1089], [700, 10, 5, 2, 1581], [700, 10, 5, 4, 1793], [700, 10, 5, 3, 1525], [700, 10, 5, 2, 1529], [700, 10, 5, 3, 1052], [700, 10, 5, 2, 1556], [700, 10, 5, 3, 1569], [700, 10, 5, 5, 2873], [700, 10, 5, 4, 2269], [700, 10, 5, 3, 3003], [700, 10, 5, 3, 1310], [700, 10, 5, 3, 1464], [700, 10, 5, 3, 2807], [700, 10, 5, 2, 1262], [700, 10, 5, 3, 1734], [700, 10, 5, 2, 2709], [700, 10, 5, 3, 2234], [700, 10, 5, 3, 1961], [700, 10, 5, 2, 1594], [700, 10, 5, 2, 1836], [700, 10, 5, 2, 1345]]]

newResult:

[[[29.0, 8.92], [52.0, 21.67], [41.0, 14.38], [7.0, 1.49], [26.0, 8.25], [18.0, 4.53], [24.0, 6.61], [21.0, 9.54], [17.0, 5.53], [27.0, 9.61], [11.0, 0.35], [22.0, 8.11], [7.0, 1.22], [36.0, 15.49], [57.0, 31.44], [43.0, 16.52], [34.0, 11.46], [15.0, 2.49], [20.0, 2.34], [16.0, 4.86], [10.0, 0.8], [8.0, 0.4], [1.0, 0.0], [30.0, 7.57], [24.0, 7.21], [5.0, 0.58], [14.0, 0.73], [4.0, 0.15]], 
[[45.0, 8.17], [100.0, 43.28], [54.0, 16.05], [10.0, 2.77], [37.0, 8.86], [27.0, 6.12], [33.0, 9.13], [34.0, 14.03], [20.0, 5.06], [45.0, 15.42], [21.0, 0.69], [26.0, 8.83], [11.0, 2.14], [44.0, 17.74], [73.0, 43.39], [43.0, 18.8], [46.0, 21.56], [29.0, 9.16], [21.0, 3.76], [20.0, 7.39], [16.0, 2.54], [1.0, 1.63], [1.0, 0.02], [28.0, 12.14], [30.0, 12.35], [7.0, 1.18], [19.0, 3.29], [4.0, 0.16]], 
[[59.0, 18.74], [100.0, 75.18], [69.0, 32.13], [11.0, 3.04], [49.0, 15.76], [30.0, 10.33], [45.0, 14.51], [43.0, 20.82], [37.0, 8.2], [69.0, 24.53], [1.0, 0.3], [38.0, 12.57], [1.0, 3.67], [65.0, 24.77], [91.0, 57.39], [53.0, 18.22], [47.0, 27.07], [34.0, 16.31], [25.0, 5.39], [31.0, 11.5], [23.0, 5.73], [19.0, 4.11], [2.0, 0.11], [35.0, 15.52], [41.0, 18.15], [7.0, 1.48], [25.0, 7.53], [3.0, 0.14]], 
[[80.0, 30.29], [100.0, 85.22], [94.0, 52.73], [11.0, 2.45], [72.0, 30.7], [46.0, 14.75], [70.0, 22.81], [50.0, 28.26], [40.0, 14.19], [60.0, 26.82], [14.0, 0.28], [45.0, 19.1], [16.0, 4.72], [82.0, 40.98], [100.0, 78.96], [66.0, 27.05], [67.0, 31.09], [34.0, 16.92], [23.0, 7.03], [48.0, 21.28], [27.0, 8.19], [21.0, 3.95], [2.0, 0.17], [43.0, 19.96], [55.0, 23.54], [8.0, 1.47], [28.0, 12.04], [4.0, 0.13]], 
[[95.0, 38.09], [100.0, 92.88], [99.0, 58.96], [13.0, 3.54], [96.0, 45.78], [33.0, 12.05], [87.0, 38.11], [62.0, 34.97], [48.0, 15.49], [84.0, 33.13], [10.0, 0.09], [63.0, 25.52], [16.0, 4.87], [100.0, 55.9], [100.0, 91.32], [90.0, 34.24], [96.0, 45.36], [37.0, 15.13], [27.0, 9.28], [49.0, 26.3], [30.0, 10.92], [22.0, 3.72], [3.0, 0.14], [67.0, 24.82], [73.0, 31.32], [8.0, 1.36], [31.0, 15.03], [4.0, 0.2]]]

运行时出现此错误:

  File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1536, in fit
    validation_split=validation_split)
  File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\keras\engine\training.py", line 992, in _standardize_user_data
    class_weight, batch_size)
  File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1117, in _standardize_weights
    exception_prefix='input')
  File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 323, in standardize_input_data
    'with shape ' + str(data_shape))
ValueError: Error when checking input: expected dense_input to have 2 dimensions, but got array with shape (5, 28, 5)

我知道我的模型肯定有问题,但是我不太清楚是什么原因。除了上述示例外,我很难找到其他信息。

1 个答案:

答案 0 :(得分:0)

您编写的代码从28个神经元的密集层开始,它们期望形状为(*, 5)。错误消息显示它正在接收的形状是((5,28,5)。

查看您的数据,它的形状为(5, 28, 5)。也许您还有另外一组括号?