是否可以在层之间映射数据集的批量大小?

时间:2019-09-05 15:36:19

标签: tensorflow keras

考虑一下:

import tensorflow as tf
from tensorflow.keras.layers import Dense, LSTM

model = tf.keras.models.Sequential([
    Dense(10, batch_input_shape=(32, None, 100)),
    LSTM(1, stateful=True)
])
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (32, None, 10)            1010      
_________________________________________________________________
lstm  (LSTM)                 (32, 1)                   48        
=================================================================
Total params: 1,058
Trainable params: 1,058
Non-trainable params: 0
_________________________________________________________________

不管这种模型是否有意义,仅由于LSTM具有stateful=True并且需要批量大小,才设置第一层(密集层)的批量大小。提供批量大小的方法是通过第一层。这就是Dense图层指定批处理大小的原因。

我想知道是否有一种方法可以完成这项工作:

import tensorflow as tf
from tensorflow.keras.layers import Dense, LSTM

model = tf.keras.models.Sequential([
    Dense(10, batch_input_shape=(None, 32, 100)),
    #Going from (None, 32, 10) to (32, None, 10)
    LSTM(1, stateful=True)
])

我知道在使用Dataset类方法(地图,窗口,批处理)启动模型之前,这是可能的。但是我想知道是否可以在各层之间进行此操作?

1 个答案:

答案 0 :(得分:0)

显然,您可以使用Lambda图层来做到这一点:

import tensorflow as tf
from tensorflow.keras.layers import Dense, LSTM

model = tf.keras.models.Sequential([
    Dense(10, batch_input_shape=(None, 32, 100)),
    tf.keras.layers.Lambda(lambda x: tf.reshape(x, (32, -1, 10))),
    LSTM(1, stateful=True)
])
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 32, 10)            1010      
_________________________________________________________________
lambda (Lambda)              (32, None, 10)            0         
_________________________________________________________________
lstm   (LSTM)                (32, 1)                   48        
=================================================================
Total params: 1,058
Trainable params: 1,058
Non-trainable params: 0
_________________________________________________________________

谁知道!?