什么是pycaffe中的求解器回调函数,我该如何使用它们?

时间:2017-01-05 16:56:38

标签: python neural-network deep-learning caffe pycaffe

查看this PR,我发现可以为on_start对象定义on_gradientcaffe.Solver个回调。

import caffe
solver = caffe.AdamSolver('solver.prototxt')
solver.add_callback(on_start, on_gradient)  # <- ??

on_starton_gradient的对象类型是什么? 这些回调是什么?
如何使用它们(一个例子会很好......)?

1 个答案:

答案 0 :(得分:2)

<强> 1。回调在何处以及如何定义?

回调是Solver的一部分,因此在solver.hpp文件中定义。确切地说,有一个Callback类,如下所示:

  // Invoked at specific points during an iteration
  class Callback {
   protected:
    virtual void on_start() = 0;
    virtual void on_gradients_ready() = 0;

    template <typename T>
    friend class Solver;
  };
  const vector<Callback*>& callbacks() const { return callbacks_; }
  void add_callback(Callback* value) {
    callbacks_.push_back(value);
  }

protected vector这样的回调,它是Solver类的成员。

  vector<Callback*> callbacks_;

因此,这基本上为add_callback类提供了Solver函数,允许人们将类型为Callback的对象添加到向量中。这是为了确保每个回调都有两种方法:on_start()on_gradients_ready()

<强> 2。回调在哪里召唤?

这发生在solver.cpp文件中,step()函数中包含主工作循环。这是主循环部分(为简单起见,有很多东西被删除):

while (iter_ < stop_iter) {

    for (int i = 0; i < callbacks_.size(); ++i) {
        callbacks_[i]->on_start();
    }

    // accumulate the loss and gradient
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
        loss += net_->ForwardBackward();
    }
    loss /= param_.iter_size();

    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }

    ApplyUpdate();

    ++iter_;
}

第3。这在哪里使用?

添加多GPU支持时实现了此回调功能。我使用回调的唯一地方(我知道)是在多个GPU之间同步求解器:

P2PSync中的parallel.hpp类继承自Solver::Callback类,并实现了on_start()on_gradients_ready()方法,这些方法可以同步GPU并最终累积所有渐变更新。

<强> 4。如何使用Python的回调?

正如拉取请求#3020所解释的那样,

  

on_starton_gradient是python函数。

所以它应该是直截了当的。我创建的this Github Gist中显示了一个完整的,可运行的示例。

<强> 5。这有用吗?

由于两个回调函数不接受任何参数,因此您无法简单地使用它们来跟踪丢失或类似的事情。为此,您必须围绕Solver类创建包装函数,并使用两种方法作为回调函数调用add_callback。这允许您使用self.solver.net从回调中访问网络。在以下示例中,我使用on_start回调将数据加载到网络中,并使用on_gradients_ready回调来打印损失函数。

class SolverWithCallback:
    def __init__(self, solver_file):
        self.solver = caffe.SGDSolver(solver_file)
        self.solver.add_callback(self.load, self.loss)

    def solve(self):
        self.solver.solve()

    def load(self):
        inp = np.random.randint(0, 255)
        self.solver.net.blobs['data'].data[...] = inp
        self.solver.net.blobs['labels'].data[...] = 2 * inp

    def loss(self):
        print "loss: " + str(self.solver.net.blobs['loss'].data)

if __name__=='__main__':
    solver = SolverWithCallback('solver.prototxt')
    solver.solve()