swift coreML:没有“ options”参数的预测函数

时间:2018-09-03 06:54:29

标签: swift lstm coreml

在迅速documentation中,MLModel具有两个预测函数

  1. 功能预测(来自:MLFeatureProvider)-> MLFeatureProvider。根据给定的输入要素值预测输出要素值。
  2. 功能预测(来自:MLFeatureProvider,选项:MLPredictionOptions)-> MLFeatureProvider。根据给定的输入要素值预测输出要素值。

但是,在我自动生成的MLModel类中,没有生成带有options参数的函数。以下代码是我自动生成的预测函数。

func prediction(input: coreML_1denses_80iters_k213_2Input) throws -> coreML_1denses_80iters_k213_2Output {
    let outFeatures = try model.prediction(from: input)
    let result = coreML_1denses_80iters_k213_2Output(output1: outFeatures.featureValue(for: "output1")!.multiArrayValue!, lstm_1_h_out: outFeatures.featureValue(for: "lstm_1_h_out")!.multiArrayValue!, lstm_1_c_out: outFeatures.featureValue(for: "lstm_1_c_out")!.multiArrayValue!)
    return result
}

func prediction(input1: MLMultiArray, input2: MLMultiArray, lstm_1_h_in: MLMultiArray?, lstm_1_c_in: MLMultiArray?) throws -> coreML_1denses_80iters_k213_2Output {
    let input_ = coreML_1denses_80iters_k213_2Input(input1: input1, input2: input2, lstm_1_h_in: lstm_1_h_in, lstm_1_c_in: lstm_1_c_in)
    return try self.prediction(input: input_)
}

注意: 顺便说一句,为什么我想使用“ options”参数找到预测函数的原因是此错误消息:

[coreml] Cannot evaluate a sequence of length 600, which is longer than maximum of 400.

我发现一个solution,它在预测函数中添加了forceCPU标志。可以在MLPredictionOptions中找到名为“ usesCPUOnly”的选项。但是,我找不到放置选项的地方。

2 个答案:

答案 0 :(得分:0)

一种方法是在自动生成的类的extension中添加您自己的预测方法(在另一个源文件中)。

答案 1 :(得分:0)

感谢@Matthijs Hollemans。我找到了解决方案。只需编写我自己的扩展名并覆盖这样的预测函数即可。

func prediction(input: model_1denses_50iters_k213Input) throws -> model_1denses_50iters_k213Output {
    let options = MLPredictionOptions()
    options.usesCPUOnly = true
    let outFeatures = try model.prediction(from: input, options:options)
    let result = model_1denses_50iters_k213Output(output1: outFeatures.featureValue(for: "output1")!.multiArrayValue!, lstm_85_h_out: outFeatures.featureValue(for: "lstm_85_h_out")!.multiArrayValue!, lstm_85_c_out: outFeatures.featureValue(for: "lstm_85_c_out")!.multiArrayValue!)
    return result
}