Tensorflow 2显示权重的直方图

时间:2019-07-10 13:20:03

标签: python tensorflow classification conv-neural-network tensorboard

我正在尝试使用tensorflow 2的lambdacallbacks在张量图中的每个时期显示所有网络权重(CNN)的直方图:

def log_hist_weights(model,writer):
    model = model
    writer = writer

    def log_hist_weights(epoch, logs):
        # predict images
        Ws = model.get_weights()
        with writer.as_default():
            tf.summary.histogram("epoch: " + str(epoch), Ws)
    return log_hist_weights

hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))

但是问题是函数“ get_weights ”返回所有netowrk权重而没有任何名称(例如,过滤器权重,Batchnormalization权重和其他内容),但实际上我只是被CNN-过滤器权重。

如果我可以在TF2中实现像this one这样的功能,那就太好了。

如何使用TF显示过滤器权重的直方图?

谢谢

1 个答案:

答案 0 :(得分:0)

对于其他有相同问题的人,这是我最终使用Tensorflow 2解决问题的方式:

def log_hist_weights(model,writer):
    model = model
    writer = writer

    def log_hist(epoch, logs):
        # predict images
        with writer.as_default():
            for tf_var in baseline_model.trainable_weights:
                    tf.summary.histogram(tf_var.name, tf_var.numpy(), step=epoch)
    return log_hist

    hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))

enter image description here