获取错误:
ValueError:Value :: Create ::序列开始标志的数字(1)与序列的数量(356)不匹配, 在
def create_inputs(output_length):
batch_axis = ct.Axis.default_batch_axis()
input_seq_axis = ct.Axis('inputAxis')
input_dynamic_axes = [batch_axis, input_seq_axis]
input_sequence = ct.input_variable(shape=1, dynamic_axes=input_dynamic_axes)
label_sequence = ct.input_variable(shape=output_length, dynamic_axes=input_dynamic_axes)
return input_sequence, label_sequence
def make_model_and_train(model_root_path, epochs, minibatch_dim, output_length, test_minibatches):
vals = get_data('data.csv')
train, test, minibatches_per_epoch = make_sets(vals, minibatch_dim, output_length, test_minibatches)
input_sequence, label_sequence = create_inputs(output_length)
model = create_model(output_length)
z = model(input_sequence)
ce = ct.squared_error(z, label_sequence)
lt_per_sample = ct.learning_rate_schedule([(7000, 0.001),(10000, 0.0005)], ct.UnitType.sample, minibatches_per_epoch)
clipping_threshold_per_sample = 2
gradient_clipping_with_truncation = True
learner = ct.momentum_sgd(z.parameters, lt_per_sample, ct.momentum_as_time_constant_schedule(1100),gradient_clipping_threshold_per_sample = clipping_threshold_per_sample, gradient_clipping_with_truncation = gradient_clipping_with_truncation)
progress_printer = ct.logging.ProgressPrinter(100, tag = 'Training')
trainer = ct.Trainer(z, (ce), learner, progress_printer)
print ("Running %d epochs with %d minibatches per epoch" % (epochs, minibatches_per_epoch))
print('')
for e in range(epochs):
mask = [True]
for b in range(minibatches_per_epoch):
arguments = ({input_sequence: train[0][b], label_sequence: train[1][b]}, mask)
mask = [False]
trainer.train_minibatch(arguments)
global_minibatch = e*minibatches_per_epoch + b
if e % 100 == 0 and e != 0:
model_filename = '%s/%s/%s_epoch_%g.dnn' % (model_root_path, name, name, e+1)
z.save_model(model_filename)
print("Saved model to '%s'" % model_filename)
目前
trainer.train_minibatch(arguments)
minibatch_dim是356, output_length是356
我大部分都是从我的其他LSTM中复制了代码,但是我一直收到这个错误。
我该如何解决这个问题?
答案 0 :(得分:0)
你需要为minibatch中的每个序列提供一个掩码值列表,如果第i个序列是新序列,或者如果它是i-的延续,则列表中的第i个元素为True。来自前一个小批量的序列