TF对象检测API默认情况下会捕获所有GPU内存,因此很难确定我可以进一步增加批处理大小的数量。通常,我只是继续增加它,直到出现CUDA OOM错误。
PyTorch在默认情况下不会捕获所有GPU内存,因此很容易看出我剩下的工作百分比,而无需进行反复试验。
是否有更好的方法来确定我缺少的TF对象检测API的批量大小?像allow-growth
的{{1}}标志一样?
答案 0 :(得分:1)
我一直在寻找源代码,但没有发现与此相关的标志。
但是,在https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py的文件model_main.py
中
您可以找到以下主要功能定义:
def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
...
想法是采用类似的方式对其进行修改,例如以下方式:
config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)
因此,添加config_proto
并更改config
,但保持其他所有条件不变。
此外,allow_growth
使程序使用所需的GPU内存。因此,取决于您的GPU,您可能最终会消耗掉所有内存。在这种情况下,您可能要使用
config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
定义要使用的内存部分。
希望这有所帮助。
如果您不想修改文件,似乎应该打开一个问题,因为我看不到任何标志。除非标志
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
表示与此相关的内容。但是我不认为是因为model_lib.py
中的内容与火车,评估和推断配置有关,而不与GPU使用情况配置有关。