由于TF2中的tfVariable问题,简单的RNN模型不起作用

时间:2019-10-02 20:39:08

标签: tensorflow

我正在尝试建立一个简单的模型并保存未经训练的图层。 (我稍后会训练它)。我试图在不依赖Keras层的情况下使用tensorflow核心API,以便我可以更直接地控制我使用的内容并最大限度地提高与TFLite的兼容性。

import numpy as np
import tensorflow as tf

class BasicModel(tf.Module):
    def __init__(self):
        self.const = None

    @tf.function(input_signature=[
            tf.TensorSpec(shape=[None,20],dtype=tf.int32),
    ])
    def rnn(self, captions):
        # ENCODER
        weights = tf.Variable(tf.random.normal([10000, 724]))#, shape=[vocab_size,embedding_dimension], name="embedding_weights")
        embedding_output = tf.nn.embedding_lookup(weights,captions)
        #activation is tanh for GRUCell
        sequence = tf.unstack(embedding_output,num=20, axis=1) 
        cell = tf.compat.v1.nn.rnn_cell.GRUCell(20)
        print(sequence)
        gru_layer = tf.compat.v1.nn.static_rnn(cell, sequence, dtype=tf.float32)
        return gru_layer

root = BasicModel()
concrete_function = root.rnn.get_concrete_function()
tf.saved_model.save(root,"model",concrete_function)

我希望有一个未经训练的模型可以保存,但是我收到一个错误:

Traceback (most recent call last):
  File "model_tensorflow_2.py", line 24, in <module>
    concrete_function = root.rnn.get_concrete_function()#tf.constant(images), tf.constant(captions), tf.constant(cap_lens))
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 782, in get_concrete_function
    return self._stateless_fn.get_concrete_function(*args, **kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1891, in get_concrete_function
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2658, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

    model_tensorflow_2.py:13 rnn  *
        weights = tf.Variable(tf.random.normal([10000, 724]))#, shape=[vocab_size,embedding_dimension], name="embedding_weights")
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

1 个答案:

答案 0 :(得分:0)

tf.function不允许在非首次调用时创建变量,因为其语义尚不清楚:是否应在每次调用时重新创建变量?应该隐式缓存它们吗? (请参阅TF Summit 2019的“ tf.function and AutoGraph”演讲中的this bit

一个常见的解决方法是让助手函数创建变量,并确保每个实例最多调用一次:

class BasicModel(tf.Module):
    # ...

    def _create_parameters(self, ...):
        self._weights = tf.Variable(...)
        self._parameters_created = True

    def rnn(self, ...):
        if not self._parameters_created:
            self._create_parameters(...)
        ...