将Albert转换为tflite(Albert通过bert-for-tf2在Keras中实现)

时间:2019-12-11 06:05:11

标签: tensorflow tensorflow-lite

我很难将albert(更具体地说是albert_base模型)转换为tflite。这是我的代码,使用bert-for-tf2(https://github.com/kpe/bert-for-tf2)定义我的模型<-谢谢您的出色实现...

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Flatten, AveragePooling1D
from tensorflow.keras.models import Model
import bert
import sentencepiece as spm


def load_pretrained_albert():
    model_name = "albert_base"
    albert_dir = bert.fetch_tfhub_albert_model(model_name, ".models")
    model_params = bert.albert_params(model_name)
    l_bert = bert.BertModelLayer.from_params(model_params, name="albert")

    # use in Keras Model here, and call model.build()
    max_seq_len = 128

    l_input_ids = Input(shape=(max_seq_len,), dtype='float32', name="l_input_ids")

    output = l_bert(l_input_ids)                             
    pooled_output = AveragePooling1D(pool_size=max_seq_len, data_format="channels_last")(output)
    pooled_output = Flatten()(pooled_output)   # poooled_output: [batch_size, embedding_dimension=768]

    model = Model(inputs=[l_input_ids], outputs=[pooled_output])
    model.build(input_shape=(None, max_seq_len))

    bert.load_albert_weights(l_bert, albert_dir)

    return model

但是当我尝试使用以下代码将模型转换为tflite时,

converter = tf.lite.TFLiteConverter.from_keras_model(m)
tflite_model = converter.convert()

发生以下错误:

File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\lite\python\lite.py", line 405, in convert
    self._funcs[0], lower_control_flow=False)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\convert_to_constants.py", line 575, in convert_variables_to_constants_v2
    converted_input_indices)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\convert_to_constants.py", line 371, in _construct_concrete_function
    new_output_names)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 620, in function_from_graph_def
    wrapped_import = wrap_function(_imports_graph_def, [])
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 598, in wrap_function
    collections={}),
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 83, in __call__
    return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 89, in wrapped
    return fn(*args, **kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 618, in _imports_graph_def
    importer.import_graph_def(graph_def, name="")
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Input 0 of node model/albert/embeddings/word_embeddings/embedding_lookup was passed float from model/albert/embeddings/word_embeddings/embedding_lookup/Read/ReadVariableOp/resource:0 incompatible with expected resource.

因此,我尝试将模型保存为save_model格式,并尝试使用以下代码进行转换:

converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_path')
tflite_model = converter.convert()

但是,再次出现相同的错误消息。

File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\lite\python\lite.py", line 405, in convert
    self._funcs[0], lower_control_flow=False)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\convert_to_constants.py", line 575, in convert_variables_to_constants_v2
    converted_input_indices)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\convert_to_constants.py", line 371, in _construct_concrete_function
    new_output_names)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 620, in function_from_graph_def
    wrapped_import = wrap_function(_imports_graph_def, [])
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 598, in wrap_function
    collections={}),
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 83, in __call__
    return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 89, in wrapped
    return fn(*args, **kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\eager\wrap_function.py", line 618, in _imports_graph_def
    importer.import_graph_def(graph_def, name="")
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "C:\Users\hygki\Anaconda3\lib\site-packages\tensorflow_core\python\framework\importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Input 0 of node StatefulPartitionedCall/model/albert/embeddings/word_embeddings/embedding_lookup was passed float from Func/StatefulPartitionedCall/input/_2:0 incompatible with expected resource.

所以我的理解是,当预期的数据类型不是float时,embedding_lookup会被float填充。但是预期的数据类型是什么?有什么办法可以找到吗?此外,是否有解决此问题的方法?

对于将albert_base转换为tflite formtat所做的任何努力,将不胜感激!

3 个答案:

答案 0 :(得分:1)

关于“ IdentityN”错误,您是否尝试使用SELECT_TF_OPS进行转换? https://www.tensorflow.org/lite/guide/ops_select

答案 1 :(得分:0)

有趣的是,我已经为这个问题苦苦挣扎了好几个小时,但是直到我上传问题后,我才解决了问题...

所以解决方案是,使用Tensorflow版本1.15.0! 使用tensorflow2似乎会引起问题。

但是,由于它尚不支持'IdentityN'操作,因此我仍然无法将模型转换为tflite。我不认为自己可以编写自定义操作,所以我应该等待tflite更新。...

答案 2 :(得分:0)

使用官方仓库中的ALBERT 2.0(tf 2.0)模型。将https://github.com/google-research/ALBERT/blob/master/modeling.py#L516更改为tf.gather(tf.identity(embedding_table), input_ids)。然后尝试像以前一样使用tflite进行转换。如果没有,请在这里评论。