无法还原RNN模型

时间:2018-06-29 08:42:02

标签: tensorflow lstm restore recurrent-neural-network

我已经在Tensorflow中训练了RNN模型,并使用tf.train.Saver()保存了它,但是当我尝试恢复它时,它会很生气。我一直在寻找原因,但仍然找不到任何解决方案。在训练期间,我使用features/testing.fest来查找准确性,该准确性没有任何错误。但是当我在还原后尝试相同的操作时,会出现错误

RNN模型值也与我训练的相同。

这是predict.py的代码

import tensorflow as tf
import numpy as np
import pickle

from models import RecurNet_Models
from preprocess_data import *

timesteps = 36
input_size = 52
hidden_features = 512
lexicon_file = 'dataset.lexi'
saved_model_path = 'trained_model/model'
model = RecurNet_Models.Simple_RNN(hidden_features=hidden_features, no_classes=2, timesteps=timesteps)

x = tf.placeholder('float', [None, timesteps, input_size])
y = tf.placeholder('float', [None, 2])

def predict(x, model=None, saved_model_path=None, string=None, y=None):
    prediction = model(x)

    saver = tf.train.Saver()

    correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver.restore(sess, saved_model_path)

        print("Accuracy: {}".format(accuracy.eval({x: test_img, y: test_labels})))



with open(lexicon_file, 'rb') as f:
    lexicon = pickle.load(f)

with open('features/testing.feat', 'rb') as f:
    [string, label] = pickle.load(f)

string = np.reshape(string, [-1, timesteps, input_size])
predict(x, model=model.model, saved_model_path=saved_model_path, string=string, y=label)

这是错误:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1322, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [564,2048] rhs shape= [104,208]
     [[Node: save/Assign_3 = Assign[T=DT_FLOAT, _class=["loc:@rnn/basic_lstm_cell/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/basic_lstm_cell/kernel, save/RestoreV2:3)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "predict.py", line 42, in <module>
    predict(x, model=model.model, saved_model_path=saved_model_path, string=string, y=label)
  File "predict.py", line 29, in predict
    saver.restore(sess, saved_model_path)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 1752, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [564,2048] rhs shape= [104,208]
     [[Node: save/Assign_3 = Assign[T=DT_FLOAT, _class=["loc:@rnn/basic_lstm_cell/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/basic_lstm_cell/kernel, save/RestoreV2:3)]]

Caused by op 'save/Assign_3', defined at:
  File "predict.py", line 42, in <module>
    predict(x, model=model.model, saved_model_path=saved_model_path, string=string, y=label)
  File "predict.py", line 21, in predict
    saver = tf.train.Saver()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 1284, in __init__
    self.build()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 1296, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 1333, in _build
    build_save=build_save, build_restore=build_restore)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 781, in _build_internal
    restore_sequentially, reshape)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 422, in _AddRestoreOps
    assign_ops.append(saveable.restore(saveable_tensors, shapes))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py", line 113, in restore
    self.op.get_shape().is_fully_defined())
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py", line 219, in assign
    validate_shape=validate_shape)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 60, in assign
    use_locking=use_locking, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1740, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [564,2048] rhs shape= [104,208]
     [[Node: save/Assign_3 = Assign[T=DT_FLOAT, _class=["loc:@rnn/basic_lstm_cell/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/basic_lstm_cell/kernel, save/RestoreV2:3)]]

我正在Google Colab中训练模型,测试相同代码时出现的错误可以解决这个问题,希望对您有帮助

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-41-4c07d32efced> in <module>()
     41 
     42 string = np.reshape(string, [-1, timesteps, input_size])
---> 43 predict(x, model=model.model, saved_model_path=saved_model_path, string=string, y=label)

<ipython-input-41-4c07d32efced> in predict(x, model, saved_model_path, string, y)
     18 def predict(x, model=None, saved_model_path=None, string=None, y=None):
     19     tf.reset_default_graph()
---> 20     prediction = model(x)
     21 
     22     saver = tf.train.Saver()

/content/sentiment/models/RecurNet_Models.py in model(self, x)
     15 
     16         lstm_instance = rnn.BasicLSTMCell(self.hidden_features)
---> 17         rnn_output, states = rnn.static_rnn(lstm_instance, x, dtype=tf.float32)
     18 
     19         output = tf.add(tf.matmul(rnn_output[-1], weights), biases)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py in static_rnn(cell, inputs, initial_state, dtype, sequence_length, scope)
   1313             state_size=cell.state_size)
   1314       else:
-> 1315         (output, state) = call_cell()
   1316 
   1317       outputs.append(output)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py in <lambda>()
   1300         varscope.reuse_variables()
   1301       # pylint: disable=cell-var-from-loop
-> 1302       call_cell = lambda: cell(input_, state)
   1303       # pylint: enable=cell-var-from-loop
   1304       if sequence_length is not None:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py in __call__(self, inputs, state, scope, *args, **kwargs)
    337     # method.  See the class docstring for more details.
    338     return base_layer.Layer.__call__(self, inputs, state, scope=scope,
--> 339                                      *args, **kwargs)
    340 
    341 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/layers/base.py in __call__(self, inputs, *args, **kwargs)
    327 
    328       # Actually call layer
--> 329       outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
    330 
    331     if not context.executing_eagerly():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    686 
    687       if not in_deferred_mode:
--> 688         outputs = self.call(inputs, *args, **kwargs)
    689         if outputs is None:
    690           raise ValueError('A layer\'s `call` method should return a Tensor '

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py in call(self, inputs, state)
    636 
    637     gate_inputs = math_ops.matmul(
--> 638         array_ops.concat([inputs, h], 1), self._kernel)
    639     gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
    640 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, name)
   1944       are both set to True.
   1945   """
-> 1946   with ops.name_scope(name, "MatMul", [a, b]) as name:
   1947     if transpose_a and adjoint_a:
   1948       raise ValueError("Only one of transpose_a and adjoint_a can be True.")

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in __enter__(self)
   6001       if self._values is None:
   6002         self._values = []
-> 6003       g = _get_graph_from_inputs(self._values)
   6004       self._g_manager = g.as_default()
   6005       self._g_manager.__enter__()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _get_graph_from_inputs(op_input_list, graph)
   5661         graph = graph_element.graph
   5662       elif original_graph_element is not None:
-> 5663         _assert_same_graph(original_graph_element, graph_element)
   5664       elif graph_element.graph is not graph:
   5665         raise ValueError("%s is not from the passed-in graph." % graph_element)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item, item)
   5597   if original_item.graph is not item.graph:
   5598     raise ValueError("%s must be from the same graph as %s." % (item,
-> 5599                                                                 original_item))
   5600 
   5601 

ValueError: Tensor("rnn/basic_lstm_cell/kernel:0", shape=(563, 2048), dtype=float32_ref) must be from the same graph as Tensor("concat:0", shape=(?, 563), dtype=float32).

0 个答案:

没有答案