保存一个非常简单的tensorFlow模型

时间:2018-09-24 16:50:50

标签: python tensorflow

作为Django / React本机Web应用程序的一部分,我正在创建一个非常简单的张量流网络,我设法创建了一个训练集,该训练集在使用plt运行时可以正常工作并产生以下输出:

Training output

但是,就神经网络和Tensor Flow而言,我发现保存这种训练好的集非常困难,我尝试过使用saver()进行保存,并查看了会话,但到目前为止都没有真正起作用。我想要保存此经过训练的模型,以便我可以在我的应用程序中使用它来匹配用户输入的问题(由整数表示)和预设答案。另外,我知道数据是非常基础的,可以通过对响应进行硬编码轻松地完成,该项目更多地是关于学习,而不是以最有效的方式获得结果。非常感谢任何反馈,下面列出了源代码!

machine_learn.py

def loss(self, predicted_y, desired_y):
    return tf.reduce_mean(tf.square(predicted_y - desired_y))    

def train(self, model, inputs, outputs, learning_rate):
    with tf.GradientTape() as t:
        current_loss = self.loss(model(inputs), outputs)

    dQ = t.gradient(current_loss, model.Q)
    model.Q.assign_sub(learning_rate * dQ)

def train_network(self, value_set):
    model = TrainingModel(value_set)
    desired_list = [4.00, 3.00, 2.00, 1.00]
    num_examples = 10000
    desired_ans = desired_list[0]
    inputs = tf.random_normal(shape=[num_examples])
    Qs = []
    epochs = range(150)
    for _ in epochs:
        Qs.append(model.Q.numpy())
        current_loss = self.loss(model.Q, desired_list)
        self.train(model, inputs, desired_list, 0.1)
        print(current_loss)
    plt.plot(epochs, Qs, 'r')
    plt.plot([desired_ans] * len(epochs), 'r--')
    plt.legend(['Q', 'true q'])

    plt.show()

training_model.py

class TrainingModel(object):

def __init__(self, questions):
    self.questions = questions
    self.Q = tfe.Variable(questions)

def __call__(self, inputs):
    return self.Q

1 个答案:

答案 0 :(得分:0)

我同意Prune的评论,但是作为另一个noobie,他在为saver()找到有用的教程时遇到很多麻烦,请在此处查看示例程序: Useful example code

保存和恢复代码应该是您想要的。我也鼓励您查看tensorboard示例,它将向您展示如何在不使用plt库的情况下绘制变量图形。希望这会有所帮助!