使用Tensorflow事件探查器时,我收到许多以下形式的警告消息:“节点梯度/ resnet_model / IdentityN_9_grad / cond / Pad_1不兼容的形状:形状(?,11,11,64)和(128,64,11 ,11)在训练过程中不兼容。”但是,培训过程不会崩溃。有人可以解释这些消息的性质吗?
答案 0 :(得分:0)
您的tf.placeholder中有一个未定义数量的补丁(dim为“ None”,并且探查器显示为“?”)。 即使使用tensorflow也可以,但探查器不支持此功能。 将该值设为硬编码(因此对您来说是128),这些警告将不再发生。
请注意,此探查器缝不会保留,在TF 2+版本中可能已禁用。
对于注释“ 我应该为Tensorflow使用什么探查器?”中的补充问题,答案有点复杂,因为您没有使用来说明要在TF脚本中探索的内容。探查器,探查器也包含顾问。 假设您想在TF模型中找到瓶颈,检查设备上的计算时间和内存,那么最简单的工具是创建时间轴json文件,然后在Chrome浏览器中读取它。
Python脚本类似于:
# build summary for logs (to be read with tensorboard):
# ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
log_dir = './logs_' + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
if not os.path.exists(log_dir):
os.makedirs(log_dir)
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('lr', self.lr)
tf.summary.scalar('psnr', self.psnr)
writer = tf.summary.FileWriter(log_dir, self.sess.graph)
merged = tf.summary.merge_all()
clip_all_weights = tf.get_collection("max_norm")
# enable full trace and metadata for tensorboard:
# ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
# request one TF iteration (this is a sample with a denoising CNN)
_, loss, summary = self.sess.run([self.train_op, self.loss, merged],
feed_dict={self.Y_: batch_clean, self.X: batch_noisy, self.lr: lr[epoch], self.psnr: self.psnr_tmp, self.is_training: True},
options=run_options,
run_metadata=run_metadata)
self.sess.run(clip_all_weights)
# add metadata and summary to log file:
# ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
writer.add_run_metadata(run_metadata, 'iter_%06d' % iter_num)
writer.add_summary(summary, iter_num)
# Create the Timeline object from metadata, and write it to a json file.
# Point Chrome browser to "chrome://tracing/" to load this json file.
# ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format(show_memory=True) #show_dataflow=True,
with open('json.timeline/timeline_%i.json' % iter_num, 'w') as f:
f.write(ctf)