如果从caffe中的MemoryData读取数据,如何读取标签数据

时间:2016-09-24 06:16:56

标签: caffe

这是net.prototxt的数据层:

layer {
    name: "csv"
    type: "MemoryData"
    top: "data"
    top: "label"
    include {
        phase: TRAIN
    }
    memory_data_param {
        batch_size: 10
        channels: 1
        width: 14
        height: 1
    }
}

我找到了函数

MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n)

但我不知道我应该在哪里添加此功能?

现在我想知道在哪里  是标签数据来自?因为我只在Datum struct中看到标签关键字。

1 个答案:

答案 0 :(得分:0)

当我通过pycaffe模块训练网络时,我总是使用MemoryData层。就像这个

solver = caffe.SGDSolver(solver_file)

X = np.zeros((batch_size, 3, im_height, im_width), dtype = np.float32)
Y = np.zeros((batch_size, ), dtype = np.float32)
# put processed images into X, put labels into Y

solver.net.set_input_arrays(X,Y)

你可以参考caffe_root / python / caffe / pycaffe.py和_caffe.cpp了解详情