我正在尝试将TensorFlow API中编写的一些开源代码转换为Keras API。我正在为每一层获取正确的输入和输出形状,但是,当我尝试对模型本身进行测试时,它会抱怨。请原谅代码,为简洁起见,我简化了很多。
这是TensorFlow API:
hidden_units_g = 100
num_generated_features = 1
seq_length = 30
batch_size = 28
latent_dim = 5
W_out_G = tf.get_variable(name='W_out_G', shape=[hidden_units_g, num_generated_features], initializer=W_out_G_initializer)
b_out_G = tf.get_variable(name='b_out_G', shape=num_generated_features, initializer=b_out_G_initializer)
# shape: (28, 30, 5)
inputs = tf.random_normal(shape=(batch_size, seq_length, latent_dim))
cell = LSTMCell(num_units=hidden_units_g,
state_is_tuple=True,
initializer=lstm_initializer,
bias_start=bias_start,
reuse=reuse)
# rnn_output shape: (28, 30, 100)
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float32,
sequence_length=[seq_length]*batch_size,
inputs=inputs)
# reshape shape: (840, 100)
rnn_outputs_2d = tf.reshape(rnn_outputs, [-1, hidden_units_g])
# shape: ( 840, 1)
logits_2d = tf.matmul(rnn_outputs_2d, W_out_G) + b_out_G
# shape: (840, 1)
output_2d = tf.nn.tanh(logits_2d)
# shape: (28, 30, 1)
output_3d = tf.reshape(output_2d, [-1, seq_length, num_generated_features])
这是我在Keras中重新创建的尝试:
gen = tf.keras.models.Sequential([
tf.keras.layers.LSTM(100, input_shape=(30, 5), return_sequences=False),
#tf.keras.layers.Dense(1, activation='tanh'),
tf.keras.layers.Dense(1, activation='tanh'),
tf.keras.layers.Reshape(target_shape=(30, 1))
])
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 100) 42400
_________________________________________________________________
dense (Dense) (None, 1) 101
_________________________________________________________________
reshape (Reshape) (None, 30, 1) 0
=================================================================
Total params: 42,501
Trainable params: 42,501
Non-trainable params: 0
还应用tf.keras.utils.plot_model()
确认我的输入/输出形状完全相同。但是,当我尝试使用np.random.randn(28, 30, 5)
进行模拟测试时,我收到一条错误消息,提示它无法重塑密集输出,它期望840个值,但只能看到28个。
在Keras中移动模型时,是否可以打印出实际的批生产形状?
谢谢
编辑:这是使用keras.utils.plot_model()
LSTM: input: (?, 30, 5)
output:(?, 100)
Dense: Input(?, 100)
: Output: (?, 1)
Reshape: Input: (?, 1)
Output: (?, 30, 1)