从train_minibatch()收集输出

时间:2017-10-19 18:11:49

标签: cntk

我的代码段的相关部分如下:

 feature_output = network['output'].find_by_name('fc8').outputs
 _, output = trainer.train_minibatch(data, (feature_output))
print(output.keys())
print(output[dict_keys(feature_output]))

它给我一个错误如下:

dict_keys([Output('fc8', [#], [1000])])
Traceback (most recent call last):
  File "trainoverfeataccurate.py", line 325, in <module>
    warm_up=0, max_epochs=epochs)
  File "trainoverfeataccurate.py", line 250, in overfeataccuratetraining
    restore, profiling, print_freq=1)
  File "trainoverfeataccurate.py", line 145, in train_and_test
    print(output[feature_output])
KeyError: (Output('fc8', [#], [1000]),)

我也试过_, output = trainer.train_minibatch(data,{'a' : feature_output}) 但它给了我以下错误 TypeError: cannot convert key of dictionary to N4CNTK8VariableE

使用train_minibatch时收集输出的正确方法是什么?

1 个答案:

答案 0 :(得分:0)

网络的输出是元组,因此您需要获取第一个元素。然后你要找的名字是uid。 E.g。

@Injectable()
export class MessageService {
    private subject = new Subject<any>();
    message$: Observable<any> = this.subject.asObservable();

    sendMessage(message: string) {
       console.log('send message');
        this.subject.next(message);
    }

    clearMessage() {
       this.subject.next();
    }
}