我已经在Tensorflow上使用AlexNet训练了一个模型。既然我已经成功地训练了它,现在不知道如何保存它并将其部署到我的android应用程序中。谁能帮我这个?在线上有很多可用的帮助,但是我怀疑从哪里开始以及实际上如何做。
这是训练它的代码:
%%time
#Changing settings for GPU running.
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.allocator_type = 'BFC'
#Training and saving the result
with tf.Session(config=config) as sess:
sess.run(init)
summary_writer = tf.summary.FileWriter(TRAIN_DIR, graph=tf.get_default_graph())
for i in range(epochs):
for j in range(0,steps,step_size):
_ , c , summary,d = sess.run([train,cross_entropy,merged_summary_op,acc],feed_dict=
{x:X[j:j+step_size] , y_true:Y[j:j+step_size]
,hold_prob1:0.5,hold_prob2:0.5,hold_prob3:0.5,hold_prob4:0.5})
summary_writer.add_summary(summary, i * total_batch + j)
acc_train.append(d)
mean_of_cross_entropy = sess.run(cross_entropy,feed_dict={x:cv_x,y_true:cv_y
,hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0})
mean_of_acc = sess.run(acc,feed_dict={x:cv_x ,y_true:cv_y
,hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0})
cross_entropy_list.append(mean_of_cross_entropy)
acc_list.append(mean_of_acc)
print(i,mean_of_cross_entropy,mean_of_acc)
saver.save(sess, "C:\\Users\\blessie\\Desktop\\LEAF RECOGNITION v5\\Models\\CNN_MC.ckpt")
print("test accuracy = ",np.mean([sess.run(acc,feed_dict={x:test_x[:230],y_true:test_y[:230]
,hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0}),sess.run(acc,feed_dict =
{x:test_x [230:460],y_true:test_y [230:460]
,hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0}),sess.run(acc,feed_dict =
{x:test_x [460:],y_true:test_y [460:],hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0}))))))))
print(“ cross_entropy loss =”,np.mean([sess.run(cross_entropy,feed_dict =
{x:test_x [:230],y_true:test_y [:230]
,hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0}),sess.run(cross_entropy,feed_dict =
{x:test_x [230:460],y_true:test_y [230:460]
,hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0}),sess.run(cross_entropy,feed_dict =
{x:test_x [460:],y_true:test_y [460:],hold_prob1:1.0,hold_prob2:1.0,hold_prob3:1.0,hold_prob4:1.0}))))))