如何在像Keras这样的tf.estimator.Estimator中监控培训过程?

时间:2017-12-31 00:05:34

标签: tensorflow keras

我使用Estimator将Keras模型移植到tf.keras.estimator.model_to_estimator()。但是,当我打电话给estimator.train()

时,我无法弄清楚如何在Keras中监控培训过程(进度条)

代码

# define a Keras model
model   = Model(inputs=inputs, outputs=outputs)
model.compile(...)

# Convert into Estimator object
estimator = tf.keras.estimator.model_to_estimator(keras_model=model,
                                                  model_dir='./logs')

# **PROBLEM HERE** I could not monitor the training process here
estimator.train(input_fn=lambda: create_tfdata(data.train.images,
                                               data.train.labels,
                                               num_epochs=epochs,
                                               shuffle=True,
                                               batch_size=batch_size),
                max_steps=steps_per_epoch)

如果我致电model.fit(),它会显示如下进度条:

Epoch 1/5
469/469 [==============================] - 116s 248ms/step - loss: 0.1870 - acc: 0.9433
Epoch 2/5
 67/469 [===>..........................] - ETA: 1:20 - loss: 0.0613 - acc: 0.9813

0 个答案:

没有答案