TensorFlow 2如何在tf.function中使用* args?

时间:2019-12-03 23:57:27

标签: python tensorflow tensorflow2.0

更新:

进行了更多测试,并且无法通过以下方式重现该行为:

import tensorflow as tf
import numpy as np

@tf.function
def tf_being_unpythonic(an_input, another_input):
    return an_input + another_input

@tf.function
def example(*inputs, other_args = True):
    return tf_being_unpythonic(*inputs)

class TestClass(tf.keras.Model):
    def __init__(self, a, b):
        super().__init__()
        self.a= a
        self.b = b

    @tf.function
    def call(self, *inps, some_kwarg=False):
        if some_kwarg:
            return self.a(*inps)
        return self.b(*inps)

class Model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.inps = tf.keras.layers.Flatten()
        self.hl1 = tf.keras.layers.Dense(5)
        self.hl2 = tf.keras.layers.Dense(4)
        self.out = tf.keras.layers.Dense(1)

    @tf.function
    def call(self,observation):
        x = self.inps(observation)
        x = self.hl1(x)
        x = self.hl2(x)
        return self.out(x)


class Model2(Model):
    def __init__(self):
        super().__init__()
        self.prein = tf.keras.layers.Concatenate()

    @tf.function
    def call(self,b,c):
        x = self.prein([b,c])
        return super().call(x)   

am = Model()
pm = Model2()
test = TestClass(am,pm)

a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))

test(a,some_kwarg=True)
test(a,b) 

所以这可能是其他地方的错误。

@tf.function
def call(self, *inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    if target:
        return self._target_network(*inp, training)
    return self._network(*inp, training)

我得到:

ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []

但是print(inp)给出:

(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,) 

此后,我已经编辑过,只是未提交的玩具代码,因此无法进一步调查。将问题留在这里,以便所有不了解此问题的人都不会读任何东西。

2 个答案:

答案 0 :(得分:1)

我认为使用*args构造不是tf.function的好习惯。如您所见,大多数TF接受可变数量输入的函数都使用元组。

因此,您可以将函数签名重写为:

def call(self, inputs, target=False, training=False)

并通过以下方式调用它:

instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])

编辑

顺便说一句,在将tf.function*args构造一起使用时,我看不到任何错误:

import tensorflow as tf

@tf.function
def call(*inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    return inp[0]

def main():
    print(call(1))
    print(call(2, 2))
    print(call(3, 3, 3))


if __name__ == '__main__':
    main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

因此,您应该向我们提供有关您尝试执行的操作以及错误出处的详细信息。

答案 1 :(得分:0)

这可能是最近已解决的错误。 *args**kwargs应该可以正常工作。