这段代码的输出是什么意思?

时间:2015-11-25 17:18:02

标签: python deep-learning caffe conv-neural-network pycaffe

我使用this tutorial作为示例来构建我的caffe自定义训练功能。在第15节中有以下代码:

def train():
    niter = 200
    test_interval = 25 
    train_loss = zeros(niter)
    test_acc = zeros(int(np.ceil(niter / test_interval)))

    ### HERE ###
    output = zeros((niter, 8, 10))
    ###      ###

在第8行有一个ndarray(输出),这个代码的含义是什么,它是demensions。 (niter, 8, 10)的含义是什么?为什么niter,为什么是8,为什么是10?我应该根据自己的数据集更改此数组吗?如果是,我应该使用什么尺寸?有人可以解释一下吗?

2 个答案:

答案 0 :(得分:2)

如果你仔细阅读教程,你会发现它处理数字分类,因此 10 类。此外,他们使用技巧将8个示例拼凑在一起(第11节,靠近In [11]:):

  

#我们使用一个小技巧来平铺前八个图像

因此 8 维度。

第15节显示了跟踪网络进度的示例。它保存了每次迭代的输出预测概率。每次迭代有 10 类次 8 示例,并且要跟踪 niter 次迭代。所有这些信息都存储在3D output数组中。

答案 1 :(得分:1)

看起来像是对numpy.zeros的调用,其中shape = (niter, 8, 10)创建了一个200 * 8 * 10的float 0数组。