如何对RNN模型进行预测

时间:2019-10-16 12:19:00

标签: python tensorflow machine-learning recurrent-neural-network

经过训练的RNN模型保存在export_path中,我想使用以下代码进行预测:

export_path = '/content/gdrive/My Drive/' +'/model/'+'20161004044008'

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["myTag"], export_path)         
    graph = tf.get_default_graph()

#    print(graph.get_operations())
    input = graph.get_tensor_by_name('input:0')
    output = graph.get_tensor_by_name('output:0')    
    print(sess.run(output,
                feed_dict={input: df_2}))

但发生错误:

INFO:tensorflow:Restoring parameters from /content/gdrive/My Drive//model/20161004044008/variables/variables
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-16-48c85ab7d8ff> in <module>()
      9     input = graph.get_tensor_by_name('input:0')
     10     output = graph.get_tensor_by_name('output:0')
---> 11     print(sess.run(output,feed_dict={input: df_2}))
     12 
     13 #compare: outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages(test_data), isTraining:False})

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1154                 'Cannot feed value of shape %r for Tensor %r, '
   1155                 'which has shape %r' %
-> 1156                 (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
   1157           if not self.graph.is_feedable(subfeed_t):
   1158             raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (126, 60480) for Tensor 'input:0', which has shape '(?, 252, 1)'

有人可以帮忙吗?谢谢

0 个答案:

没有答案