类内的Tensorflow图-如何管理会话和范围

时间:2018-09-07 14:17:53

标签: python class tensorflow

我正在尝试构建包装在简单的一层NN类中的通用tensorflow基础结构(请参见下面的代码)。

我将创建许多NNet,所以我想知道什么是管理会话和变量的最佳方法。

通常,我只想为一个网络而不是所有网络(在“显示”功能中)获取tf.trainable_variables(),以便我可以打印所需的网络。

我还必须将会话变量“ sess”传递给每个函数,以便不重新初始化变量。 我想我做的一切都不正确...有人可以帮忙吗?

class oneLayerNN: 

"""
Implements a 1 hidden-layer neural network: y = W2 * ([W1 * x + b1]+) + b1
"""

def __init__(self, ...):
    ...
    self.initOp = tf.global_variables_initializer()

def show(self, sess):
    tvars = tf.trainable_variables()
    tvals = sess.run(tvars)
    for var, val in zip(tvars,tvals):
        print(var.name, val)
    print()

def initializeVariables(self, sess):
    sess.run(self.initOp)

def forwardPropagation(self, sess, x):
    labels = sess.run(self.yHat, feed_dict={self.x: x})
    return labels

def train(self, sess, dataset, epochs, batchSize, debug=False, verbose=False):
    dataset = dataset.batch(batchSize)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    for epoch in range(epochs):
        sess.run(iterator.initializer)
        while True:
            try:
                batch_x, batch_y = sess.run(next_element)
                _, c = sess.run([self.optimizer, self.loss], feed_dict={self.x: batch_x, self.y: batch_y})
            except tf.errors.OutOfRangeError:
                break

with tf.Session() as sess:
    network.initializeVariables(sess)
    network.show(sess)  

2 个答案:

答案 0 :(得分:0)

这可能与口味以及您打算如何使用物品有关。

如果您可以将对象限制为处理单个tf.Session(例如Keras,应该满足基本需求,并且可能会超出范围),那么您只需实例化单个{{1} },例如您喜欢的类似Singleton的模式(也许只是Keras中的普通旧函数)。

答案 1 :(得分:0)

感谢您的回答。

但是,变量范围仍然存在问题。如何定义变量作为对象的一部分?我希望能够执行以下操作:

vars = network.getTrainableVariables()

那应该只返回该对象中定义的变量(不像tf.trainable_variables()一样)

在同时使用多个网络(例如,范围是网络的名称)的情况下,找不到在范围内干净声明变量的示例。

当我多次运行代码时,它会创建变量W,b,然后是W_1,b_1,然后是W_2,b_2等...

此外,我希望network.initialize()仅初始化此图中定义的变量,而不初始化每个网络中的所有变量...

一种解决方案是在“名称”范围内声明网络变量,然后能够在“名称”范围内重置reset_default_graph,但是我无法做到这一点。