TLDR:如何在Android上使用冻结张量流图中的变量?
1。我想做什么
我有一个Tensorflow模型,它将内部状态保存在多个变量中,使用:state_var = tf.Variable(tf.zeros(shape, dtype=tf.float32), name='state', trainable=False)
创建。
在推理期间修改此状态:
tf.assign(state_var, new_value)
我现在想在Android上部署该模型。我能够运行Tensorflow示例App。在那里,加载了一个冻结模型,工作正常。
2。从冻结图中恢复变量不起作用
但是,使用freeze_graph script冻结图形时,所有变量都将转换为常量。这适用于网络的权重,但不适用于内部状态。推断失败,并显示以下消息。我将此解释为“分配对常数张量不起作用”
java.lang.RuntimeException: Failed to load model from 'file:///android_asset/model.pb'
at org.tensorflow.contrib.android.TensorFlowInferenceInterface.<init>(TensorFlowInferenceInterface.java:113)
...
Caused by: java.io.IOException: Not a valid TensorFlow Graph serialization: Input 0 of node layer_1/Assign was passed float from layer_1/state:0 incompatible with expected float_ref.
幸运的是,您可以将变量黑名单转换为常量。但是,这也不起作用,因为冻结图现在包含未初始化的变量。
java.lang.IllegalStateException: Attempting to use uninitialized value layer_7/state
3。恢复SavedModel不适用于Android
我尝试过的最后一个版本是使用SavedModel
格式,该格式应包含冻结图和变量。不幸的是,调用恢复方法在Android上不起作用。
SavedModelBundle bundle = SavedModelBundle.load(modelFilename, modelTag);
// produces error:
E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.demo, PID: 27451
java.lang.UnsupportedOperationException: Loading a SavedModel is not supported in Android. File a bug at https://github.com/tensorflow/tensorflow/issues if this feature is important to you at org.tensorflow.SavedModelBundle.load(Native Method)
4。我怎样才能做到这一点?
我不知道还能尝试什么。这是我想象的,但我不知道如何使它工作:
答案 0 :(得分:1)
我已经通过不同的路线解决了这个问题。据我所知,“变量”概念不能像我在Python中习惯的那样在Android上使用(例如,你不能初始化变量,然后在推理期间更新网络的内部状态)。 / p>
相反,您可以使用placehlder和output节点来保留Java代码中的状态,并在每次推理调用时将其提供给网络。
tf.Variable
替换所有tf.placeholder
次出现。形状保持不变。tf.identity(inputs, name='state_output')
在Android上进行推理时,您可以将初始状态提供给网络。
float[] values = {0, 0, 0, ...}; // zeros of the correct shape
inferenceInterface.feed('state', values, ...);
推断后,您将读取生成的网络内部状态
float[] values = new float[output_shape];
inferenceInterface.fetch('state_output', values);
然后记住Java中的这个输出,将其传递到'state'
占位符以进行下一个推理调用。