我有一个神经网络张量流量估算器,我称之为classifier
,我想从网络中的一个层打印出激活,称为pool5
。
在模型函数中,我调用:
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {"last_layer": pool5}
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
然后在主脚本中,我有
predictions = classifier.predict(input_fn=input_fn)
print(predictions["last_layer"])
但我得到的错误是
Traceback (most recent call last):
File "C:/Users/John/AppData/Local/Programs/Python/Python35/Scripts/Estimator_5minutes.py", line 177, in <module>
tf.app.run()
File "C:\Users\John\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\platform\app.py", line 124, in run
_sys.exit(main(argv))
File "C:/Users/John/AppData/Local/Programs/Python/Python35/Scripts/Estimator_5minutes.py", line 152, in main
print(predictions["last_layer"])
TypeError: 'generator' object is not subscriptable
答案 0 :(得分:1)
This may not be the perfect answer, but here is what I did to solve the problem.
predictions = list(classifier.predict(input_fn=input_fn))
scipy.io.savemat('C:/activations.mat', {"activations": predictions})