封装TensorFlow计算图的类中的类继承

时间:2018-02-22 06:15:48

标签: python oop inheritance tensorflow machine-learning

我喜欢围绕我的Tensorflow模型构建python类,以使它们更易于使用(至少在我看来)。我采取的方法是编写像

这样的东西
class MyAwesomeModel(object):
    def __init__(self, some_graph_params):
        # bunch of code that defines the tensors, optimizer, etc...
        # e.g.  self.mytensor = tf.placeholder(tf.float32, [1])
    def Train(self, tfsession, input_val):
        # some code that calls the run() method on tfsession, etc.
    def other_methods(self):
        # other things like testing, plotting, etc. all managed nicely
        # by the state that MyAwesomeModel instances maintain

我有两个非常相似的模型。唯一的区别在于计算图结构中的一些地方 - 我想创建一个具有所有常用功能的基类,并且只有子类可以覆盖基类中的一些东西。以下是它在我脑海中的运作方式

说我的基类看起来像这样

import tensorflow as tf

class BaseClass(object):
    def __init__(self, multiplier):
        self.multiplier = multiplier
        # This is where I construct the graph
        self.inputnode = tf.placeholder(tf.float32, [1])
        self.tensor1 = tf.constant(self.multiplier, dtype=tf.float32) * self.inputnode
        self.tensor2 = self.tensor1  # this is where the two 
                                     # child classes will differ
        self.tensoroutput = 10*self.tensor2
    def forward_pass(self, tfsession, input_val):
        return tfsession.run(self.tensoroutput, 
                             feed_dict={self.inputnode: [input_val]})
    def other_methods(self):
        print("doing something here...")
        print(self.multiplier)

然后我弹出两个子类,它们重新定义了self.tensor2self.tensor1之间的关系:

class ChildClass1(BaseClass):
    def __init__(self, multiplier):
        BaseClass.__init__(self, multiplier)
        self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)

class ChildClass2(BaseClass):
    def __init__(self, multiplier):
        BaseClass.__init__(self, multiplier)
        self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)

我的目标是运行以下内容:

cc1 = ChildClass1(2)   # multiplier is 2
mysession = tf.Session()
mysession.run(tf.global_variables_initializer())
print(cc1.forward_pass(mysession, 5))

如果我想这样做,那么结果将是((5 * 2)+5)* 10 = 150.如果对象cc1是ChildClass2(2)类型,那么我想要结果是((5 * 2)+4)* 10 = 140.

但是,当我运行上面的代码时,结果是100,这与子类永远不会覆盖在基类中首次遇到的self.tensor2 的定义一致。我认为我需要有那条不稳定的行self.tensor2 = self.tensor1,否则以下行会抱怨self.tensor2不存在。我真正想要的是让子类覆盖self.tensor2的定义,而不是别的。这样做的正确方法是什么?

非常感谢!

1 个答案:

答案 0 :(得分:1)

self.tensoroutput永远不会被覆盖,因此它的值不依赖于你拥有的任何基类。使它成为一种方法,然后它就可以了。