tf.keras.layers.LSTM的initial_state的输入形状

时间:2019-10-03 13:31:38

标签: tensorflow keras lstm tf.keras

在这里,我想构建一个非常基本且简单的字符型RNN。

假设我的数据集是这样嵌入的:

import numpy as np
 batch_1 = np.array([[1, 2, ...., 20], [21, .....,40], [41,....,60], [61,...., 80]])
 batch_2 = np.array([[...], [...], [...], [...]])
import tensorflow as tf
batch_size = 4
steps_number = 20
hidden_units = 100
keep_prob = 0.5
dim = tf.zeros([batch_size, hidden_units])
input_data = tf.keras.layers.Input(shape=(1, steps_number), batch_size=batch_size)
hidden_1, state_h, state_c = tf.keras.layers.LSTM(units=hidden_units, stateful=True, dropout=keep_prob, return_state=True)(input_data, initial_state=[dim, dim], training=True)
hideen_2 = tf.keras.layers.LSTM(units=hidden_units, stateful=True, dropout=keep_prob, return_state=False)(hidden_1, initial_state=[state_h, state_c], training=True)
hidden3 = tf.keras.layers.Dense(10, activation='relu')(hidden_1)
output = tf.keras.layers.Dense(1, activation='sigmoid')(hidden3)
model = tf.keras.models.Model(input_data, output)

在这里,我在hidden_​​2层收到此错误: ValueError:形状(100,4)的排名必须至少为3

问题是hidden_​​1图层大小的输出应为[batch_size,steps_number,hidden_​​units]

1 个答案:

答案 0 :(得分:0)

这是可行的解决方案,但我不明白为什么我必须根据列数组指定输入形状:

  

shape =(steps_number,1)而不是(1,steps_number)

import tensorflow as tf
batch_size = 4
steps_number = 20
hidden_units = 100
keep_prob = 0.5
dim = tf.zeros([batch_size, hidden_units])
input_data = tf.keras.layers.Input(shape=(steps_number,1), batch_size=batch_size)
hidden_1, state_h, state_c = tf.keras.layers.LSTM(units=hidden_units, stateful=True, dropout=keep_prob, return_state=True, return_sequences=True)(input_data, initial_state=[dim, dim], training=True)
print(hidden_1.get_shape().as_list)
hideen_2 = tf.keras.layers.LSTM(units=hidden_units, stateful=True, dropout=keep_prob, return_state=False)(hidden_1, initial_state=[state_h, state_c], training=True)
hidden3 = tf.keras.layers.Dense(10, activation='relu')(hidden_1)
output = tf.keras.layers.Dense(1, activation='sigmoid')(hidden3)
model = tf.keras.models.Model(input_data, output)