如何使用批处理训练模型来预测单个输入?

时间:2019-02-10 17:56:00

标签: tensorflow keras

我有在数据集上训练过的RNN模型:

train = tf.data.Dataset.from_tensor_slices((data_x[:train_size],
                          data_y[:train_size])).batch(batch_size).repeat()

型号:

    model = tf.keras.Sequential()
    model.add(tf.keras.layers.GRU(units=lstm_num_units,
                                   return_sequences=True,
                                   kernel_initializer='random_uniform',
                                   recurrent_initializer='random_uniform',
                                   bias_initializer='random_uniform',
                                   batch_size=batch_size,
                                   input_shape = [seq_len, num_features]))
    model.add(tf.keras.layers.LSTM(units=lstm_num_units,
                                   batch_size=batch_size,
                                   return_sequences=True,
                                   input_shape = [seq_len, num_features]))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(units=dence_units))
    model.add(tf.keras.layers.Dropout(drop_flat))
    model.add(tf.keras.layers.Dense(units=out_units))
    model.add(tf.keras.layers.Softmax())   

    model.compile(loss="sparse_categorical_crossentropy",
            optimizer=tf.train.RMSPropOptimizer(opt),
            metrics=['accuracy'])

 model.fit(train, epochs=EPOCHS,
                        steps_per_epoch=repeat_size_train,
                        validation_data=validate,
                        validation_steps=repeat_size_validate,
                        verbose=1,
                        shuffle=True)
                        callbacks=[tensorboard, cp_callback])

我需要对seq_len的单个输入进行预测,但是看起来我的输入必须具有批处理大小:

ar = np.random.randint(98, size=[batch_size, seq_len])
ar = np.reshape(ar, [batch_size, seq_len, 1])
prediction = model.m.predict(ar)

有没有办法使它在形状为[1,seq_len,1]的单个输入上工作?

1 个答案:

答案 0 :(得分:2)

是的,只需在第一层中重建模型而没有批处理大小即可。

复制旧模型的权重。

newModel.set_weights(oldModel.get_weights())

仅在stateful=True模型中存在批次大小的目的,以保持批次之间的一致性。

尽管如此,由于批次大小,数学上也没有变化。