如何从TensorFlow继承类tf.layers.Layer?

时间:2018-09-25 14:44:15

标签: python tensorflow

如何通过tensorflow run()方法继承类tf.layers.Layer并使它的对象可运行?

例如:

class A(tf.layers.Layer):

    def __init__(self, a, b, ...):
        super(A, self).__init__()

    def apply(self):        # Or something else
        return 5 + 5

    # Or some other method
    ...

当我编码时:

a = A()

with tf.Session() as sess:
    print(sess.run(a))

10

或者继承其他一些类,或者通过函数使其自定义和运行。

我有此代码:

class A(tf.layers.Layer):

    def __init__(self, 
                 a,
                 b,
                 trainable=True, 
                 name=None, 
                 dtype=None,
                 **kwargs):

        super(A, self).__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)


    def __call__(self, inputs, *args, **kwargs):
        print(5)

#     def apply(self):
#         print(5)

当我跑步时:

a = A(1, 2)

with tf.Session() as sess:
    sess.run(a)

我有错误:

TypeError: Can not convert a A into a Tensor or Operation.

但是如果我运行此命令:

a.apply(1)

结果正确: 5

0 个答案:

没有答案