# Read the data from the database
def read_input(model, batch_size, db, db_type):
# load the data
data, binary, dist = model.TensorProtosDBInput(
[], ['data', 'binary'], batch_size=batch_size,
db=db, db_type=db_type)
# Get the absolute distance.
data = model.StopGradient(data, data)
binary = model.StopGradient(binary, binary)
return data, binary
def define_network(model, inp, b_classifier):
fc1 = brew.fc(model, inp, 'fc1', dim_in=2, dim_out=10)
fc1 = brew.relu(model, fc1, fc1)
fc2 = brew.fc(model, fc1, 'fc2', dim_in=10, dim_out=20)
fc2 = brew.relu(model, fc2, fc2)
fc3 = brew.fc(model, fc2, 'fc3', dim_in=20, dim_out=10)
fc3 = brew.relu(model, fc3, fc3)
nDimOut = 1+int(b_classifier)
predict = brew.fc(model, fc3, 'predict', dim_in=10, dim_out=nDimOut)
if b_classifier:
softmax = brew.softmax(model, predict, 'softmax')
return softmax
return predict
def add_classifier_loos_and_acc(model, softmax, gt):
crossEnt = model.LabelCrossEntropy([softmax, gt], 'crossEnt')
loss = model.AveragedLoss(crossEnt, "loss")
accuracy = brew.accuracy(model, [softmax, gt], "accuracy")
return loss, accuracy
def add_training_ops(model, loss):
optimizer.build_rms_prop(model, base_learning_rate=0.001)
batchSize = 128
# Create a training network
train_model = model_helper.ModelHelper(name="binary_train")
data, binary = read_input(train_model, batch_size=batchSize, db='Data/' + folder + '/train.minidb', db_type='minidb')
softmax = define_network(train_model, data, True)
loss, _ = add_classifier_loos_and_acc(train_model, softmax, binary)
add_training_ops(train_model, loss)
# Train and check.
# creating the network
workspace.CreateNet(train_model.net, overwrite=True)
# Train once
testRes = workspace.FetchBlob('softmax')
gt = workspace.FetchBlob('binary')
crossEnt = workspace.FetchBlob('crossEnt')
avgCrossEnt = np.mean(crossEnt)
loss = workspace.FetchBlob('loss')
当我运行代码并进入 workspace.RunNet(train_model.net)的行时,我的代码崩溃了:
RuntimeError: [enforce fail at accuracy_op.cc:29] label.ndim() == 1. 2 vs 1 Error from operator:
input: "softmax" input: "binary" output: "accuracy" name: "" type: "Accuracy"
We've got an error while stopping in post-mortem: <type 'exceptions.KeyboardInterrupt'>
我试着理解它大约2天而我什么也没得到。 有谁知道我做错了什么?
答案 0 :(得分:0)
我发现了问题。 问题是二进制文件的形状为(batchSize,1),它的大小应为(batchSize,)。 添加FlattenToVec()解决了问题