我想通过一个时间序列让 LSTM 进行预测,但我犯了这个错误。
我的 X_train 和 y_train 形状
X_train_seasonal.shape
(893, 93)
y_train_seasonal.shape
(893,)
我的 LSTM
def getModel():
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(100, activation='relu', input_shape=(X_train_seasonal.shape[0],X_train_seasonal[1])))
model.add(tf.keras.layers.Dense(units=1))
return model
我的模型编译:
model = getModel()
model.compile(
loss='mean_squared_error',
optimizer=tf.keras.optimizers.Adam(0.001)
)
我的历史:
history_seasonal = model.fit(
X_train_seasonal, y_train_seasonal,
epochs=100,
batch_size=32,
validation_split=0.1,
shuffle=False
)
我得到了那个错误:
<块引用>ValueError: 层序列 1 的输入 0 与层不兼容:预期 ndim=3,发现 ndim=2。收到完整形状:(无,93)
我是 LSTM 的新手,希望得到任何帮助。谢谢各位。
答案 0 :(得分:0)
我能够使用示例代码复制您的问题,如下所示
import tensorflow as tf
import numpy as np
inputs = tf.random.normal([10, 8])
simple_lstm = tf.keras.layers.LSTM(4)
output = simple_lstm(inputs)
print(output)
输出:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-1-2cb62bec2981> in <module>()
4 inputs = tf.random.normal([10, 8])
5 simple_lstm = tf.keras.layers.LSTM(4)
----> 6 output = simple_lstm(inputs)
7 print(output)
2 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name)
217 'expected ndim=' + str(spec.ndim) + ', found ndim=' +
218 str(ndim) + '. Full shape received: ' +
--> 219 str(tuple(shape)))
220 if spec.max_ndim is not None:
221 ndim = x.shape.rank
ValueError: Input 0 of layer lstm is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (10, 8)
固定代码:
LSTM
期望输入一个 3D 张量,形状为 [batch, timesteps, feature]
。
import tensorflow as tf
import numpy as np
inputs = tf.random.normal([32, 10, 8])
simple_lstm = tf.keras.layers.LSTM(4)
output = simple_lstm(inputs)
print(output)
输出:
tf.Tensor(
[[ 0.06065299 0.3409816 0.32014534 0.06411142]
[ 0.00128129 -0.38577533 -0.11682745 0.10254318]
[ 0.20580113 -0.14564514 0.11878442 -0.10241538]
[-0.19621769 0.159462 -0.14459077 -0.06216513]
[-0.10081916 0.17638563 0.07971784 0.1896367 ]
[ 0.33309937 0.07162716 -0.08868891 -0.00883376]
[ 0.17272277 -0.34112597 0.099504 0.0996887 ]
[ 0.21682273 0.00900807 0.5081149 0.02028211]
[ 0.2525146 0.04386558 -0.09498325 0.10461893]
[ 0.21941815 -0.3566848 -0.05213086 0.18148176]
[ 0.22719224 -0.29461107 0.07673237 -0.1128229 ]
[ 0.00436124 -0.14181408 0.0085922 -0.49300092]
[-0.15231487 0.06897711 -0.30905092 0.06128961]
[ 0.25731358 -0.03430091 -0.2770667 0.14336488]
[-0.09124507 0.12587348 0.04689778 -0.12776485]
[ 0.15820538 -0.03368861 0.01726492 -0.02723333]
[ 0.04661836 -0.06789393 0.0413182 0.14790519]
[-0.04050795 0.18242228 -0.14569572 0.00631422]
[ 0.12048664 -0.01098521 -0.19472744 0.15155892]
[-0.20725198 0.1710444 -0.3829169 0.01446645]
[ 0.06099659 0.15198827 -0.18342684 0.10251417]
[ 0.01376235 -0.07642762 0.16639794 0.02458677]
[ 0.21958975 -0.08766301 -0.02507084 0.00333961]
[-0.15135197 -0.00785332 -0.02620712 -0.15565342]
[ 0.07218299 -0.0798007 0.01710635 -0.2993008 ]
[ 0.41292062 -0.2292722 -0.14371048 0.2036194 ]
[ 0.19662695 -0.10295419 -0.01334361 -0.3022645 ]
[-0.2587392 -0.10956616 0.10394819 -0.3500641 ]
[-0.0293685 -0.25388837 0.07153057 0.02657588]
[ 0.23911244 -0.3574759 0.06245361 -0.04481344]
[-0.32070398 0.03763141 0.03036258 -0.2610327 ]
[-0.13514674 -0.14885807 0.2496089 -0.12311874]], shape=(32, 4), dtype=float32)