CNTK:序列开始标志的数量(1)与序列

时间:2017-07-03 04:07:53

标签: python deep-learning sequence lstm cntk

获取错误:

  

ValueError:Value :: Create ::序列开始标志的数字(1)与序列的数量(356)不匹配,   在

def create_inputs(output_length):
    batch_axis = ct.Axis.default_batch_axis()
    input_seq_axis = ct.Axis('inputAxis')

    input_dynamic_axes = [batch_axis, input_seq_axis]
    input_sequence = ct.input_variable(shape=1, dynamic_axes=input_dynamic_axes)
    label_sequence = ct.input_variable(shape=output_length, dynamic_axes=input_dynamic_axes)

    return input_sequence, label_sequence

def make_model_and_train(model_root_path, epochs, minibatch_dim, output_length, test_minibatches):
    vals = get_data('data.csv')
    train, test, minibatches_per_epoch = make_sets(vals, minibatch_dim, output_length, test_minibatches)

    input_sequence, label_sequence = create_inputs(output_length)

    model = create_model(output_length)

    z = model(input_sequence)

    ce = ct.squared_error(z, label_sequence) 

    lt_per_sample = ct.learning_rate_schedule([(7000, 0.001),(10000, 0.0005)], ct.UnitType.sample, minibatches_per_epoch)
    clipping_threshold_per_sample = 2
    gradient_clipping_with_truncation = True

    learner = ct.momentum_sgd(z.parameters, lt_per_sample, ct.momentum_as_time_constant_schedule(1100),gradient_clipping_threshold_per_sample = clipping_threshold_per_sample, gradient_clipping_with_truncation = gradient_clipping_with_truncation)
    progress_printer = ct.logging.ProgressPrinter(100, tag = 'Training')
    trainer = ct.Trainer(z, (ce), learner, progress_printer)

    print ("Running %d epochs with %d minibatches per epoch" % (epochs, minibatches_per_epoch))
    print('')

    for e in range(epochs):
        mask = [True]
        for b in range(minibatches_per_epoch):
            arguments = ({input_sequence: train[0][b], label_sequence: train[1][b]}, mask)
            mask = [False]
            trainer.train_minibatch(arguments)

            global_minibatch = e*minibatches_per_epoch + b
        if e % 100 == 0 and e != 0:
            model_filename = '%s/%s/%s_epoch_%g.dnn' % (model_root_path, name, name, e+1)
            z.save_model(model_filename)
            print("Saved model to '%s'" % model_filename)

目前

trainer.train_minibatch(arguments)

minibatch_dim是356, output_length是356

我大部分都是从我的其他LSTM中复制了代码,但是我一直收到这个错误。

我该如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

你需要为minibatch中的每个序列提供一个掩码值列表,如果第i个序列是新序列,或者如果它是i-的延续,则列表中的第i个元素为True。来自前一个小批量的序列