如何构建Tensorflow模型代码?

时间:2017-06-30 07:29:57

标签: machine-learning tensorflow

我很难找到如何构建我的Tensorflow模型代码。我想以类的形式构建它以便于将来重用。此外,我目前的结构很乱,张量板图输出里面有多个“模型”。

以下是我目前的情况:

import tensorflow as tf
import os

from utils import Utils as utils

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

class Neural_Network:
    # Neural Network Setup
    num_of_epoch = 50

    n_nodes_hl1 = 500
    n_nodes_hl2 = 500
    n_nodes_hl3 = 500

    def __init__(self):
        self.num_of_classes = utils.get_num_of_classes()
        self.num_of_words = utils.get_num_of_words()

        # placeholders
        self.x = tf.placeholder(tf.float32, [None, self.num_of_words])
        self.y = tf.placeholder(tf.int32, [None, self.num_of_classes])

        with tf.name_scope("model"):
            self.h1_layer = tf.layers.dense(self.x, self.n_nodes_hl1, activation=tf.nn.relu, name="h1")
            self.h2_layer = tf.layers.dense(self.h1_layer, self.n_nodes_hl2, activation=tf.nn.relu, name="h2")
            self.h3_layer = tf.layers.dense(self.h2_layer, self.n_nodes_hl3, activation=tf.nn.relu, name="h3")

            self.logits = tf.layers.dense(self.h3_layer, self.num_of_classes, name="output")

    def predict(self):
        return self.logits

    def make_prediction(self, query):
        result = None

        with tf.Session() as sess:
            saver = tf.train.import_meta_graph('saved_models/testing.meta')
            saver.restore(sess, 'saved_models/testing')

            sess.run(tf.global_variables_initializer())

            prediction = self.predict()
            prediction = sess.run(prediction, feed_dict={self.x : query})
            prediction = prediction.tolist()
            prediction = tf.nn.softmax(prediction)
            prediction = sess.run(prediction)
            print prediction

            return utils.get_label_from_encoding(prediction[0])

    def train(self, data):

        print len(data['values'])
        print len(data['labels'])

        prediction = self.predict()

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=self.y))
        optimizer = tf.train.AdamOptimizer().minimize(cost)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())

            for epoch in range(self.num_of_epoch):
                optimised, loss = sess.run([optimizer, cost],
                                           feed_dict={self.x: data['values'], self.y: data['labels']})

                if epoch % 1 == 0:
                    print("Completed Training Cycle: " + str(epoch) + " out of " + str(self.num_of_epoch))
                    print("Current Loss: " + str(loss))

                    saver = tf.train.Saver()
                    saver.save(sess, 'saved_models/testing')
                    print("Model saved")

我在网上找到的是许多人使用更低级别的代码,例如 tf.Variables tf.Constant 因此,他们更能分开他们的码。但是,由于我对Tensorflow相对较新,我想首先使用更高级别的代码。

有人可以告诉我如何构建我的代码吗?

1 个答案:

答案 0 :(得分:6)

如评论所述,对您的初步问题的简短回答是阅读this,但是当您在评论中提出后续问题时,我认为需要更完整的答案。

  

有人可以告诉我如何构建我的代码吗?

显然,构建代码是一种品味问题。但是,为了帮助您制作自己的品味,您需要记住以下主要内容:TensorFlow中有两个不同的层,不要混淆它们。

  • 第一个是Graph层,它包含所有TensorFlow节点,例如
    • tensors(例如tf.placeholdertf.constanttf.Variables等等。)或
    • operationstf.addtf.matmul等。)。 Graph包含模型 本身,可能包含更多内容,例如:损失函数,训练模型的优化器,输入数据管道等。

每个节点都有一个名称,您可以使用该名称直接从图表中检索它(例如,使用tf.get_variable方法或tf.Graph.get_tensor_by_name)。

  • 第二层是使用Python(或C ++或Java,...)API构建TensorFlow Graph的方式。这可能是您在提问时想到的这一层。但是,在某种程度上,这一层实际上更像是模型工厂,而不是模型。
  

格式是否支持保存和恢复模型?

这取决于 model 的含义,即使在这两种情况下答案都是肯定的。

  • 如果您考虑过TensorFlow Graph,答案是,您可以保存并恢复Graph,因为它并不取决于您的身份构建它。只需查看此document保存和恢复部分,即可获得有关如何执行此操作的一些见解,或查看只有Graph还原的answer
  • 如果考虑到Python类,简短的回答是但是你可以弥补一些事情,使其成为 如前一项所述,TensorFlow检查点不保存Python(也不是C ++或Java)对象,而只保存图形。但是作为Python类的模型的结构位于其他地方:它存在于您的代码中。

    因此,如果您重新创建Python类的实例,并确保在Graph中重新创建所有TensorFlow节点(因此获得等效的Graph),那么,当您和#39; ll从检查点恢复TensorFlow Graph,您的模型将作为Python-instance-linked-to-a-TensorFlow - Graph恢复。

    请参阅document恢复变量部分,了解一个简单的示例,其中Python-instances-linked-to-a-TensorFlow - Graph实际上是Python变量(即v1v2)生活在模块范围内的某个地方。

    # Create some variables.
    v1 = tf.Variable(..., name="v1")
    v2 = tf.Variable(..., name="v2")
    ...
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    with tf.Session() as sess:
      # Restore variables from disk.
      saver.restore(sess, "/tmp/model.ckpt")
      print("Model restored.")
      # Do some work with the model
      ...
    

我只能建议阅读(和upvote :))这个question及其答案,因为您将在TensorFlow中如何保存/恢复工作方面学到很多。

希望现在有点清楚了。