使用Android上

时间:2018-06-07 14:58:36

标签: android python tensorflow

TLDR:如何在Android上使用冻结张量流图中的变量?

1。我想做什么

我有一个Tensorflow模型,它将内部状态保存在多个变量中,使用:state_var = tf.Variable(tf.zeros(shape, dtype=tf.float32), name='state', trainable=False)创建。

在推理期间修改此状态:

tf.assign(state_var, new_value)

我现在想在Android上部署该模型。我能够运行Tensorflow示例App。在那里,加载了一个冻结模型,工作正常。

2。从冻结图中恢复变量不起作用

但是,使用freeze_graph script冻结图形时,所有变量都将转换为常量。这适用于网络的权重,但不适用于内部状态。推断失败,并显示以下消息。我将此解释为“分配对常数张量不起作用”

java.lang.RuntimeException: Failed to load model from 'file:///android_asset/model.pb'
at org.tensorflow.contrib.android.TensorFlowInferenceInterface.<init>(TensorFlowInferenceInterface.java:113)
...
Caused by: java.io.IOException: Not a valid TensorFlow Graph serialization: Input 0 of node layer_1/Assign was passed float from layer_1/state:0 incompatible with expected float_ref.

幸运的是,您可以将变量黑名单转换为常量。但是,这也不起作用,因为冻结图现在包含未初始化的变量。

java.lang.IllegalStateException: Attempting to use uninitialized value layer_7/state

3。恢复SavedModel不适用于Android

我尝试过的最后一个版本是使用SavedModel格式,该格式应包含冻结图和变量。不幸的是,调用恢复方法在Android上不起作用。

SavedModelBundle bundle = SavedModelBundle.load(modelFilename, modelTag);

// produces error:

E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.demo, PID: 27451
     java.lang.UnsupportedOperationException: Loading a SavedModel is not supported in Android. File a bug at https://github.com/tensorflow/tensorflow/issues if this feature is important to you at org.tensorflow.SavedModelBundle.load(Native Method)

4。我怎样才能做到这一点?

我不知道还能尝试什么。这是我想象的,但我不知道如何使它工作:

  1. 找出一种在Android上初始化变量的方法
  2. 找出一种冻结模型的不同方法,这样初始化程序op也可能是冻结图的一部分,可以从Android运行
  3. 了解RNN / LSTM是否在内部实施,因为这些在推理过程中也应该具有相同的使用变量的要求(我假设LSTM可以在Android上部署)。
  4. ???

1 个答案:

答案 0 :(得分:1)

我已经通过不同的路线解决了这个问题。据我所知,“变量”概念不能像我在Python中习惯的那样在Android上使用(例如,你不能初始化变量,然后在推理期间更新网络的内部状态)。 / p>

相反,您可以使用placehlder和output节点来保留Java代码中的状态,并在每次推理调用时将其提供给网络。

  • tf.Variable替换所有tf.placeholder次出现。形状保持不变。
  • 我还定义了一个用于读取输出的附加节点。 (也许你可以简单地阅读占位符本身,我没有尝试过。)tf.identity(inputs, name='state_output')
  • 在Android上进行推理时,您可以将初始状态提供给网络。

    float[] values = {0, 0, 0, ...}; // zeros of the correct shape inferenceInterface.feed('state', values, ...);

  • 推断后,您将读取生成的网络内部状态

    float[] values = new float[output_shape]; inferenceInterface.fetch('state_output', values);

    然后记住Java中的这个输出,将其传递到'state'占位符以进行下一个推理调用。