将张量流模型保存到文件

时间:2016-06-23 19:26:13

标签: python-2.7 tensorflow pickle

我创建了一个张量流模型,我希望将其保存到文件中,以便以后可以预测它。特别是,我需要保存:

  • input_placeholder
    = tf.placeholder(tf.float32, [None, iVariableLen])
  • solution_space
    = tf.nn.sigmoid(tf.matmul(input_placeholder, weight_variable) + bias_variable)
  • 会话
    = tf.Session()

我已经尝试过使用pickle,它可以在sklearn二值化器等其他对象上运行,但不是上面的,我在底部得到了错误。

我如何挑剔:

import pickle
with open(sModelSavePath, 'w') as fiModel:
    pickle.dump(dModel, fiModel)

其中dModel是一个包含我想要保留的所有对象的字典,我用它来进行拟合。

关于如何挑选张量流对象的任何建议?

错误消息

pickle.dump(dModel, fiModel)
...
    raise TypeError, "can't pickle %s objects" % base.__name__
TypeError: can't pickle module objects

2 个答案:

答案 0 :(得分:7)

我解决这个问题的方法是pickleing Sklearn对象,比如二进制化器,并使用tensorflow's inbuilt save functions作为实际模型:

保存张量流模型
1)按照惯例建立模型 2)使用tf.train.Saver()保存会话。例如:

oSaver = tf.train.Saver()

oSess = oSession
oSaver.save(oSess, sModelPath)  #filename ends with .ckpt

3)这会将该会话中的所有可用变量等保存到其变量名称中。

加载张量流模型
1)需要重新初始化整个流程。换句话说,需要声明变量,权重,偏差,损失函数等,然后将tf.initialize_all_variables()传递给oSession.run()进行初始化。 2)该会话现在需要传递给加载器。我抽象了流程,所以我的加载器看起来像这样:

dAlg = tf_training_algorithm()  #defines variables etc and initializes session

oSaver = tf.train.Saver()
oSaver.restore(dAlg['oSess'], sModelPath)

return {
    'oSess': dAlg['oSess'],
    #the other stuff I need from my algorithm, like my solution space etc
}

3)您需要预测的所有对象都需要从初始化中获取,在我的情况下,这些对象位于dAlg

PS:像这样的泡菜:

with open(sSavePathFilename, 'w') as fiModel:
    pickle.dump(dModel, fiModel)

with open(sFilename, 'r') as fiModel:
    dModel = pickle.load(fiModel)

答案 1 :(得分:0)

您应该将项目保存为两个独立的部分,一个用于tensorflow个对象,另一个用于其他对象。我建议您使用以下工具:

  1. tf.saved_model:您要保存的程序并在其中加载tensorflow
  2. dill:基于pickle的更强大的pickle工具,它可以帮助您绕过pickle
  3. 遇到的大多数错误
相关问题