如何在颤振中使用经过训练的张量流模型?

时间:2021-03-04 11:59:43

标签: python flutter tensorflow tensorflow2.0 flutter-web

我已经训练了一个 tensorflow 模型来预测输入文本的下一个单词。我将其保存为 .h5 文件。

我可以在另一个 python 代码中使用该模型来预测单词,如下所示:

import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.models import load_model

model = load_model('model.h5')
model.compile(
    loss = "categorical_crossentropy",
    optimizer = "adam",
    metrics = ["accuracy"]
)

data = open("dataset.txt").read()
corpus = data.lower().split("\n")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)

seed_text = input()

sequence_text = tokenizer.texts_to_sequences([seed_text])[0]
padded_sequence = np.array(pad_sequences([sequence_text], maxlen = 11 -1))
predicted = np.argmax(model.predict(padded_sequence))
<块引用>

有没有办法直接在里面使用那个模型 颤振,我可以从 TextField() 获取输入,然后按 按钮,显示预测词??

2 个答案:

答案 0 :(得分:0)

您不能直接在 Flutter 中使用 .h5 文件。 您需要将其转换为 .tflite 文件并使用该文件,或者创建一个 REST API。

将其转换为 .tflite 文件是最简单的。 您可以查看以下文章了解更多详情: https://medium.com/analytics-vidhya/run-cnn-model-in-flutter-10c944cadcba

如果您想创建 REST API,请查看这篇文章: https://medium.com/analytics-vidhya/deploy-ml-models-using-flask-as-rest-api-and-access-via-flutter-app-7ce63d5c1f3b

答案 1 :(得分:0)

步骤

  1. 将模型转换为 .tflite 模型。
# https://www.tensorflow.org/lite/convert/#convert_a_savedmodel_recommended_

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)
  1. 将 tflite 模型添加到 App 目录。我通常将模型添加到 assets/ 目录中。
android/
assets/
    model.tflite
ios/
lib/
  1. 将 tflite 作为依赖项添加到 pubspec.yaml
dependencies:
  flutter:
    sdk: flutter
  tflite: ^1.0.5
  .
  .
  1. 在您的 dart 脚本中运行推理。例如,以下代码片段是关于如何对图像运行推理的示例脚本,其中 labels.txt 是包含类的文本文件:
import 'package:tflite/tflite.dart';
.
.
.

class _MyAppState extends State<MyApp> {
  . . .
  @override
  void initState() {
    super.initState();
    _loading = true;

    loadModel().then((value) {
      setState(() {
        _loading = false;
      });
    });
  }

  classifyImage(File image) async {
    var output = await Tflite.runModelOnImage(
      path: image.path,
      numResults: 2,
      threshold: 0.5,
      imageMean: 127.5,
      imageStd: 127.5,
    );
    setState(() {
      _loading = false;
      _outputs = output;
    });
  }

  loadModel() async {
    await Tflite.loadModel(
      model: "assets/model_unquant.tflite",
      labels: "assets/labels.txt",
    );
  }
  @override
  void dispose() {
    Tflite.close();
    super.dispose();
  }
 . . .
}


旁注

tflite 插件不支持文本分类 AFAIK,如果您想专门进行文本分类,我建议使用 tflite_flutter 插件。以下是使用文本分类插件的文章链接。

Text Classification using TensorFlow Lite Plugin for Flutter

相关问题