恢复和评估预训练的LSTM模型:张量不是该图的元素

时间:2019-05-21 11:25:03

标签: tensorflow lstm

我试图恢复一个预先训练的LSTM模型,并用它来评估新数据。但是它一直说张量不是该图的元素。我尝试了许多解决方案,但都没有结果。

def main():
    graph = tf.Graph()

with graph.as_default():

        x = tf.placeholder(tf.float32, shape=[None, 8], name='Input')
        y = tf.placeholder(tf.float32, shape=[None, _NUM_CLASSES], name='Output')

        global_step = tf.Variable(initial_value=0, trainable=False, name='global_step')
        keep_prob_ = tf.placeholder(tf.float32, name='keep')
        learning_rate_ = tf.placeholder(tf.float32, shape=[], name='learning_rate')

with graph.as_default():

        lstm_in = tf.layers.dense(x, lstm_size, activation=None)

        lstm_in = tf.split(lstm_in, 1, 0)

        lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
        drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_)
        cell = tf.contrib.rnn.MultiRNNCell([drop] * lstm_layers)
        initial_state = cell.zero_state(_BATCH_SIZE, tf.float32)

with graph.as_default():
        outputs, final_state = tf.contrib.rnn.static_rnn(cell, lstm_in, dtype=tf.float32,
                                                 initial_state = initial_state)

        softmax = tf.layers.dense(outputs[-1], units=_NUM_CLASSES, name="logits")

        y_pred_cls = tf.argmax(softmax, axis=1)

        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=softmax, labels=y))
        tf.summary.scalar("Loss",loss)
        train_op = tf.train.AdamOptimizer(learning_rate_)

        gradients = train_op.compute_gradients(loss)
        capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients]
        optimizer = train_op.apply_gradients(capped_gradients)


        correct_prediction = tf.equal(y_pred_cls, tf.argmax(y, axis=1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar("Accuracy/train", accuracy)

with graph.as_default():
    saver = tf.train.Saver()
merged = tf.summary.merge_all()
sess = tf.Session(graph=graph)

        print("Trying to restore last checkpoint ...")
        last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH)
        saver.restore(sess, save_path=last_chk_path)

def get_emg_array(msg):
        global counter, predicted_class_last, array

        val_state = sess.run(cell.zero_state(_BATCH_SIZE, tf.float32))
        emg_data = numpy.array([list(msg.data)])
        predicted_class = sess.run(y_pred_cls, feed_dict={x: emg_data, keep_prob_: 1.0, learning_rate_: learning_rate})

sess.close()

if __name__ == '__main__':
    main()
  

[错误] [1558437001.632465]:错误的回调:回溯(最近一次调用最近):文件   “ /opt/ros/kinetic/lib/python2.7/dist-packages/rospy/topics.py”,第   750,在_invoke_callback中       cb(msg)文件“ /home/luke/bionic_hand_ws/src/bionic_hand/scripts/EMG_LSTM.py”,行   114,在get_emg_array中       val_state = sess.run(cell.zero_state(_BATCH_SIZE,tf.float32))文件   “ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   929行,正在运行       run_metadata_ptr)文件“ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   _run中的第1137行       self._graph,提取,feed_dict_tensor,feed_handles = feed_handles)文件   “ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第471行,在 init 中       self._fetch_mapper = _FetchMapper.for_fetch(fetches)文件“ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第261行,在for_fetch中       返回_ListFetchMapper(fetch)文件“ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第370行,在 init       self._mappers = [_FetchMapper.for_fetch(fetch)用于在提取中进行提取]文件   “ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第261行,在for_fetch中       返回_ListFetchMapper(fetch)文件“ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第370行,在 init       self._mappers = [_FetchMapper.for_fetch(fetch)用于在提取中进行提取]文件   “ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第271行,在for_fetch中       返回_ElementFetchMapper(fetches,contraction_fn)文件“ /home/luke/venv2/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,   第307行,在 init       '张量。 (%s)'%(fetch,str(e)))ValueError:无法将获取参数解释为张量。   (张量   Tensor(“ MultiRNNCellZeroState_78 / DropoutWrapperZeroState / BasicLSTMCellZeroState / zeros:0”,   shape =(128,192),dtype = float32)不是此图的元素。)

0 个答案:

没有答案