如何从Tensorflow示例中将推荐Keras模型导出到tflite?

时间:2020-03-29 08:44:21

标签: python tensorflow machine-learning keras tf-lite

我对Tensorflow和机器学习非常陌生。我在Tensorflow的github页面上使用了一个示例:Tensorflow / models / official / recommendation /

https://github.com/tensorflow/models/tree/master/official/recommendation

我正在运行movielens.py,然后按照说明运行ncf_keras_main.py。一切正常,我在定向文件夹〜/ ml-small-model /中确实有输出:

ml-small-model/
    checkpoint
    checkpoint.data-0000-of-00001
    checkpoint.index
    summaries/
        train/
            events.out.tfevents.1585469662.DESKTOP-CAEB4HN.24292.515.v2
            plugins/
                2020_03_29_08_14_24/
                    DESKTOP-CAEB4HN.input_pipeline.pb
                    DESKTOP-CAEB4HN.kernel_stats.pb
                    DESKTOP-CAEB4HN.overview_page.pb
                    DESKTOP-CAEB4HN.tensorflow_stats.pb
                    DESKTOP-CAEB4HN.trace.json.gz
        validation/
            events.out.tfevents.1585469676.DESKTOP-CAEB4HN.24292.7424.v2

问题是,我想将此训练后的模型作为tflite文件部署到Firebase mlkit。我看过Google网站上的示例以及Stack Overflow上的其他帖子,似乎有两种方法可以解决此问题:使用SavedModel类型的.pb文件或使用tfliteconverter。我只看到了转换单个.pb文件的示例,但我有多个,因此我选择了后者。我将此代码写在ncf_keras_main.py中run_ncf(_)方法的return语句之前的行中,因为它似乎是程序之前运行的最后一件事:

  #convert and export to tflite
  tflite_model_files = pathlib.Path("C:/ml-small-model/ml-small.tflite")
  converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  tflite_model = converter.convert()
  tflite_model_files.write_bytes(tflite_model)
  open("converted_model.tflite", "wb").write(tflite_model)

但是,我遇到了ValueError:

Traceback (most recent call last):
  File ".\ncf_keras_main.py", line 578, in <module>
    app.run(main)
  File "!\python\python37\lib\site-packages\absl\app.py", line 299, in run
    _run_main(main, args)
  File "!\python\python37\lib\site-packages\absl\app.py", line 250, in _run_main
    sys.exit(main(argv))
  File ".\ncf_keras_main.py", line 573, in main
    run_ncf(FLAGS)
  File ".\ncf_keras_main.py", line 347, in run_ncf
    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\lite\python\lite.py", line 503, in from_keras_model
    concrete_func = func.get_concrete_function()
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 958, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 864, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 506, in _initialize
    *args, **kwds))
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\function.py", line 2667, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\saving\saving_utils.py", line 132, in _wrapped_model
    outputs = model(inputs, training=False)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 933, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\engine\network.py", line 719, in call
    convert_kwargs_to_constants=base_layer_utils.call_context().saving)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\engine\network.py", line 888, in _run_internal_graph
    output_tensors = layer(computed_tensors, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 933, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
    return func(*args, **kwargs)
  File ".\ncf_keras_main.py", line 80, in call
    self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1360, in add_metric
    self._add_metric(value, aggregation, name)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2148, in _add_metric
    metric_obj(value)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\metrics.py", line 214, in __call__
    replica_local_fn, *args, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\distribute\distributed_training_utils.py", line 1133, in call_replica_local_fn
    return fn(*args, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\metrics.py", line 194, in replica_local_fn
    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\utils\metrics_utils.py", line 90, in decorated
    update_op = update_state_fn(*args, **kwargs)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\keras\metrics.py", line 356, in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "~\movie_data\models\official\recommendation\venv\lib\site-packages\tensorflow\python\distribute\values.py", line 1004, in assign_add
    "SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.

I censored the directory since it had some personal info, so I apologize if it looks weird

我已经查找了该错误,但似乎根本没有与其他人发生。

我最终有两个问题:

  1. 如何使用tensorflow模型推荐示例获取tflite输出文件? (也许我可以处理pb文件吗?)
  2. 如果获取tflite并将其部署到firebase比另一个想法更复杂,那么我还能如何部署此ml模型以获取建议?

其他任何技巧也都很好。

0 个答案:

没有答案