如何为Keras实现Beholder(Tensorboard插件)?

时间:2019-04-24 23:44:34

标签: python tensorflow tensorboard

我正在尝试将Tensorboard中的Beholder插件实现为简单的CNN代码(我是Tensorflow的初学者),但是我不确定将visualizer.update(session=session)放在哪里。 一开始我有:

from tensorboard.plugins.beholder import Beholder
LOG_DIRECTORY='/tmp/tensorflow_logs'
visualizer = Beholder(logdir=LOG_DIRECTORY)

我像这样训练我的模型:

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(253,27,3))) 
.
.
.
model.compile(loss='binary_crossentropy',
                  optimizer='rmsprop',
                  metrics=['accuracy'])

我应该将visualizer.update(session=session)放在什么地方,我应该在代码中另外放置什么,因为现在它说没有找到Beholder数据。谢谢!

1 个答案:

答案 0 :(得分:0)

创建一个custom Keras callback是合适的,这样您就可以在每个时期的末尾(或您希望的任何时候)调用$ cat tst.awk { addr = gensub(/.* ([^:]+):.*$/,"\\1",1) } /peer holds all/ { peers[addr] } /no free leases/ { frees[addr] } END { PROCINFO["sorted_in"] = "@ind_str_asc" print "Peer Holds Leases - Via:" for (addr in peers) { print addr } print "No Free Leases:" for (addr in frees) { print addr } } $ awk -f tst.awk file Peer Holds Leases - Via: 1.2.3.188 1.2.3.189 No Free Leases: 1.2.64.0/24 1.2.65.0/24 。这是一个示例,显示这种回调的样子:

visualizer.update(session=session)

然后,在定义模型后,实例化回调并将其传递给model.fit

from tensorboard.plugins.beholder import Beholder
import tensorflow as tf
import keras.backend as K
import keras

LOG_DIRECTORY='/tmp/tensorflow_logs'


class BeholderCallback(keras.callbacks.Callback):
    def __init__(self, frame, logdir=LOG_DIRECTORY, sess=None):
        self.visualizer = Beholder(logdir=logdir)
        self.sess = sess
        if sess is None:
            self.sess = K.get_session()
        self.frame = frame

    def on_epoch_end(self, epoch, logs=None):
        self.visualizer.update(
            session=self.sess,
            frame=self.frame
        )

您还可以类似的方式使用# Define your Keras model # ... # Prepare callback sess = K.get_session() beholder_callback = BeholderCallback(your_frame, sess=sess) # Fit data into model and pass callback to model.fit model.fit(x=x_train, y=y_train, callbacks=[beholder_callback]) 的参数arrays