Pycaffe Net forward_all()函数不起作用

时间:2016-10-21 16:47:02

标签: python caffe pycaffe

在我训练了一些CNN的权重之后,我决定使用相同的网络架构来进行预测。 我设置了数据batch_size = 64

我可以正常运行pred_net.forward()函数,我可以从blobs['prob']获取预测的类。

我的数据集中有20000个样本。如果我将forward()函数调用i次,我会将64*i个样本转发到网上。所以我不能覆盖20000个样本而不转发一些样本两次。

因此,我尝试了forward_all()功能。但是我得到了一个没有任何有用信息的例外。我不知道出了什么问题。

我希望forward()forward_all()相似(但不是)。

以下是我的代码部分和错误消息:

pred_net = caffe.Net(pred_net_proto_file, 'kg_trained.caffemodel', caffe.TEST)

pred_net.forward_all()
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-6-cefd35621a35> in <module>()
----> 1 pred_net.forward_all()

/home/microos/Space/caffe-master/python/caffe/pycaffe.pyc in _Net_forward_all(self, blobs, **kwargs)
    197         all_outs[out] = np.asarray(all_outs[out])
    198     # Discard padding.
--> 199     pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
    200     if pad:
    201         for out in all_outs:

StopIteration: 

希望我能清楚地描述一些事情。

1 个答案:

答案 0 :(得分:2)

您必须将要转发的数据传递给forward_all()函数:

pred_net = caffe.Net(pred_net_proto_file, 'kg_trained.caffemodel', caffe.TEST)
pred_net.forward_all(data=data_samples)

假设您的CNN需要形状图像(3,224,224),那么您的data_samples应具有形状(20000,3,224,224)