tf.keras模型的输入形状不匹配

时间:2020-07-07 10:44:35

标签: tf.keras

tf.keras模型的输入形状不匹配。下面给出了带有堆栈跟踪的代码块。我正在使用hub.keraslayer作为我的第一层。该模型正准备使用Tensor Flow Federated(TFF)进行训练。模型的输入是可变长度的字符串。请提出一个出路。

#Making a Tensorflow Model
from tensorflow import keras

def create_keras_model():
 encoder = hub.load("https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1")
 return tf.keras.models.Sequential([
  hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
  keras.layers.Dense(32, activation='relu'),
  keras.layers.Dense(16, activation='relu'),
  keras.layers.Dense(1, activation='sigmoid'),
])

def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
  keras_model,
  input_spec=preprocessed_example_dataset.element_spec,
  loss=tf.keras.losses.BinaryCrossentropy(),
  metrics=[tf.keras.metrics.Accuracy()])

iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

WARNING:tensorflow:From /usr/local/lib/python3.6/dist- 
packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling 
BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with 
constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist- 
packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling 
BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with 
constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:Model was constructed with shape (None,) for input 
Tensor("keras_layer_input:0", shape=(None,), dtype=string), but it was called on an input 
with incompatible shape (None, None).
WARNING:tensorflow:Model was constructed with shape (None,) for input 
Tensor("keras_layer_input:0", shape=(None,), dtype=string), but it was called on an input 
with incompatible shape (None, None).
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-27-68fa27e65b7e> in <module>()
  3     model_fn,
  4     client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
----> 5     server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

18 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in 
wrapper(*args, **kwargs)
966           except Exception as e:  # pylint:disable=broad-except
967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
969             else:
970               raise

ValueError: in user code:

/usr/local/lib/python3.6/dist- 
packages/tensorflow_federated/python/learning/federated_averaging.py:91 __call__  *
    num_examples_sum = dataset.reduce(
/usr/local/lib/python3.6/dist- 
 packages/tensorflow_federated/python/learning/model_utils.py:152 forward_pass  *
    self._model.forward_pass(batch_input, training), model_lib.BatchOutput)
/usr/local/lib/python3.6/dist- 
packages/tensorflow_federated/python/learning/keras_utils.py:391 forward_pass  *
    return self._forward_pass(batch_input, training=training)
/usr/local/lib/python3.6/dist- 
packages/tensorflow_federated/python/learning/keras_utils.py:359 _forward_pass  *
    predictions = self._keras_model(inputs, training=training)
/usr/local/lib/python3.6/dist-packages/tensorflow_hub/keras_layer.py:222 call  *
    result = f()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:486 
_call_attribute  **
    return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:580 __call__
    result = self._call(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:618 _call
    results = self._stateful_fn(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2419 __call__
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2735 
_maybe_define_function
    *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2238 
canonicalize_function_inputs
    self._flat_input_signature)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2305 
_convert_inputs_to_signature
    format_error_message(inputs, input_signature))

ValueError: Python inputs incompatible with input_signature:
  inputs: (
    Tensor("batch_input:0", shape=(None, None), dtype=string))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.string, name=None))

0 个答案:

没有答案