我正在尝试使用tensorflow 2的lambdacallbacks在张量图中的每个时期显示所有网络权重(CNN)的直方图:
def log_hist_weights(model,writer):
model = model
writer = writer
def log_hist_weights(epoch, logs):
# predict images
Ws = model.get_weights()
with writer.as_default():
tf.summary.histogram("epoch: " + str(epoch), Ws)
return log_hist_weights
hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))
但是问题是函数“ get_weights ”返回所有netowrk权重而没有任何名称(例如,过滤器权重,Batchnormalization权重和其他内容),但实际上我只是被CNN-过滤器权重。
如果我可以在TF2中实现像this one这样的功能,那就太好了。
如何使用TF显示过滤器权重的直方图?
谢谢
答案 0 :(得分:0)
对于其他有相同问题的人,这是我最终使用Tensorflow 2解决问题的方式:
def log_hist_weights(model,writer):
model = model
writer = writer
def log_hist(epoch, logs):
# predict images
with writer.as_default():
for tf_var in baseline_model.trainable_weights:
tf.summary.histogram(tf_var.name, tf_var.numpy(), step=epoch)
return log_hist
hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))