我可以如下单独使用“ LSTM”:
import tensorflow as tf
import numpy as np
lstm = tf.keras.layers.LSTM(10)
input_ = tf.cast(np.arange(24).reshape(-1, 24, 1), tf.float64)
lstm(input_)
output_.shape = [1,10],它是'LSTM'中'h'的值,实际上我想获取'LSTM'中'c'的值,所以我如下使用'LSTMCell' :
lstmcell = tf.keras.layers.LSTMCell(10)
input_ = tf.cast(np.arange(24).reshape(-1, 24, 1), tf.float64)
lstmcell(input_)
我收到如下错误:
Traceback (most recent call last):
File "<ipython-input-11-009171321caf>", line 3, in <module>
lstmcell(input_)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 891, in __call__
outputs = self.call(cast_inputs, *args, **kwargs)
TypeError: call() missing 1 required positional argument: 'states'
然后我将代码修改如下:
lstmcell = tf.keras.layers.LSTMCell(10)
input_ = tf.cast(np.arange(24).reshape(-1, 24, 1), tf.float64)
states = tf.cast(np.arange(20).reshape(2, 10), tf.float64)
lstmcell(input_, states)
然后我得到如下错误:
Traceback (most recent call last):
File "<ipython-input-23-2cc50f67a804>", line 1, in <module>
lstmcell(input_, states)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 891, in __call__
outputs = self.call(cast_inputs, *args, **kwargs)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\layers\recurrent.py", line 2262, in call
z += K.dot(h_tm1, self.recurrent_kernel)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\backend.py", line 1703, in dot
out = math_ops.matmul(x, y)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
return target(*args, **kwargs)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 2765, in matmul
a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py", line 6125, in mat_mul
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
InvalidArgumentError: cannot compute MatMul as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:MatMul] name: lstm_cell_1/MatMul/
那么我应该如何单独使用'LSTMCell'来获取'h'和'c'的值?