属性错误:创建Keras模型时,“ NoneType”对象没有属性“ _inbound_nodes”错误

时间:2020-08-14 11:13:44

标签: python tensorflow keras generative-adversarial-network

当我尝试为WGAN_GP建立对抗图时,出现此错误。

def _create_adversarial_graph(self):
        real_img = Input(shape=self.input_dims)
        latent_space = Input(shape=(self.latent_space_size,))
        fake_img = self.generator(latent_space)
        # interpolate real and fake images
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        self.interpolated_img = Add()([Multiply()([alpha, real_img]), Multiply()([1-alpha, fake_img])])
        # pass it through discriminator
        real_critic = self.discriminator(real_img)
        fake_critic = self.discriminator(fake_img)
        interpolated_critic = self.discriminator(self.interpolated_img)
        #--------------------------------------------
        # discriminator (critic) computational graph
        #--------------------------------------------
        set_trainable(self.generator, False) # freeze weights for generator while training discriminator
        self.discriminator_model = Model(inputs=[real_img, latent_space], outputs=[real_critic, fake_critic, interpolated_critic])
        self.discriminator_model.compile(
            loss=[self.Wasserstein_loss, self.Wasserstein_loss, self.GP_loss],
            optimizer=self.optimizer,
            loss_weights=self.discriminator_loss_weights
            )

self.generatorself.discriminator也是Keras模型。 完整错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-35-66b73f0e6b92> in <module>()
     24         generator_padding=["same","same","same","same"],
     25         optimizer=OPTIMIZER,
---> 26         batch_size=BATCH_SIZE
     27         )
     28 print("discriminator_model and generator_model summary:")

10 frames
<ipython-input-34-1ccb34bddd33> in __init__(self, input_dims, latent_space_size, discriminator_filters, discriminator_kernel_size, discriminator_strides, discriminator_padding, discriminator_loss_weights, generator_init_dense, generator_batch_norm_momentum, generator_filters, generator_kernel_size, generator_strides, generator_padding, optimizer, batch_size)
     40         self.batch_size = batch_size
     41         self.discriminator_loss_weights = discriminator_loss_weights
---> 42         self._create_model()
     43     def _create_model(self):
     44         self._create_discriminator()

<ipython-input-34-1ccb34bddd33> in _create_model(self)
     45         self._create_generator()
     46         # now that we have generator and discriminator we can make computational graph
---> 47         self._create_adversarial_graph()
     48 
     49     def _create_discriminator(self):

<ipython-input-34-1ccb34bddd33> in _create_adversarial_graph(self)
     98         #--------------------------------------------
     99         set_trainable(self.generator, False) # freeze weights for generator while training discriminator
--> 100         self.discriminator_model = Model(inputs=[real_img, latent_space], outputs=[real_critic, fake_critic, interpolated_critic])
    101         self.discriminator_model.compile(
    102             loss=[self.Wasserstein_loss, self.Wasserstein_loss, self.GP_loss],

/tensorflow-1.15.2/python3.6/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/tensorflow-1.15.2/python3.6/keras/engine/network.py in __init__(self, *args, **kwargs)
     92                 'inputs' in kwargs and 'outputs' in kwargs):
     93             # Graph network
---> 94             self._init_graph_network(*args, **kwargs)
     95         else:
     96             # Subclassed network

/tensorflow-1.15.2/python3.6/keras/engine/network.py in _init_graph_network(self, inputs, outputs, name, **kwargs)
    239         # Keep track of the network's nodes and layers.
    240         nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network(
--> 241             self.inputs, self.outputs)
    242         self._network_nodes = nodes
    243         self._nodes_by_depth = nodes_by_depth

/tensorflow-1.15.2/python3.6/keras/engine/network.py in _map_graph_network(inputs, outputs)
   1432                   layer=layer,
   1433                   node_index=node_index,
-> 1434                   tensor_index=tensor_index)
   1435 
   1436     for node in reversed(nodes_in_decreasing_depth):

/tensorflow-1.15.2/python3.6/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1419             tensor_index = node.tensor_indices[i]
   1420             build_map(x, finished_nodes, nodes_in_progress, layer,
-> 1421                       node_index, tensor_index)
   1422 
   1423         finished_nodes.add(node)

/tensorflow-1.15.2/python3.6/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1419             tensor_index = node.tensor_indices[i]
   1420             build_map(x, finished_nodes, nodes_in_progress, layer,
-> 1421                       node_index, tensor_index)
   1422 
   1423         finished_nodes.add(node)

/tensorflow-1.15.2/python3.6/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1419             tensor_index = node.tensor_indices[i]
   1420             build_map(x, finished_nodes, nodes_in_progress, layer,
-> 1421                       node_index, tensor_index)
   1422 
   1423         finished_nodes.add(node)

/tensorflow-1.15.2/python3.6/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1391             ValueError: if a cycle is detected.
   1392         """
-> 1393         node = layer._inbound_nodes[node_index]
   1394 
   1395         # Prevent cycles.

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

0 个答案:

没有答案