我一直打算使用tf.distribute.MirroredStrategy训练变体自动编码器,但是会发生'InvalidArgumentError'。
但是,如果我训练除tf.distribute.MirroredStrategy之外的该模型,则不会发生任何错误。
我需要在哪里修复某些东西?
型号
embedding_size = int(((input_shape[0] - 4) / 3 - 4) ** 2)
print(embedding_size)
inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32)
encode = tf.keras.layers.Conv2D(
filters=16, kernel_size=3, strides=(1, 1), activation='relu')(inputs)
encode = tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(1, 1), activation='relu')(encode)
encode = tf.keras.layers.MaxPooling2D(3)(encode)
encode = tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(1, 1), activation='relu')(encode)
encode = tf.keras.layers.Conv2D(
filters=embedding_size, kernel_size=3, strides=(1, 1), activation='relu')(encode)
encode = tf.keras.layers.GlobalAveragePooling2D(name='encoder_output')(encode)
decode_input_size = (
int(embedding_size ** (1 / 2)),
int(embedding_size ** (1 / 2)),
1
)
decode = tf.keras.layers.Reshape(decode_input_size)(encode)
decode = tf.keras.layers.Conv2DTranspose(16, 3, activation='relu')(decode)
decode = tf.keras.layers.Conv2DTranspose(32, 3, activation='relu')(decode)
decode = tf.keras.layers.UpSampling2D(3)(decode)
decode = tf.keras.layers.Conv2DTranspose(16, 3, activation='relu')(decode)
outputs = tf.keras.layers.Conv2DTranspose(input_shape[2], 3, activation='relu', name='decoder_output')(decode)
return tf.keras.Model(inputs=inputs, outputs=outputs)
编译并拟合
distribute_strategy = tf.distribute.MirroredStrategy()
with distribute_strategy.scope():
model = ConvAE(input_shape=input_shape)optimizer =
tf.keras.optimizers.Adam(begin_learning_rate)
model.compile(
optimizer=optimizer,
loss='mse',
metrics=['accuracy']
)
# do something to create a dataset for training
model.fit(...)
[错误回溯]
Traceback (most recent call last):
File "src/main.py", line 285, in <module>
app.run(main)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "src/main.py", line 261, in main
TensorBoardImage(tag='gen_images', data=x_data, data_size=len(x_data))
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 746, in fit
validation_freq=validation_freq)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_distributed.py", line 131, in fit_distributed
steps_name='steps_per_epoch')
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 263, in model_iteration
batch_outs = f(actual_inputs)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/backend.py", line 3217, in __call__
outputs = self._graph_fn(*converted_inputs)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/function.py", line 558, in __call__
return self._call_flat(args)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/function.py", line 627, in _call_flat
outputs = self._inference_function.call(ctx, args)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/function.py", line 415, in call
ctx=ctx)
File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/execute.py", line 66, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,32,32,3]
[[{{node input_1}}]] [Op:__inference_keras_scratch_graph_2723]