我尝试在图像分类问题,google colab上以及当我运行此代码时进行转移学习:
# Setup input shape to the model
INPUT_SHAPE = [None, 244, 244, 3] # batch, height, width, colour channels
# Setup output shape of the model
OUTPUT_SHAPE = 120
# Setup model URL form TensorFlow Hub
MODEL_URL = "https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4"
# Create a function which builds a Keras model
def create_model(input_shape=INPUT_SHAPE, output_shape=OUTPUT_SHAPE, model_url=MODEL_URL):
print("Building model with:", MODEL_URL)
# Setup the model layers
model = tf.keras.Sequential([
hub.KerasLayer(MODEL_URL), # Layer 1 (input layer)
tf.keras.layers.Dense(units=OUTPUT_SHAPE,
activation="softmax") # Layer 2 (output layer)
])
# Compile the model
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"]
)
# Build the model
model.build(INPUT_SHAPE)
return model
model = create_model()
model.summary()
我收到此错误:
Building model with: https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-43-0fd4f47c95c0> in <module>()
----> 1 model = create_model()
2 model.summary()
5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
263 except Exception as e: # pylint:disable=broad-except
264 if hasattr(e, 'ag_error_metadata'):
--> 265 raise e.ag_error_metadata.to_exception(e)
266 else:
267 raise
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow_hub/keras_layer.py:229 call *
result = smart_cond.smart_cond(training,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:486 _call_attribute **
return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:580 __call__
result = self._call(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:627 _call
self._initialize(args, kwds, add_initializers_to=initializers)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:506 _initialize
*args, **kwds))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2446 _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2777 _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2667 _create_graph_function
capture_by_value=self._capture_by_value),
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:981 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:441 wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/function_deserialization.py:261 restored_function_body
"\n\n".join(signature_descriptions)))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (4 total):
* Tensor("inputs:0", shape=(None, 244, 244, 3), dtype=float32)
* False
* False
* 0.99
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* False
* True
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
Option 2:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* True
* False
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
Option 3:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* True
* True
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
Option 4:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* False
* False
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
我尝试每夜安装tf-night和旧版本的tensorflow来查看它是否可以运行,但是没有用。我还尝试了tensorflow_hub的旧版本,这也导致了更多错误。我试图将笔记本计算机恢复出厂设置,然后重试,但出现相同的错误。如果我注释掉model.build(INPUT_SHAPE)
,则不会显示该错误。除此之外,我不确定该怎么解决。
答案 0 :(得分:0)
代码需要进行两项更改。
ImageNet Dataset
上使用它。< / li>
Input Layer
,如下所示: tf.keras.layers.InputLayer(input_shape=(224,224,3))
完整的工作代码如下所示:
import tensorflow as tf
import tensorflow_hub as hub
# Setup input shape to the model
INPUT_SHAPE = [None, 244, 244, 3] # batch, height, width, colour channels
# Setup output shape of the model
OUTPUT_SHAPE = 120
# Setup model URL form TensorFlow Hub
MODEL_URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
# Create a function which builds a Keras model
def create_model(input_shape=INPUT_SHAPE, output_shape=OUTPUT_SHAPE, model_url=MODEL_URL):
print("Building model with:", MODEL_URL)
# Setup the model layers
model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(224,224,3)),
hub.KerasLayer(MODEL_URL, output_shape=[1280],
trainable=False),
tf.keras.layers.Dense(units=OUTPUT_SHAPE,
activation="softmax")
])
# Compile the model
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"]
)
# Build the model
model.build(INPUT_SHAPE)
return model
model = create_model()
model.summary()
以上代码的输出如下所示:
Building model with: https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4
Model: "sequential_8"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
keras_layer_8 (KerasLayer) (None, 1280) 2257984
_________________________________________________________________
dense_8 (Dense) (None, 120) 153720
=================================================================
Total params: 2,411,704
Trainable params: 153,720
Non-trainable params: 2,257,984
有关将MobileNet_V2
用于Dataset
的详细信息,请参阅此全面的TF Hub Tutorial。