例如:
BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(BATCH_SIZE, tf.compat.v1.data.get_output_shapes(train_dataset))
test_dataset = test_dataset.padded_batch(BATCH_SIZE, tf.compat.v1.data.get_output_shapes(test_dataset))
def pad_to_size(vec, size):
zeros = [0] * (size - len(vec))
vec.extend(zeros)
return vec
...
model = tf.keras.Sequential([
tf.keras.layers.Embedding(encoder.vocab_size, 64),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=False)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
print(model.summary())
打印内容为:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, None, 64) 523840
_________________________________________________________________
bidirectional (Bidirectional (None, 128) 66048
_________________________________________________________________
dense (Dense) (None, 64) 8256
_________________________________________________________________
dense_1 (Dense) (None, 1) 65
=================================================================
Total params: 598,209
Trainable params: 598,209
Non-trainable params: 0
我有以下问题:
1)对于嵌入层,输出形状为何为(None, None, 64)
。我知道'64'是向量长度。为什么另外两个没有?
2)双向层的输出形状如何为(None, 128)
?为什么是128
?
答案 0 :(得分:1)
对于嵌入层,输出形状为何为(None,None,64)。我知道'64'是向量长度。为什么另两个没有?
如果您未将(None,None)
定义为input_shape=(None,)
,则可以看到this function产生input_shape
(包括批次尺寸)(换句话说,它会(None, None)
作为默认值)顺序模型的第一层。
如果将大小为(None, None, 64)
的输入张量传递给嵌入层,则假定嵌入尺寸为64,它会生成一个None
张量。第一个input_length
是批处理尺寸,而第二个是时间维度(指的是(None, None, 64)
参数)。这就是为什么您得到Bidirectional
大小的输出的原因。
双向层的输出形状是(None,128)?为什么是128?
在这里,您有一个LSTM
LSTM。您的(None, 64)
层会生成return_sequences=False
大小的输出(当Bidirectional
时)。当您有一个LSTM
图层时,就像有两个merge_mode
图层(一个向前,另一个向后)。并且您的默认concat
为(None, 128)
,这意味着向前和向后的两个输出状态将串联在一起。这将为您提供RepositoryNotFoundError: No repository for "User" was found.
Looks like this entity is not registered in current "default" connection?
message: 'No repository for "User" was found. Looks like tes.js)
{his entity is not registered in current "default" connection?'
his entity is not registered in current "default" connection?'}
大小的输出。