查看this PR,我发现可以为on_start
对象定义on_gradient
和caffe.Solver
个回调。
import caffe
solver = caffe.AdamSolver('solver.prototxt')
solver.add_callback(on_start, on_gradient) # <- ??
on_start
和on_gradient
的对象类型是什么?
这些回调是什么?
如何使用它们(一个例子会很好......)?
答案 0 :(得分:2)
<强> 1。回调在何处以及如何定义?
回调是Solver的一部分,因此在solver.hpp
文件中定义。确切地说,有一个Callback
类,如下所示:
// Invoked at specific points during an iteration
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
template <typename T>
friend class Solver;
};
const vector<Callback*>& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value);
}
和protected vector这样的回调,它是Solver
类的成员。
vector<Callback*> callbacks_;
因此,这基本上为add_callback
类提供了Solver
函数,允许人们将类型为Callback
的对象添加到向量中。这是为了确保每个回调都有两种方法:on_start()
和on_gradients_ready()
。
<强> 2。回调在哪里召唤?
这发生在solver.cpp
文件中,step()
函数中包含主工作循环。这是主循环部分(为简单起见,有很多东西被删除):
while (iter_ < stop_iter) {
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
ApplyUpdate();
++iter_;
}
第3。这在哪里使用?
添加多GPU支持时实现了此回调功能。我使用回调的唯一地方(我知道)是在多个GPU之间同步求解器:
P2PSync
中的parallel.hpp
类继承自Solver::Callback
类,并实现了on_start()
和on_gradients_ready()
方法,这些方法可以同步GPU并最终累积所有渐变更新。
<强> 4。如何使用Python的回调?
正如拉取请求#3020所解释的那样,
on_start
和on_gradient
是python函数。
所以它应该是直截了当的。我创建的this Github Gist中显示了一个完整的,可运行的示例。
<强> 5。这有用吗?
由于两个回调函数不接受任何参数,因此您无法简单地使用它们来跟踪丢失或类似的事情。为此,您必须围绕Solver
类创建包装函数,并使用两种方法作为回调函数调用add_callback
。这允许您使用self.solver.net
从回调中访问网络。在以下示例中,我使用on_start
回调将数据加载到网络中,并使用on_gradients_ready
回调来打印损失函数。
class SolverWithCallback:
def __init__(self, solver_file):
self.solver = caffe.SGDSolver(solver_file)
self.solver.add_callback(self.load, self.loss)
def solve(self):
self.solver.solve()
def load(self):
inp = np.random.randint(0, 255)
self.solver.net.blobs['data'].data[...] = inp
self.solver.net.blobs['labels'].data[...] = 2 * inp
def loss(self):
print "loss: " + str(self.solver.net.blobs['loss'].data)
if __name__=='__main__':
solver = SolverWithCallback('solver.prototxt')
solver.solve()