我正在尝试建立一个简单的模型并保存未经训练的图层。 (我稍后会训练它)。我试图在不依赖Keras层的情况下使用tensorflow核心API,以便我可以更直接地控制我使用的内容并最大限度地提高与TFLite的兼容性。
import numpy as np
import tensorflow as tf
class BasicModel(tf.Module):
def __init__(self):
self.const = None
@tf.function(input_signature=[
tf.TensorSpec(shape=[None,20],dtype=tf.int32),
])
def rnn(self, captions):
# ENCODER
weights = tf.Variable(tf.random.normal([10000, 724]))#, shape=[vocab_size,embedding_dimension], name="embedding_weights")
embedding_output = tf.nn.embedding_lookup(weights,captions)
#activation is tanh for GRUCell
sequence = tf.unstack(embedding_output,num=20, axis=1)
cell = tf.compat.v1.nn.rnn_cell.GRUCell(20)
print(sequence)
gru_layer = tf.compat.v1.nn.static_rnn(cell, sequence, dtype=tf.float32)
return gru_layer
root = BasicModel()
concrete_function = root.rnn.get_concrete_function()
tf.saved_model.save(root,"model",concrete_function)
我希望有一个未经训练的模型可以保存,但是我收到一个错误:
Traceback (most recent call last):
File "model_tensorflow_2.py", line 24, in <module>
concrete_function = root.rnn.get_concrete_function()#tf.constant(images), tf.constant(captions), tf.constant(cap_lens))
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 782, in get_concrete_function
return self._stateless_fn.get_concrete_function(*args, **kwargs)
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1891, in get_concrete_function
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2658, in bound_method_wrapper
return wrapped_fn(*args, **kwargs)
File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:
model_tensorflow_2.py:13 rnn *
weights = tf.Variable(tf.random.normal([10000, 724]))#, shape=[vocab_size,embedding_dimension], name="embedding_weights")
/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
return cls._variable_v2_call(*args, **kwargs)
/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
shape=shape)
/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:65 getter
return captured_getter(captured_previous, **kwargs)
/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
"tf.function-decorated function tried to create "
ValueError: tf.function-decorated function tried to create variables on non-first call.
答案 0 :(得分:0)
tf.function
不允许在非首次调用时创建变量,因为其语义尚不清楚:是否应在每次调用时重新创建变量?应该隐式缓存它们吗? (请参阅TF Summit 2019的“ tf.function and AutoGraph”演讲中的this bit。
一个常见的解决方法是让助手函数创建变量,并确保每个实例最多调用一次:
class BasicModel(tf.Module):
# ...
def _create_parameters(self, ...):
self._weights = tf.Variable(...)
self._parameters_created = True
def rnn(self, ...):
if not self._parameters_created:
self._create_parameters(...)
...