我如何在tensorflow2.0中独立使用'LSTMCell'?

时间:2019-10-10 10:46:56

标签: lstm tensorflow2.0

我可以如下单独使用“ LSTM”:

import tensorflow as tf
import numpy as np
lstm = tf.keras.layers.LSTM(10)
input_ = tf.cast(np.arange(24).reshape(-1, 24, 1), tf.float64)
lstm(input_)

output_.shape = [1,10],它是'LSTM'中'h'的值,实际上我想获取'LSTM'中'c'的值,所以我如下使用'LSTMCell' :

lstmcell = tf.keras.layers.LSTMCell(10)
input_ = tf.cast(np.arange(24).reshape(-1, 24, 1), tf.float64)
lstmcell(input_)

我收到如下错误:

Traceback (most recent call last):

  File "<ipython-input-11-009171321caf>", line 3, in <module>
    lstmcell(input_)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 891, in __call__
    outputs = self.call(cast_inputs, *args, **kwargs)

TypeError: call() missing 1 required positional argument: 'states'

然后我将代码修改如下:

lstmcell = tf.keras.layers.LSTMCell(10)
input_ = tf.cast(np.arange(24).reshape(-1, 24, 1), tf.float64)
states = tf.cast(np.arange(20).reshape(2, 10), tf.float64)
lstmcell(input_, states)

然后我得到如下错误:

Traceback (most recent call last):

  File "<ipython-input-23-2cc50f67a804>", line 1, in <module>
    lstmcell(input_, states)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 891, in __call__
    outputs = self.call(cast_inputs, *args, **kwargs)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\layers\recurrent.py", line 2262, in call
    z += K.dot(h_tm1, self.recurrent_kernel)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\keras\backend.py", line 1703, in dot
    out = math_ops.matmul(x, y)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 2765, in matmul
    a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)

  File "D:\Program\anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py", line 6125, in mat_mul
    _six.raise_from(_core._status_to_exception(e.code, message), None)

  File "<string>", line 3, in raise_from

InvalidArgumentError: cannot compute MatMul as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:MatMul] name: lstm_cell_1/MatMul/

那么我应该如何单独使用'LSTMCell'来获取'h'和'c'的值?

0 个答案:

没有答案