所有,我正在尝试运行带有分布式张量流的NN模型(2ps + 2个工作站,每个都在一台独立的机器上),性能不好,而且工人的cpu使用率在50%~800%之间(机器有40个核心)
我的代码是:`with tf.device(tf.train.replica_device_setter(worker_device =“/ job:worker / task:%d”%FLAGS.task_id,cluster = cluster)):
global_step = tf.contrib.framework.get_or_create_global_step()
indexes, values, labels = inputs.inputs()
logits = models.inference([indexes, values])
with tf.name_scope('loss'):
diff = labels * tf.log(logits)
with tf.name_scope('total'):
loss = -tf.reduce_mean(diff)
tf.summary.scalar('loss', loss)
with tf.name_scope('train'):
print("learning_rate = %f" % FLAGS.learning_rate)
#sync_opt = tf.train.SyncReplicasOptimizer(tf.train.AdamOptimizer(FLAGS.learning_rate),
#replicas_to_aggregate=1, total_num_replicas=1, use_locking=False)
sync_opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
train_step = sync_opt.minimize(loss, global_step=global_step)
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
with tf.name_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
hooks = [tf.train.StopAtStepHook(num_steps=6000)]
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=if_chief,
checkpoint_dir=None,
hooks=hooks,config=sess_config,stop_grace_period_secs=20) as session:
step = 0
while not session.should_stop():
start_time = time.time()
run_metadata = tf.RunMetadata()
_, global_step_value,loss_value = session.run([train_step, global_step,loss],options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
if step > 0 and step % 1 == 0:
duration = time.time() - start_time
format_str = ("After %d traning steps (%d global steps), "
"loss on training batch is %3f. "
"(cost %d s)")
print(format_str % (step, global_step_value, loss_value,duration))
step += 1
if step >0 and step % 10 == 0 :
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline.json', 'w') as f:
f.write(ctf)`
有什么建议吗?谢谢!