从Tensorflow Estimator获取层激活

时间:2018-02-21 22:51:47

标签: tensorflow

我有一个神经网络张量流量估算器,我称之为classifier,我想从网络中的一个层打印出激活,称为pool5

在模型函数中,我调用:

if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {"last_layer": pool5}
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

然后在主脚本中,我有

predictions = classifier.predict(input_fn=input_fn)
print(predictions["last_layer"])

但我得到的错误是

Traceback (most recent call last):
  File "C:/Users/John/AppData/Local/Programs/Python/Python35/Scripts/Estimator_5minutes.py", line 177, in <module>
    tf.app.run()
  File "C:\Users\John\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\platform\app.py", line 124, in run
    _sys.exit(main(argv))
  File "C:/Users/John/AppData/Local/Programs/Python/Python35/Scripts/Estimator_5minutes.py", line 152, in main
    print(predictions["last_layer"])
TypeError: 'generator' object is not subscriptable

1 个答案:

答案 0 :(得分:1)

This may not be the perfect answer, but here is what I did to solve the problem.

predictions = list(classifier.predict(input_fn=input_fn))
scipy.io.savemat('C:/activations.mat', {"activations": predictions})