合并之前无法编译两个模型

时间:2019-07-19 14:08:59

标签: tensorflow keras

我想实现GAN,这意味着我需要分别训练鉴别器模型,并通过将其与(冻结的)鉴别器组合来训练生成器,然后重复。

但是,如果预先编译了a或b,在编译组合模型时似乎会​​出现错误?

谢谢。

def build_a():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(10, input_shape=(1,)),
        tf.keras.layers.Dense(1),
        tf.keras.layers.Activation('sigmoid'),
    ])

    return model

def build_b():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(10, input_shape=(1,)),
        tf.keras.layers.Dense(1),
        tf.keras.layers.Activation('sigmoid'),
    ])

    return model

def build_c(a, b, x_train):
    a_out = a(x_train)
    b_out = b(a_out)
    return tf.keras.Model(x_train, b_out)

a = build_a()
a.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0002), metrics=['accuracy'])
b = build_b()
b.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0002), metrics=['accuracy'])

c = build_c(a, b, tf.keras.layers.Input((1,)))
c.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0002), metrics=['accuracy'])

x = np.ones((100,1), dtype=np.float32)
y = np.ones((100,1), dtype=np.float32)

a.train_on_batch(x, y) #works
b.train_on_batch(x, y) #works
c.train_on_batch(x, y) #InvalidArgumentError
#c works fine if a.compile and b.compile are removed

print("Done")

这是错误:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-11-135606be04bb> in <module>
     36 a.train_on_batch(x, y)
     37 b.train_on_batch(x, y)
---> 38 c.train_on_batch(x, y)
     39 
     40 print("Done")

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\keras\engine\training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
    916       self._update_sample_weight_modes(sample_weights=sample_weights)
    917       self._make_train_function()
--> 918       outputs = self.train_function(ins)  # pylint: disable=not-callable
    919 
    920     if reset_metrics:

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
   3508         value = math_ops.cast(value, tensor.dtype)
   3509       converted_inputs.append(value)
-> 3510     outputs = self._graph_fn(*converted_inputs)
   3511 
   3512     # EagerTensor.numpy() will often make a copy to ensure memory safety.

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
    570       raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
    571           list(kwargs.keys()), list(self._arg_keywords)))
--> 572     return self._call_flat(args)
    573 
    574   def _filtered_call(self, args, kwargs):

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args)
    669     # Only need to override the gradient in graph mode and when we have outputs.
    670     if context.executing_eagerly() or not self.outputs:
--> 671       outputs = self._inference_function.call(ctx, args)
    672     else:
    673       self._register_gradient()

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args)
    443             attrs=("executor_type", executor_type,
    444                    "config_proto", config),
--> 445             ctx=ctx)
    446       # Replace empty list with None
    447       outputs = outputs or None

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     65     else:
     66       message = e.message
---> 67     six.raise_from(core._status_to_exception(e.code, message), None)
     68   except TypeError as e:
     69     if any(ops._is_keras_symbolic_tensor(x) for x in inputs):

c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\six.py in raise_from(value, from_value)

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  You must feed a value for placeholder tensor 'activation_9_target' with dtype float and shape [?,?]
     [[node activation_9_target (defined at <ipython-input-11-135606be04bb>:38) ]]
  (1) Invalid argument:  You must feed a value for placeholder tensor 'activation_9_target' with dtype float and shape [?,?]
     [[node activation_9_target (defined at <ipython-input-11-135606be04bb>:38) ]]
     [[sequential_9_target/_1]]
0 successful operations.
0 derived errors ignored. [Op:__inference_keras_scratch_graph_7356]

Function call stack:
keras_scratch_graph -> keras_scratch_graph

0 个答案:

没有答案