我正在尝试构建包装在简单的一层NN类中的通用tensorflow基础结构(请参见下面的代码)。
我将创建许多NNet,所以我想知道什么是管理会话和变量的最佳方法。
通常,我只想为一个网络而不是所有网络(在“显示”功能中)获取tf.trainable_variables(),以便我可以打印所需的网络。
我还必须将会话变量“ sess”传递给每个函数,以便不重新初始化变量。 我想我做的一切都不正确...有人可以帮忙吗?
class oneLayerNN:
"""
Implements a 1 hidden-layer neural network: y = W2 * ([W1 * x + b1]+) + b1
"""
def __init__(self, ...):
...
self.initOp = tf.global_variables_initializer()
def show(self, sess):
tvars = tf.trainable_variables()
tvals = sess.run(tvars)
for var, val in zip(tvars,tvals):
print(var.name, val)
print()
def initializeVariables(self, sess):
sess.run(self.initOp)
def forwardPropagation(self, sess, x):
labels = sess.run(self.yHat, feed_dict={self.x: x})
return labels
def train(self, sess, dataset, epochs, batchSize, debug=False, verbose=False):
dataset = dataset.batch(batchSize)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
for epoch in range(epochs):
sess.run(iterator.initializer)
while True:
try:
batch_x, batch_y = sess.run(next_element)
_, c = sess.run([self.optimizer, self.loss], feed_dict={self.x: batch_x, self.y: batch_y})
except tf.errors.OutOfRangeError:
break
with tf.Session() as sess:
network.initializeVariables(sess)
network.show(sess)
答案 0 :(得分:0)
这可能与口味以及您打算如何使用物品有关。
如果您可以将对象限制为处理单个tf.Session
(例如Keras,应该满足基本需求,并且可能会超出范围),那么您只需实例化单个{{1} },例如您喜欢的类似Singleton的模式(也许只是Keras中的普通旧函数)。
答案 1 :(得分:0)
感谢您的回答。
但是,变量范围仍然存在问题。如何定义变量作为对象的一部分?我希望能够执行以下操作:
vars = network.getTrainableVariables()
那应该只返回该对象中定义的变量(不像tf.trainable_variables()一样)
在同时使用多个网络(例如,范围是网络的名称)的情况下,找不到在范围内干净声明变量的示例。
当我多次运行代码时,它会创建变量W,b,然后是W_1,b_1,然后是W_2,b_2等...
此外,我希望network.initialize()仅初始化此图中定义的变量,而不初始化每个网络中的所有变量...
一种解决方案是在“名称”范围内声明网络变量,然后能够在“名称”范围内重置reset_default_graph,但是我无法做到这一点。