如何继承典型的Tensorflow模型类以重用大多数图形构建代码?

时间:2018-11-01 12:21:04

标签: python tensorflow

典型的tensorflow模型类如下:

class Model:
    def __init__(self):
        build()
    def build(self):
        self.x = tf.placeholder()
        self.y = f(self.x)
        self.z = g(self.y)

如果需要稍作修改(即将self.y=f(self.x)更改为slef.y=h(self.x)),我们真的很想继承这个Model类并添加一些代码来实现这一点。

但是,一旦调用build函数,便会构建完整的图形。覆盖属性不会改变图形结构。有什么办法可以整齐地完成这项工作吗?

1 个答案:

答案 0 :(得分:1)

您可以参数化fg(或您拥有的任何参数)并将它们传递给构造函数:

class Model:
    def __init__(self, f=default_f, g=default_g):
        self.f = f
        self.g = g
        self.build()

    def build(self):
        self.x = tf.placeholder()
        self.y = self.f(self.x)
        self.z = self.g(self.y)

或者您可以使它们成为可重写的类级变量,以避免构造函数的签名变得肿,但是您不能在构造函数中隐式调用.build()

class Model:
    f = default_f
    g = default_g

    def build(self):
        self.x = tf.placeholder()
        self.y = self.f(self.x)
        self.z = self.g(self.y)

# ...

m = Model()
m.f = some_other_f
m.build()