我喜欢围绕我的Tensorflow模型构建python类,以使它们更易于使用(至少在我看来)。我采取的方法是编写像
这样的东西class MyAwesomeModel(object):
def __init__(self, some_graph_params):
# bunch of code that defines the tensors, optimizer, etc...
# e.g. self.mytensor = tf.placeholder(tf.float32, [1])
def Train(self, tfsession, input_val):
# some code that calls the run() method on tfsession, etc.
def other_methods(self):
# other things like testing, plotting, etc. all managed nicely
# by the state that MyAwesomeModel instances maintain
我有两个非常相似的模型。唯一的区别在于计算图结构中的一些地方 - 我想创建一个具有所有常用功能的基类,并且只有子类可以覆盖基类中的一些东西。以下是它在我脑海中的运作方式
说我的基类看起来像这样
import tensorflow as tf
class BaseClass(object):
def __init__(self, multiplier):
self.multiplier = multiplier
# This is where I construct the graph
self.inputnode = tf.placeholder(tf.float32, [1])
self.tensor1 = tf.constant(self.multiplier, dtype=tf.float32) * self.inputnode
self.tensor2 = self.tensor1 # this is where the two
# child classes will differ
self.tensoroutput = 10*self.tensor2
def forward_pass(self, tfsession, input_val):
return tfsession.run(self.tensoroutput,
feed_dict={self.inputnode: [input_val]})
def other_methods(self):
print("doing something here...")
print(self.multiplier)
然后我弹出两个子类,它们重新定义了self.tensor2
和self.tensor1
之间的关系:
class ChildClass1(BaseClass):
def __init__(self, multiplier):
BaseClass.__init__(self, multiplier)
self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)
class ChildClass2(BaseClass):
def __init__(self, multiplier):
BaseClass.__init__(self, multiplier)
self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)
我的目标是运行以下内容:
cc1 = ChildClass1(2) # multiplier is 2
mysession = tf.Session()
mysession.run(tf.global_variables_initializer())
print(cc1.forward_pass(mysession, 5))
如果我想这样做,那么结果将是((5 * 2)+5)* 10 = 150.如果对象cc1是ChildClass2(2)类型,那么我想要结果是((5 * 2)+4)* 10 = 140.
但是,当我运行上面的代码时,结果是100,这与子类永远不会覆盖在基类中首次遇到的self.tensor2 的定义一致。我认为我需要有那条不稳定的行self.tensor2 = self.tensor1
,否则以下行会抱怨self.tensor2
不存在。我真正想要的是让子类覆盖self.tensor2
的定义,而不是别的。这样做的正确方法是什么?
非常感谢!
答案 0 :(得分:1)
self.tensoroutput永远不会被覆盖,因此它的值不依赖于你拥有的任何基类。使它成为一种方法,然后它就可以了。