为了在移动设备中保存训练有素的模型,我写了一个如下的模型进行研究。
import tensorflow as tf
with tf.Session() as sess:
def save_model():
path = tf.train.write_graph(sess.graph_def, '/data/data/com.chelexa.tfandroid', 'mnist_50_mlp.pb',as_text=False)
#path_folder=path[33:36]
return tf.to_float(tf.rank(path))
x = tf.placeholder(tf.float32, shape=[None, 784], name="x")
y = tf.placeholder(tf.float32, [None, 10], name="y")
W = tf.Variable(tf.zeros([784, 10]), name="weights")
b = tf.Variable(tf.zeros([10]))
y_out = tf.matmul(x, W) + b
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_out), reduction_indices=[1]))
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_out))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy, name="train")
correct_prediction = tf.equal(tf.argmax(y_out,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="test")
init = tf.variables_initializer(tf.global_variables(), name="init")
save = tf.add(1.0,save_model(),name="save_model")
save_save = tf.add(1.0,save_model(),name="savesave")
然后我用java保存训练有素的模型
TensorFlowInferenceInterface.run(new String[]{"save_model"}, new String[]{}, logStats);
运行调用后可以返回正确的值save == 1,但没有在/data/data/com.chelexa.tfandroid /
中的apk文件夹中保存的mnist_50_mlp.pb文件谢谢!