tf.train.write_graph在java调用时无法保存模型

时间:2018-03-22 08:29:46

标签: android tensorflow graph

为了在移动设备中保存训练有素的模型,我写了一个如下的模型进行研究。

import tensorflow as tf

with tf.Session() as sess:
    def save_model():
        path = tf.train.write_graph(sess.graph_def, '/data/data/com.chelexa.tfandroid', 'mnist_50_mlp.pb',as_text=False)
        #path_folder=path[33:36]
        return tf.to_float(tf.rank(path))
    x = tf.placeholder(tf.float32, shape=[None, 784], name="x")
    y = tf.placeholder(tf.float32, [None, 10], name="y")

    W = tf.Variable(tf.zeros([784, 10]), name="weights")
    b = tf.Variable(tf.zeros([10]))

    y_out = tf.matmul(x, W) + b

    #cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_out), reduction_indices=[1]))
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_out))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy, name="train")

    correct_prediction = tf.equal(tf.argmax(y_out,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="test")

    init = tf.variables_initializer(tf.global_variables(), name="init")
    save = tf.add(1.0,save_model(),name="save_model")
    save_save = tf.add(1.0,save_model(),name="savesave")

然后我用java保存训练有素的模型

TensorFlowInferenceInterface.run(new String[]{"save_model"}, new String[]{}, logStats);

运行调用后可以返回正确的值save == 1,但没有在/data/data/com.chelexa.tfandroid /

中的apk文件夹中保存的mnist_50_mlp.pb文件

谢谢!

0 个答案:

没有答案