进行了更多测试,并且无法通过以下方式重现该行为:
import tensorflow as tf
import numpy as np
@tf.function
def tf_being_unpythonic(an_input, another_input):
return an_input + another_input
@tf.function
def example(*inputs, other_args = True):
return tf_being_unpythonic(*inputs)
class TestClass(tf.keras.Model):
def __init__(self, a, b):
super().__init__()
self.a= a
self.b = b
@tf.function
def call(self, *inps, some_kwarg=False):
if some_kwarg:
return self.a(*inps)
return self.b(*inps)
class Model(tf.keras.Model):
def __init__(self):
super().__init__()
self.inps = tf.keras.layers.Flatten()
self.hl1 = tf.keras.layers.Dense(5)
self.hl2 = tf.keras.layers.Dense(4)
self.out = tf.keras.layers.Dense(1)
@tf.function
def call(self,observation):
x = self.inps(observation)
x = self.hl1(x)
x = self.hl2(x)
return self.out(x)
class Model2(Model):
def __init__(self):
super().__init__()
self.prein = tf.keras.layers.Concatenate()
@tf.function
def call(self,b,c):
x = self.prein([b,c])
return super().call(x)
am = Model()
pm = Model2()
test = TestClass(am,pm)
a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))
test(a,some_kwarg=True)
test(a,b)
所以这可能是其他地方的错误。
@tf.function
def call(self, *inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
if target:
return self._target_network(*inp, training)
return self._network(*inp, training)
我得到:
ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []
但是print(inp)给出:
(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,)
此后,我已经编辑过,只是未提交的玩具代码,因此无法进一步调查。将问题留在这里,以便所有不了解此问题的人都不会读任何东西。
答案 0 :(得分:1)
我认为使用*args
构造不是tf.function
的好习惯。如您所见,大多数TF接受可变数量输入的函数都使用元组。
因此,您可以将函数签名重写为:
def call(self, inputs, target=False, training=False)
并通过以下方式调用它:
instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])
顺便说一句,在将tf.function
与*args
构造一起使用时,我看不到任何错误:
import tensorflow as tf
@tf.function
def call(*inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
return inp[0]
def main():
print(call(1))
print(call(2, 2))
print(call(3, 3, 3))
if __name__ == '__main__':
main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
因此,您应该向我们提供有关您尝试执行的操作以及错误出处的详细信息。
答案 1 :(得分:0)
这可能是最近已解决的错误。 *args
和**kwargs
应该可以正常工作。