在Tensorflow 2.0a中使用tf.distribute.MirroredStratege时如何修复'InvalidArgumentError'错误

时间:2019-06-01 17:57:11

标签: tensorflow tensorflow2.0

我一直打算使用tf.distribute.MirroredStrategy训练变体自动编码器,但是会发生'InvalidArgumentError'。

但是,如果我训练除tf.distribute.MirroredStrategy之外的该模型,则不会发生任何错误。

我需要在哪里修复某些东西?

型号

 embedding_size = int(((input_shape[0] - 4) / 3 - 4) ** 2)
    print(embedding_size)

    inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32)
    encode = tf.keras.layers.Conv2D(
        filters=16, kernel_size=3, strides=(1, 1), activation='relu')(inputs)
    encode = tf.keras.layers.Conv2D(
        filters=32, kernel_size=3, strides=(1, 1), activation='relu')(encode)
    encode = tf.keras.layers.MaxPooling2D(3)(encode)
    encode = tf.keras.layers.Conv2D(
        filters=32, kernel_size=3, strides=(1, 1), activation='relu')(encode)
    encode = tf.keras.layers.Conv2D(
        filters=embedding_size, kernel_size=3, strides=(1, 1), activation='relu')(encode)
    encode = tf.keras.layers.GlobalAveragePooling2D(name='encoder_output')(encode)
    decode_input_size = (
        int(embedding_size ** (1 / 2)),
        int(embedding_size ** (1 / 2)),
        1
    )
    decode = tf.keras.layers.Reshape(decode_input_size)(encode)
    decode = tf.keras.layers.Conv2DTranspose(16, 3, activation='relu')(decode)
    decode = tf.keras.layers.Conv2DTranspose(32, 3, activation='relu')(decode)
    decode = tf.keras.layers.UpSampling2D(3)(decode)
    decode = tf.keras.layers.Conv2DTranspose(16, 3, activation='relu')(decode)
    outputs = tf.keras.layers.Conv2DTranspose(input_shape[2], 3, activation='relu', name='decoder_output')(decode)
    return tf.keras.Model(inputs=inputs, outputs=outputs)

编译并拟合

distribute_strategy = tf.distribute.MirroredStrategy()
with distribute_strategy.scope():
    model = ConvAE(input_shape=input_shape)optimizer = 
    tf.keras.optimizers.Adam(begin_learning_rate)
    model.compile(
        optimizer=optimizer,
        loss='mse',
        metrics=['accuracy']
    )

    # do something to create a dataset for training

    model.fit(...)

[错误回溯]

Traceback (most recent call last):
  File "src/main.py", line 285, in <module>
    app.run(main)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "src/main.py", line 261, in main
    TensorBoardImage(tag='gen_images', data=x_data, data_size=len(x_data))
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 746, in fit
    validation_freq=validation_freq)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_distributed.py", line 131, in fit_distributed
    steps_name='steps_per_epoch')
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 263, in model_iteration
batch_outs = f(actual_inputs)

  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/keras/backend.py", line 3217, in __call__
    outputs = self._graph_fn(*converted_inputs)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/function.py", line 558, in __call__
    return self._call_flat(args)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/function.py", line 627, in _call_flat
    outputs = self._inference_function.call(ctx, args)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/function.py", line 415, in call
    ctx=ctx)
  File "/Users/rangkim/projects/yodanjedi/ai/tf/vae/venv/lib/python3.5/site-packages/tensorflow/python/eager/execute.py", line 66, in quick_execute
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,32,32,3]
     [[{{node input_1}}]] [Op:__inference_keras_scratch_graph_2723]

0 个答案:

没有答案