我想实现GAN,这意味着我需要分别训练鉴别器模型,并通过将其与(冻结的)鉴别器组合来训练生成器,然后重复。
但是,如果预先编译了a或b,在编译组合模型时似乎会出现错误?
谢谢。
def build_a():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, input_shape=(1,)),
tf.keras.layers.Dense(1),
tf.keras.layers.Activation('sigmoid'),
])
return model
def build_b():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, input_shape=(1,)),
tf.keras.layers.Dense(1),
tf.keras.layers.Activation('sigmoid'),
])
return model
def build_c(a, b, x_train):
a_out = a(x_train)
b_out = b(a_out)
return tf.keras.Model(x_train, b_out)
a = build_a()
a.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0002), metrics=['accuracy'])
b = build_b()
b.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0002), metrics=['accuracy'])
c = build_c(a, b, tf.keras.layers.Input((1,)))
c.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0002), metrics=['accuracy'])
x = np.ones((100,1), dtype=np.float32)
y = np.ones((100,1), dtype=np.float32)
a.train_on_batch(x, y) #works
b.train_on_batch(x, y) #works
c.train_on_batch(x, y) #InvalidArgumentError
#c works fine if a.compile and b.compile are removed
print("Done")
这是错误:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-11-135606be04bb> in <module>
36 a.train_on_batch(x, y)
37 b.train_on_batch(x, y)
---> 38 c.train_on_batch(x, y)
39
40 print("Done")
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\keras\engine\training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
916 self._update_sample_weight_modes(sample_weights=sample_weights)
917 self._make_train_function()
--> 918 outputs = self.train_function(ins) # pylint: disable=not-callable
919
920 if reset_metrics:
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
3508 value = math_ops.cast(value, tensor.dtype)
3509 converted_inputs.append(value)
-> 3510 outputs = self._graph_fn(*converted_inputs)
3511
3512 # EagerTensor.numpy() will often make a copy to ensure memory safety.
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
570 raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
571 list(kwargs.keys()), list(self._arg_keywords)))
--> 572 return self._call_flat(args)
573
574 def _filtered_call(self, args, kwargs):
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args)
669 # Only need to override the gradient in graph mode and when we have outputs.
670 if context.executing_eagerly() or not self.outputs:
--> 671 outputs = self._inference_function.call(ctx, args)
672 else:
673 self._register_gradient()
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args)
443 attrs=("executor_type", executor_type,
444 "config_proto", config),
--> 445 ctx=ctx)
446 # Replace empty list with None
447 outputs = outputs or None
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 if any(ops._is_keras_symbolic_tensor(x) for x in inputs):
c:\users\noah\pycharmprojects\untitled\venv\lib\site-packages\six.py in raise_from(value, from_value)
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: You must feed a value for placeholder tensor 'activation_9_target' with dtype float and shape [?,?]
[[node activation_9_target (defined at <ipython-input-11-135606be04bb>:38) ]]
(1) Invalid argument: You must feed a value for placeholder tensor 'activation_9_target' with dtype float and shape [?,?]
[[node activation_9_target (defined at <ipython-input-11-135606be04bb>:38) ]]
[[sequential_9_target/_1]]
0 successful operations.
0 derived errors ignored. [Op:__inference_keras_scratch_graph_7356]
Function call stack:
keras_scratch_graph -> keras_scratch_graph