我试图理解(和建立)BVLC/caffe PR#4681。
为此,我编写了自己的虚拟数据生成。但是,在尝试运行训练时,由于以下错误
,文件根本未加载F0302 11:40:20.412214 10108 hdf5.cpp:50]不支持的数据类型类:H5T_ENUM
以下是我的原型和代码的代码对于数据本身:
layer {
name: "data"
type: "HDF5Data"
top: "data"
top: "seq_ind"
top: "labels"
hdf5_data_param {
source: "dummy_data.txt"
batch_size: 40
}
}
layer {
name: "lstm1"
type: "LSTM"
bottom: "data"
bottom: "seq_ind"
top: "lstm1"
recurrent_param {
# num + 1 for blank label! (last one)
num_output: 40
weight_filler {
type: "gaussian"
std: 0.1
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "ip1"
type: "InnerProduct"
bottom: "lstm1"
top: "ip1"
inner_product_param {
num_output: 4
weight_filler {
type: "gaussian"
std: 0.1
}
axis: 2
}
}
layer {
name: "rev1"
type: "Reverse"
bottom: "data"
top: "rev_data"
}
layer {
name: "lstm2"
type: "LSTM"
bottom: "rev_data"
bottom: "seq_ind"
top: "lstm2"
recurrent_param {
# num + 1 for blank label! (last one)
num_output: 40
weight_filler {
type: "gaussian"
std: 0.1
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "ip2"
type: "InnerProduct"
bottom: "lstm2"
top: "ip2"
inner_product_param {
num_output: 4
weight_filler {
type: "gaussian"
std: 0.1
}
axis: 2
}
}
layer {
name: "rev2"
type: "Reverse"
bottom: "ip2"
top: "rev2"
}
layer {
name: "eltwise-sum"
type: "Eltwise"
bottom: "ip1"
bottom: "rev2"
eltwise_param { operation: SUM }
top: "sum"
}
layer {
name: "loss"
type: "CTCLoss"
bottom: "sum"
bottom: "seq_ind"
bottom: "labels"
top: "ctc_loss"
}
import numpy as np
import h5py
import functools
def reduce_f(x,y):
if y:
x[-1] += 1
elif x[-1] != 0 :
if x[-1] < 3:
x[-1] = 1
elif x[-1] < 6:
x[-1] = 2
else:
x[-1] = 3
x.append(0)
return x
def label_gen(data):
res = functools.reduce(reduce_f,data,[0])
return res[:-1]
def store_hdf5(filename, mapping):
"""Function to store data mapping to a hdf5 file
Args:
filename (str): The output filename
mapping (dic): A dictionary containing mapping from name to numpy data
The complete mapping will be stored as single datasets
in the h5py file.
"""
print("Storing hdf5 file %s" % filename)
with h5py.File(filename, 'w') as hf:
for label, data in mapping.items():
print(" adding dataset %s with shape %s" % (label, data.shape))
hf.create_dataset(label, data=data)
print(" finished")
def generate_data(T_, C_, lab_len_):
"""Function to generate dummy data
The data is generated non randomly by a defined function.
The sequence length is exactly T_.
The target sequence will be 1 if there are less than 3 continuous 1s
2 if less than 6
3 otherwise
Args:
T_ (int): The number of timesteps (this value must match the batch_size of the caffe net)
C_ (int): The number of channgels/labels
lab_len_(int): The label size that must be smaller or equals T_. This value
will also be used as the maximum allowed label. The label size in the network
must therefore be 6 = 5 + 1 (+1 for blank label)
Returns:
data (numpy array): A numpy array of shape (T_, 1, C_) containing dummy data
sequence_indicators (numpy array): A numpy array of shape (T_, 1) indicating the
sequence
labels (numpy array): A numpy array of shape (T_, 1) defining the label sequence.
labels will be -1 for all elements greater than T_ (indicating end of sequence).
"""
assert(lab_len_ <= T_)
# this is an arbitrary function to generate data not randomly
data = np.random.rand(T_,C_) > 0.3
# The sequence length is exactly T_.
sequence_indicators = np.full(T_, 1, dtype=np.float32)
sequence_indicators[0] = 0
# The label lengh is lab_len_
labels = np.asarray(label_gen(data))
labels = np.append(labels,[-1]*(T_-labels.size))
return data, sequence_indicators, labels
if __name__=="__main__":
# generate the dummy data
# not that T_ = 40 must match the batch_size of 40 in the network setup
# as required by the CTC alorithm to see the full sequence
# The label length and max label is set to 5. Use 6 = 5 + 1 for the label size in the network
# to add the blank label
numdatas = 100
datas = []
seq_inds = []
labelsarr = []
for _ in np.arange(numdatas):
data, sequence_indicators, labels = generate_data(40, 1, 3)
datas.append(data)
seq_inds.append(sequence_indicators)
labelsarr.append(labels)
datas = np.asarray(datas)
datas = np.swapaxes(datas,0,1)
seq_inds = np.asarray(seq_inds)
seq_inds = np.swapaxes(seq_inds,0,1)
labelsarr = np.asarray(labelsarr)
labelsarr = np.swapaxes(labelsarr,0,1)
# and write it to the h5 file
store_hdf5("dummy_data.h5", {"data" : datas, "seq_ind" : seq_inds, "labels" : labelsarr})