tf.keras模型的输入形状不匹配。下面给出了带有堆栈跟踪的代码块。我正在使用hub.keraslayer作为我的第一层。该模型正准备使用Tensor Flow Federated(TFF)进行训练。模型的输入是可变长度的字符串。请提出一个出路。
#Making a Tensorflow Model
from tensorflow import keras
def create_keras_model():
encoder = hub.load("https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1")
return tf.keras.models.Sequential([
hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(16, activation='relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.Accuracy()])
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-
packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling
BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with
constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-
packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling
BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with
constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:Model was constructed with shape (None,) for input
Tensor("keras_layer_input:0", shape=(None,), dtype=string), but it was called on an input
with incompatible shape (None, None).
WARNING:tensorflow:Model was constructed with shape (None,) for input
Tensor("keras_layer_input:0", shape=(None,), dtype=string), but it was called on an input
with incompatible shape (None, None).
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-27-68fa27e65b7e> in <module>()
3 model_fn,
4 client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
----> 5 server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
18 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in
wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
ValueError: in user code:
/usr/local/lib/python3.6/dist-
packages/tensorflow_federated/python/learning/federated_averaging.py:91 __call__ *
num_examples_sum = dataset.reduce(
/usr/local/lib/python3.6/dist-
packages/tensorflow_federated/python/learning/model_utils.py:152 forward_pass *
self._model.forward_pass(batch_input, training), model_lib.BatchOutput)
/usr/local/lib/python3.6/dist-
packages/tensorflow_federated/python/learning/keras_utils.py:391 forward_pass *
return self._forward_pass(batch_input, training=training)
/usr/local/lib/python3.6/dist-
packages/tensorflow_federated/python/learning/keras_utils.py:359 _forward_pass *
predictions = self._keras_model(inputs, training=training)
/usr/local/lib/python3.6/dist-packages/tensorflow_hub/keras_layer.py:222 call *
result = f()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:486
_call_attribute **
return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:580 __call__
result = self._call(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:618 _call
results = self._stateful_fn(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2419 __call__
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2735
_maybe_define_function
*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2238
canonicalize_function_inputs
self._flat_input_signature)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2305
_convert_inputs_to_signature
format_error_message(inputs, input_signature))
ValueError: Python inputs incompatible with input_signature:
inputs: (
Tensor("batch_input:0", shape=(None, None), dtype=string))
input_signature: (
TensorSpec(shape=(None,), dtype=tf.string, name=None))