我正在尝试使用分类交叉熵损失来训练神经网络,但是我得到了
InvalidArgumentError: Expected size[0] in [0, 0], but got 1
[[node loss_4/dense_18_loss/categorical_crossentropy/softmax_cross_entropy_with_logits/Slice_1
(defined at C:...\keras\backend\tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_2929]
Function call stack:
keras_scratch_graph
我不明白它在告诉我什么。是说它期望的标量值为0到0?
这是我的代码
model = Sequential()
model.add(layers.Dense(32,
activation='relu',
input_shape=self.obsSpace))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(self.actionSpace,
activation='softmax'))
model.compile(loss="categorical_crossentropy",
optimizer=tf.keras.optimizers.Adam())
model.net.train_on_batch(batchStates, batchQVals)
batchStates为
tf.Tensor(
[[-4.16258150e-02 2.25938538e-01 4.36936698e-02 -2.81657789e-01]
[-3.71070442e-02 4.20410887e-01 3.80605141e-02 -5.60246049e-01]
[-2.86988265e-02 2.24775998e-01 2.68555931e-02 -2.55819147e-01]
[-2.42033065e-02 4.19504426e-01 2.17392101e-02 -5.39911869e-01]
[-1.58132180e-02 2.24083740e-01 1.09409728e-02 -2.40459278e-01]
[-1.13315432e-02 4.19047703e-01 6.13178719e-03 -5.29671139e-01]
[-2.95058915e-03 2.23840031e-01 -4.46163558e-03 -2.35062400e-01]
[ 1.52621147e-03 2.87821088e-02 -9.16288359e-03 5.62098541e-02]
[ 2.10185365e-03 2.24034234e-01 -8.03868651e-03 -2.39349889e-01]
[ 6.58253833e-03 2.90280372e-02 -1.28256843e-02 5.07866067e-02]
[ 7.16309907e-03 2.24331524e-01 -1.18099522e-02 -2.45915177e-01]
[ 1.16497296e-02 4.19620142e-01 -1.67282557e-02 -5.42299721e-01]
[ 2.00421324e-02 2.24737244e-01 -2.75742501e-02 -2.54934152e-01]
[ 2.45368773e-02 4.20241828e-01 -3.26729332e-02 -5.56185350e-01]
[ 3.29417139e-02 6.15806894e-01 -4.37966402e-02 -8.58980519e-01]
[ 4.52578517e-02 8.11497186e-01 -6.09762505e-02 -1.16510658e+00]
[ 6.14877955e-02 6.17219641e-01 -8.42783820e-02 -8.92147760e-01]
[ 7.38321883e-02 8.13377481e-01 -1.02121337e-01 -1.21008870e+00]
[ 9.00997379e-02 1.00965894e+00 -1.26323111e-01 -1.53294851e+00]
[ 1.10292917e-01 1.20605640e+00 -1.56982082e-01 -1.86223761e+00]
[ 1.34414045e-01 1.01296538e+00 -1.94226834e-01 -1.62212597e+00]
[ 1.54673352e-01 8.20588296e-01 -2.26669353e-01 -1.39573053e+00]], shape=(22, 4), dtype=float64)
batchQVals是
tf.Tensor(
[6.4799747 6.447029 6.40827 6.3626704 6.309024 6.2459106 6.1716595
6.0843053 5.9815354 5.86063 5.718388 5.551045 5.354171 5.122554
4.8500633 4.529486 4.1523366 3.7086313 3.186625 2.5725 1.85
1. ], shape=(22,), dtype=float32)