这是net.prototxt的数据层:
layer {
name: "csv"
type: "MemoryData"
top: "data"
top: "label"
include {
phase: TRAIN
}
memory_data_param {
batch_size: 10
channels: 1
width: 14
height: 1
}
}
我找到了函数
MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n)
但我不知道我应该在哪里添加此功能?
现在我想知道在哪里 是标签数据来自?因为我只在Datum struct中看到标签关键字。
答案 0 :(得分:0)
当我通过pycaffe模块训练网络时,我总是使用MemoryData层。就像这个
solver = caffe.SGDSolver(solver_file)
X = np.zeros((batch_size, 3, im_height, im_width), dtype = np.float32)
Y = np.zeros((batch_size, ), dtype = np.float32)
# put processed images into X, put labels into Y
solver.net.set_input_arrays(X,Y)
你可以参考caffe_root / python / caffe / pycaffe.py和_caffe.cpp了解详情