TensorFlow 2.0预览-TypeError:使用tf.function时'属性'对象不可迭代

时间:2019-01-25 16:47:34

标签: python tensorflow keras tensorflow-datasets

在尝试在TensorFlow 2.0预览版中实现简单的DCGAN模型时,我偶然发现了一个错误。

如果我用GAN.train()装饰@tf.function函数,则会在启动过程中引发错误。如果我删除装饰器,它将起作用。

我做错了什么/错过了什么吗?或者可能是个错误?

代码实现

"""
Implement DCGAN using the new TF 2.0 API.

Also test tensorflow-datasets.

Celeb-A dataset.
"""

from typing import Dict
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras as k


def bce(x: tf.Tensor, label: tf.Tensor, label_smoothing: float = 0.0) -> tf.Tensor:
    """Returns the discrete binary cross entropy between x and the discrete label
    Args:
        x: a 2D tensor
        label: the discrite label, aka, the distribution to match
        label_smoothing: if greater than zero, smooth the labels

    Returns:
        The binary cros entropy
    """
    # FIXME: Fix the warning
    # assert len(x.shape) == 2 and len(label.shape) == 0

    return k.losses.BinaryCrossentropy()(tf.ones_like(x) * label, x)


def min_max(
    positive: tf.Tensor, negative: tf.Tensor, label_smoothing: float = 0.0
) -> tf.Tensor:
    """Returns the discriminator (min max) loss
    Args:
        positive: the discriminator output for the positive class: 2D tensor
        negative: the discriminator output for the negative class: 2D tensor
        smooth: if greater than zero, appiles one-sided label smoothing
    Returns:
        The sum of 2 BCE
    """

    one = tf.constant(1.0)
    zero = tf.constant(0.0)
    d_loss = bce(positive, one, label_smoothing) + bce(negative, zero)
    return d_loss


class Generator(k.Model):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        self.fc1 = k.layers.Dense(4 * 4 * 1024)
        self.batchnorm1 = k.layers.BatchNormalization()

        self.conv2 = k.layers.Conv2DTranspose(
            filters=512,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm2 = k.layers.BatchNormalization()

        self.conv3 = k.layers.Conv2DTranspose(
            filters=256,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm3 = k.layers.BatchNormalization()

        self.conv4 = k.layers.Conv2DTranspose(
            filters=128,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm4 = k.layers.BatchNormalization()

        self.conv5 = k.layers.Conv2DTranspose(
            filters=3,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm5 = k.layers.BatchNormalization()

    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.fc1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.relu(x)
        x = tf.reshape(x, shape=(-1, 4, 4, 1024))

        x = self.conv2(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv4(x)
        x = self.batchnorm4(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv5(x)
        x = self.batchnorm5(x, training=training)

        x = tf.nn.tanh(x)
        return x


class Discriminator(k.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = k.layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same")
        self.conv2 = k.layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm2 = k.layers.BatchNormalization()
        self.conv3 = k.layers.Conv2D(512, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm3 = k.layers.BatchNormalization()
        self.conv4 = k.layers.Conv2D(1024, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm4 = k.layers.BatchNormalization()
        self.flatten = k.layers.Flatten()
        self.fc5 = k.layers.Dense(1)

    def call(self, x, training=True):
        x = self.conv1(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv4(x)
        x = self.batchnorm4(x)
        x = tf.nn.leaky_relu(x)

        x = self.flatten(x)
        x = self.fc5(x)
        return x


class GAN:
    def __init__(self, generator, discriminator, encoder=None):
        """
        GAN initializer.

        Args:
            generator: A ``tensorflow.keras.Model`` to use as Generator.
            discriminator: A ``tensorflow.keras.Model`` to use as Discriminator.
            encoder: A ``tensorflow.keras.Model`` to use as Encoder.

        Returns:
            Trained GAN model (?).

        """
        self.G = generator()
        self.D = discriminator()
        self.E = encoder() if encoder is not None else None
        self.latent_vector_dims = 100

        self.G_opt = k.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)
        self.D_opt = k.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)

    @tf.function()
    def train(self, dataset: tf.data.Dataset):
        """
        Train.
        """
        for step, features in enumerate(dataset, start=1):
            x = features["image"]
            z = tf.random.normal((x.shape[0], self.latent_vector_dims))

            # We record all the operations in the tape
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                G_z = self.G(z, training=True)

                D_x = self.D(x, training=True)
                D_Gz = self.D(G_z, training=True)

                g_loss = bce(D_Gz, tf.constant(1.0))
                d_loss = min_max(D_x, D_Gz, label_smoothing=0.0)

            # We retrieve the gradients from our records
            G_grads = gen_tape.gradient(g_loss, self.G.trainable_variables)
            D_grads = disc_tape.gradient(d_loss, self.D.trainable_variables)

            # Optimize and apply the gradients
            self.G_opt.apply_gradients(zip(G_grads, self.G.trainable_variables))
            self.D_opt.apply_gradients(zip(D_grads, self.D.trainable_variables))

            if step % 10 == 0:
                print(f"--------------------------")
                print(f"STEP: {step}")
                print(f"D_LOSS: {d_loss}")
                print(f"G_LOSS: {g_loss}")


class InputPipeline:
    def __init__(
        self, dataset, batch_size, epochs, shuffle_buffer, prefetched_items, size
    ):
        self.batch_size = batch_size
        self.dataset_name = dataset
        self.epochs = epochs
        self.prefetched_items = prefetched_items
        self.shuffle_buffer = shuffle_buffer
        self.size = size

    def get_input_fn(self) -> tf.data.Dataset:
        """Input fn."""
        return self.input_fn

    def load_public_dataset(self):
        """
        Load one of the publicly available datasets, will merge together all the splits.

        Args:
            chosen_dataset: dataset to use.

        Return:
            The chosen dataset as a ``tf.data.Dataset``

        """
        # Construct a tf.data.Dataset
        datasets = tfds.load(name=self.dataset_name, split=tfds.Split.ALL)
        return datasets

    def resize_images(self, features: Dict) -> Dict:
        """
        Overwrite the \"image\" feature in order to resize them.

        Args:
            features: features dictionary.
            size: desired target size.

        Returns:
            Features with \"image\" resized to the correct shape.

        """
        features["image"] = tf.image.resize(features["image"], self.size)
        return features

    def input_fn(self):
        dataset = self.load_public_dataset()
        dataset = (
            dataset.map(self.resize_images)
            .shuffle(self.shuffle_buffer)
            .batch(self.batch_size)
            .prefetch(self.prefetched_items)
            .repeat(self.epochs)
        )
        return dataset


def main():

    # TODO: replace with CLI
    CHOICE = "celeb_a"
    EPOCHS = 10
    BATCH_SIZE = 64
    PREFETCH = 10
    SHUFFLE_BUFFER = 10000

    # See available datasets
    public_datasets = tfds.list_builders()

    gan = GAN(Generator, Discriminator)
    input_pipeline = InputPipeline(
        dataset=CHOICE,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        prefetched_items=PREFETCH,
        shuffle_buffer=SHUFFLE_BUFFER,
        size=(64, 64),
    )
    dataset = input_pipeline.input_fn()
    gan.train(dataset=dataset)


if __name__ == "__main__":
    main()

完整追溯

Traceback (most recent call last):
  File "dcgan-tf2.py", line 289, in <module>
    main()
  File "dcgan-tf2.py", line 285, in main
    gan.train(dataset=dataset)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 383, in __call__
    self._initialize(args, kwds)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 355, in _initialize
    *args, **kwds))
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1097, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1322, in _maybe_define_function
    arg_names=arg_names), self._function_attributes)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 540, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 298, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1803, in bound_method_wrapper
    return wrapped_fn(weak_instance(), *args, **kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 533, in wrapper
    ), *args, **kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 293, in converted_call
    experimental_partial_types=partial_types)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 415, in to_graph
    arg_values, arg_types)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 222, in entity_to_graph
    entity_to_graph(candidate, program_ctx, {}, {})
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 175, in entity_to_graph
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 376, in function_to_graph
    node = node_to_graph(node, context)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 435, in node_to_graph
    node = converter.apply_(node, context, call_trees)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/core/converter.py", line 507, in apply_
    node = converter_module.transform(node, context)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/converters/call_trees.py", line 350, in transform
    return CallTreeTransformer(ctx).visit(node)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/core/converter.py", line 440, in visit
    return super(Base, self).visit(node)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/pyct/transformer.py", line 484, in visit
    result = super(Base, self).visit(node)
  File "/usr/lib64/python3.6/ast.py", line 253, in visit
    return visitor(node)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/converters/call_trees.py", line 282, in visit_FunctionDef
    node.returns = self.visit_block(node.returns)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/pyct/transformer.py", line 368, in visit_block
    for node in nodes:
TypeError: 'Attribute' object is not iterable

0 个答案:

没有答案