自定义模型中的集线器模块

时间:2019-11-06 21:23:19

标签: tensorflow tensorflow2.0

我正尝试在不使用顺序API的情况下实现此tensorflow教程(https://www.tensorflow.org/tutorials/keras/text_classification_with_hub)。

本教程使用的Keras顺序API可以正常工作,并且hub_layer是hub.KerasLayer对象。

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

我使用OOP实现了这个简单的模型:

class MyModel(tf.keras.Model):

    def __init__(self, hub_layer):
        super().__init__()
        self.embedding = hub_layer
        self.dense1= layers.Dense(16, activation='relu')
        self.dense2= layers.Dense(1, activation='sigmoid')

    def call(self, x):
        x = self.embedding(x)
        x = self.dense1(x)
        return self.dense2(x)

但是tensorflow不断抛出


    <ipython-input-103-d15973420a95>:11 call  *
        x = self.embedding(x)
    /tensorflow-2.0.0/python3.6/tensorflow_core/python/saved_model/load.py:436 _call_attribute
        return instance.__call__(*args, **kwargs)
    /tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/def_function.py:457 __call__
        result = self._call(*args, **kwds)
    /tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/def_function.py:524 _call
        *args, **kwds)
    /tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/function.py:1650 canonicalize_function_inputs
        self._flat_input_signature)
    /tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/function.py:1716 _convert_inputs_to_signature
        format_error_message(inputs, input_signature))

    ValueError: Python inputs incompatible with input_signature:
      inputs: (
        Tensor("input_1_23:0", shape=(None, 1), dtype=string))
      input_signature: (
        TensorSpec(shape=(None,), dtype=tf.string, name=None))

我做错了什么?我在两个模型上都使用完全相同的拟合和编译函数。

0 个答案:

没有答案