如何在不写入文件系统的情况下从Google存储桶还原Tensorflow模型?

时间:2018-11-02 20:37:25

标签: google-app-engine tensorflow

我有一个2gb的Tensorflow模型,我想将其添加到App Engine上的Flask项目中,但是我似乎找不到任何文档说明我正在尝试做的事情。

由于App Engine不允许写入文件系统,因此我将模型的文件存储在Google Bucket中,并尝试从那里恢复模型。这些是那里的文件:

  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta
  • 检查点

在本地工作,我可以使用

with tf.Session() as sess:
    logger.info("Importing model into TF")
    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore(sess, model.ckpt)

使用Flask的@before_first_request将模型加载到内存中。

一旦在App Engine上,我就可以做到这一点:

blob = bucket.get_blob('blob_name')
filename = os.path.join(model_dir, blob.name)
blob.download_to_filename(filename)

然后执行相同的还原。但是App Engine不允许。

是否可以将这些文件流式传输到Tensorflow的还原功能中,而不必将文件写入文件系统?

2 个答案:

答案 0 :(得分:1)

我实际上并没有使用Tensorflow,答案基于文档和GAE相关知识。

通常,在GAE中将GCS对象用作文件来避免缺少可写文件系统访问,它依赖于2种替代方法之一,而不仅仅是传递文件名以直接读取/写入(可以您的应用代码(和/或它可能正在使用的任何第三者实用程序/库)对GCS对象进行的处理:

在您的特定情况下,似乎tf.train.import_meta_graph()调用支持传递MetaGraphDef协议缓冲区(即原始数据)而不是应该从中加载文件名的文件名:

  

Args:

     
      
  • meta_graph_or_file:包含MetaGraphDef的{​​{1}}协议缓冲区或文件名(包括路径)。
  •   

因此应该可以从GCS 还原模型,大致如下:

MetaGraphDef

但是,从快速的文档扫描保存/检查点模式回到GCS可能很棘手,save()似乎想将数据写入磁盘本身。但是我并没有挖得太深。

答案 1 :(得分:1)

Dan Cornilescu提出了一些技巧并进行了深入研究后,我发现Tensorflow使用名为MetaGraphDef的函数构建了ParseFromString,所以我最终要做的是:

from google.cloud import storage
from tensorflow import MetaGraphDef

client = storage.Client()
bucket = client.get_bucket(Config.MODEL_BUCKET)
blob = bucket.get_blob('model.ckpt.meta')
model_graph = blob.download_as_string()

mgd = MetaGraphDef()
mgd.ParseFromString(model_graph)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(mgd)