使用TensorFlow对象检测API确定最大批处理大小

时间:2019-04-05 06:01:29

标签: tensorflow object-detection-api batchsize

TF对象检测API默认情况下会捕获所有GPU内存,因此很难确定我可以进一步增加批处理大小的数量。通常,我只是继续增加它,直到出现CUDA OOM错误。

PyTorch在默认情况下不会捕获所有GPU内存,因此很容易看出我剩下的工作百分比,而无需进行反复试验。

是否有更好的方法来确定我缺少的TF对象检测API的批量大小?像allow-growth的{​​{1}}标志一样?

1 个答案:

答案 0 :(得分:1)

我一直在寻找源代码,但没有发现与此相关的标志。

但是,在https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py的文件model_main.py中 您可以找到以下主要功能定义:

def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
...

想法是采用类似的方式对其进行修改,例如以下方式:

config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True

config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)

因此,添加config_proto并更改config,但保持其他所有条件不变。

此外,allow_growth使程序使用所需的GPU内存。因此,取决于您的GPU,您可能最终会消耗掉所有内存。在这种情况下,您可能要使用

config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9

定义要使用的内存部分。

希望这有所帮助。

如果您不想修改文件,似乎应该打开一个问题,因为我看不到任何标志。除非标志

flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
                    'file.')

表示与此相关的内容。但是我不认为是因为model_lib.py中的内容与火车,评估和推断配置有关,而不与GPU使用情况配置有关。