Tensorflow - 如何为保护程序添加嵌入

时间:2018-01-23 14:13:55

标签: python tensorflow tensorboard

我使用TF训练模型并有两个文件myModel.pytrain_model.py。为此,我使用以下代码(train_model.py):

...
with tf.Session() as sess:
    my_model = MyModel(placeholders, learning_rate, ...) # builds the graph
    merged = tf.summary.merge_all()
    config = projector.ProjectorConfig()
    writer = tf.summary.FileWriter(save_path, sess.graph)
    for epoch in range(num_epochs):
        sess.run([model.opt_op, model.loss, model.accuracy, merged],
                 feed_dict=feed_dict)
        ...
    model.save(path, sess=sess)

模型本身实现了这样的保存和加载函数(myModel.py):

class MyModel:
    self.vars = None
    self.activations = None
    self.build() # builds the graph and fills self.vars & self.activations
    ...

    def save(self, path, sess):
        saver = tf.train.Saver(self.vars)
        saver.save(sess, path)

    def load(self, path, sess):
        saver = tf.train.Saver(self.vars)
        saver.restore(sess, path)

这种方法运行良好,因为我可以保存并加载我的模型并在那里进行一些计算。 但是,我现在想要在Tensorboard中使用PCA / TSNE可视化每层的激活。

这是我的问题:

我必须在train_model.py中添加嵌入,因为我想在训练后添加它。另一方面,我必须保存通过以下方式获得的嵌入变量:

act = output.eval(feed_dict=d, session=sess)
embedding_var = tf.Variable(act, name='activation_layer_{}'.format(i))

使用我的model.save()功能。 如何将变量添加到会话或将另一个检查点保存到同一目录?

保存变量的最佳方法是什么?你会使用另一个只包含嵌入的检查点吗?

感谢您的帮助,

罗马

0 个答案:

没有答案