张量流2.0多工作者镜像策略分布式训练图中的重复节点

时间:2019-05-21 11:57:34

标签: python-3.x tensorflow tensorboard tensorflow2.0 tf.keras

我正在尝试在TensorFlow 2.0中创建多员工镜像的分布式训练图:

Tensor<type: int32 shape: [2] values: 256 256>

这是我的Tensorboard图:

Model Graph with duplicate nodes

整个模型图在这里重复了一次,我只有一个工作人员,但是我可以在张量板图中看到重复的节点(像import tensorflow as tf import os import json import numpy as np NUM_WORKERS = 1 WORKER_IP_ADDRS = ['localhost' for i in range(NUM_WORKERS)] WORKER_PORTS = [12345 + i for i in range(NUM_WORKERS)] INDEX = 0 os.environ['TF_CONFIG'] = json.dumps({ 'cluster': { 'worker': ['%s:%d' % (WORKER_IP_ADDRS[w], WORKER_PORTS[w]) for w in range(NUM_WORKERS)], }, 'task': {'type': 'worker', 'index': INDEX} }) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) checkpoint_dir = f'{str(__file__)[:-3]}/training_checkpoints' # Name of the checkpoint files checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") callbacks = [ tf.keras.callbacks.TensorBoard(log_dir=f'{str(__file__)[:-3]}/tensorboard'), tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True), ] def data_func(): for i in range(100): yield np.random.random((28, 28, 1)), [np.random.randint(0, 10) * 1.0] train_dataset = tf.data.Dataset.from_generator(data_func, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape((28, 28, 1)), tf.TensorShape((1,)))).batch( 12) model.fit(train_dataset, epochs=10, callbacks=callbacks) 这样的节点)。

这样可以吗?为什么创建这些重复的节点?

即使我将工人的数量增加到大于1的任何数量(具有不同的INDEX值),也会生成相同的图形。

0 个答案:

没有答案