当尝试以面向对象的方式构建TensorFlow模型时,尤其是在使用tf.data.Dataset时,我发现了许多问题。理想情况下,我的模型公开以下界面:
class MyModel:
def fit(self, x):
pass
def transform(self, x):
pass
挺直的。使用tf.data.Dataset时会出现问题。
tf.data.Dataset从各种来源读取数据并产生张量作为输出。然后可以将这些张量相乘,相加等,以构成您的计算图的其他张量。问题在于这会将您的数据集产生的张量耦合到计算图中。这打破了这个面向对象的接口,因为理想情况下,我想在不同的数据集上调用fit
和transform
方法,而不必每次都构造一个新的计算图。
有人对此问题有解决方案吗?