CNTK python API:如何从训练模型中获得预测?

时间:2017-03-07 11:15:13

标签: python eval prediction cntk

我有一个训练有素的模型,我使用CNTK.load_model()函数加载。我正在查看CNTK git repo上的MNIST Tutorial作为模型评估代码的参考。我创建了一个数据读取器(MinibatchSource个对象)并尝试运行model.eval(mb) mb = minibatch_source.next_minibatch(...)(类似于this answer

但是,我收到以下错误消息

Traceback (most recent call last):
    File "LID_test.py", line 162, in <module>
        test_and_evaluate()
    File "LID_test.py", line 159, in test_and_evaluate
        predictions = model.eval(mb)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/ops/functions.py", line 228, in eval
        _, output_map = self.forward(arguments, self.outputs, device=device, as_numpy=as_numpy)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/utils/swig_helper.py", line 62, in wrapper
        result = f(*args, **kwds)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/ops/functions.py", line 354, in forward
        None, device)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/utils/__init__.py", line 393, in sanitize_var_map
        if len(arguments) < len(op_arguments):
TypeError: object of type 'Variable' has no len()

我的模型中没有input_variable名为'Variable',我认为没有任何理由可以解决此错误。

P.S。:我的输入是稀疏输入(单热)

1 个答案:

答案 0 :(得分:2)

您有几个选择:

  • 将一组数据作为numpy数组(CNTK 202教程中的实例)传递,其中onehot数据作为numpy数组传入。

    pred = model.eval({model.arguments [0]:[onehot]})

  • 读取minibatch数据并将其传递给eval函数

    eval_input_map = {input:reader_eval.streams.features}
    eval_data = reader_eval.next_minibatch(eval_minibatch_size,                                   input_map = eval_input_map) mydata = eval_data [输入] .value 预测= model.eval(mydata)