为什么我的TensorBoard训练图不完整(使用Jupyter / TensorFlow / Keras)?

时间:2019-11-11 13:15:51

标签: tensorflow keras tensorboard

当我使用TensorBoard可视化训练集和验证集上的错误时,TensorBoard中的训练集错误系列通常会提前停止(不显示所有步骤)。

使用以下代码时,序列会随机停止:

import os
import numpy as np
import matplotlib.pyplot as plt
import time
import tensorflow as tf
import tensorflow.keras as keras

# Generate some training data (Y = X0 - X2^2).
X = np.random.rand(1000,2)
Y = X[:,0] + np.square(X[:,1])
X_val = np.random.rand(100,2)
Y_val = X_val[:,0] + np.square(X_val[:,1])

# Tensorboard logs.
log_dir = os.path.join('./tensorboard-logs/', 'stack-overflow', time.strftime('%Y_%m_%d-%H_%M_%S'))
os.makedirs(log_dir, exist_ok=True)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)

# Tensorflow model (create, compile, train)
model = tf.keras.Sequential([keras.layers.Flatten(input_shape=X[0].shape),
                             keras.layers.Dense(1, use_bias=True, activation='linear')])
model.compile(optimizer='adam', loss='mse', metrics=['mse'])
hist = model.fit(X, Y, epochs=500, batch_size=1000, validation_data=(X_val, Y_val), verbose=0, callbacks=[tb_callback])

当我绘制hist对象时,我得到了我期望的图形-两个系列,值从0到499:

# Plot history locally.
plt.plot(range(len(hist.history['mse'])), hist.history['mse'], 'r')
plt.plot(range(len(hist.history['val_mse'])), hist.history['val_mse'], 'b')
plt.figure()

上面的代码产生完整的图。但是,如果我看一下TensorBoard向我显示的内容,则训练图通常在值499之前就停止了。为什么TensorBoard训练图通常不完整?

1 个答案:

答案 0 :(得分:0)

像往常一样,在一天中的大部分时间里与这个问题作斗争之后,我认为我在发布后几分钟就找到了答案:只需将--reload_multifile添加到tensorboard命令中即可。