Tensorflow 2.0中不推荐使用Tensorflow Dataset.from_generator吗?抛出tf.py_func弃用错误

时间:2019-05-08 11:11:17

标签: python tensorflow tensorflow-datasets tensorflow2.0 tf.keras

当我从生成器创建tf数据集并尝试运行tf2.0代码时,它会以贬值消息警告我。

代码:

import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model


def my_function():
    import numpy as np
    for i in range(1000):
        yield np.random.random(size=(28, 28, 1)), [1.0]


train_ds = tf.data.Dataset.from_generator(my_function, output_types=(tf.float32, tf.float32)).batch(32)


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

    # def __call__(self, *args, **kwargs):
    #     return super().__call(*args,**kwargs)


model = MyModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)


EPOCHS = 5

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100))

警告消息:

........
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, there are two
    options available in V2. ........

我想使用数据集API(带有prefetch)从流输入中馈送数据以进行建模。即使在当前的Alpha版本中仍然可以使用,以后还会将其删除吗?

tensorflow会将生成器数据集中使用的tf.py_func替换为新的东西吗,还是会从生成器API中删除整个数据集?

1 个答案:

答案 0 :(得分:1)

否,tf.data.Dataset.from_generator在TensorFlow 2.0中不会被弃用。您看到的是一条警告消息,用于通知用户将来的更改。如果您需要直接使用py_func,最直接的方法是使用tf.compat.v1.py_func。 TF2.0有自己的包装器,称为tf.py_function