如何在Google Colab上下载TensorFlow训练模型?

时间:2018-05-04 04:02:58

标签: python tensorflow machine-learning model

我是新手,正在研究机器学习,我在Google Colab创建了一个模型。

我的目标是在Android应用中使用该模型进行离线预测。所以我需要下载经过培训的模型。

我唯一知道的是我需要将我的模型保存为.pb文件才能制作我的Android应用。我一直在寻找,也许有答案,但它们太短暂,我无法理解 所以需要非常详细的答案。

This是我的Test.ipynb文件,有人可以花一点时间训练该模型,看看我们是否可以将其下载到本地驱动器。

2 个答案:

答案 0 :(得分:0)

这是我从Collab保存和下载模型文件的方法。

  1. 只保存可训练变量(这些是我的存储/恢复fns):
  2. 代码如下:

    def store(sess_var, model_path):        
        if model_path is not None:
            saver = tf.train.Saver(var_list=tf.trainable_variables())
            save_path = saver.save(sess_var, model_path)
            print("Model saved in path: %s" % save_path)
        else:
            print("Model path is None - Nothing to store")
    
    
    
    def restore(sess_var, model_path):
        if model_path is not None:
            if os.path.exists("{}.index".format(model_path)):
                saver = tf.train.Saver(var_list=tf.trainable_variables())
                saver.restore(sess_var, model_path)
                print("Model at %s restored" % model_path)
            else:
                print("Model path does not exist, skipping...")
        else:
            print("Model path is None - Nothing to restore")
    
    1. 压缩存储模型的目录 - 确保其中不包含其他内容:!tar -czvf model.tar.gz models/

    2. 下载型号:

      from google.colab import files files.download('model.tar.gz')

    3. 由于您只存储可训练变量而不是整个会话,因此模型的大小很小,因此可以下载。请务必使用Chrome - 我无法在Firefox上使用最后一个代码段。

答案 1 :(得分:0)

应该这样做:

import tensorflow as tf
from google.colab import files


# Specify export directory and use tensorflow to save your_model
export_dir = './saved_model'
tf.saved_model.save(your_model, export_dir=export_dir)

请注意,导出目录包含多个文件,但如果您只想下载 .pb 文件,请执行以下操作。

# Download the model
files.download(export_dir + '/saved_model.pb')