使用Tensorflow进行超参数调整

时间:2019-05-22 17:58:04

标签: tensorflow hyperparameters

我正在使用Tensorflow-for-poets合作实验室中的预训练模型来使用我自己的数据训练模型。我一直在寻求进行超参数搜索以改善我的模型。但是,我可以找到的所有示例都使用Keras,但我不完全知道如何在示例中实现此功能。

再培训脚本的相关部分,它添加了最后一层并训练了模型:

 with tf.Session(graph=graph) as sess:
# Initialize all weights: for the module to their pretrained 
 values,
# and for the newly added retraining layer to random initial values.
init = tf.global_variables_initializer()
sess.run(init)

# Set up the image decoding sub-graph.
jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)

if do_distort_images:
  # We will be applying distortions, so set up the operations we'll need.
  (distorted_jpeg_data_tensor,
   distorted_image_tensor) = add_input_distortions(
       FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
       FLAGS.random_brightness, module_spec)
else:
  # We'll make sure we've calculated the 'bottleneck' image summaries and
  # cached them on disk.
  cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
                    FLAGS.bottleneck_dir, jpeg_data_tensor,
                    decoded_image_tensor, resized_image_tensor,
                    bottleneck_tensor, FLAGS.tfhub_module)

# Create the operations we need to evaluate the accuracy of our new layer.
evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)

# Merge all the summaries and write them out to the summaries_dir
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                     sess.graph)

validation_writer = tf.summary.FileWriter(
    FLAGS.summaries_dir + '/validation')

# Create a train saver that is used to restore values into an eval graph
# when exporting models.
train_saver = tf.train.Saver()

# Run the training for as many cycles as requested on the command line.
for i in range(FLAGS.how_many_training_steps):
  # Get a batch of input bottleneck values, either calculated fresh every
  # time with distortions applied, or from the cache stored on disk.
  if do_distort_images:
    (train_bottlenecks,
     train_ground_truth) = get_random_distorted_bottlenecks(
         sess, image_lists, FLAGS.train_batch_size, 'training',
         FLAGS.image_dir, distorted_jpeg_data_tensor,
         distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
  else:
    (train_bottlenecks,
     train_ground_truth, _) = get_random_cached_bottlenecks(
         sess, image_lists, FLAGS.train_batch_size, 'training',
         FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
         decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
         FLAGS.tfhub_module)
  # Feed the bottlenecks and ground truth into the graph, and run a training
  # step. Capture training summaries for TensorBoard with the `merged` op.
  train_summary, _ = sess.run(
      [merged, train_step],
      feed_dict={bottleneck_input: train_bottlenecks,
                 ground_truth_input: train_ground_truth})
  train_writer.add_summary(train_summary, i)

  # Every so often, print out how well the graph is training.
  is_last_step = (i + 1 == FLAGS.how_many_training_steps)
  if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
    train_accuracy, cross_entropy_value = sess.run(
        [evaluation_step, cross_entropy],
        feed_dict={bottleneck_input: train_bottlenecks,
                   ground_truth_input: train_ground_truth})
    tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %
                    (datetime.now(), i, train_accuracy * 100))
    tf.logging.info('%s: Step %d: Cross entropy = %f' %
                    (datetime.now(), i, cross_entropy_value))
    # TODO: Make this use an eval graph, to avoid quantization
    # moving averages being updated by the validation set, though in
    # practice this makes a negligable difference.
    validation_bottlenecks, validation_ground_truth, _ = (
        get_random_cached_bottlenecks(
            sess, image_lists, FLAGS.validation_batch_size, 'validation',
            FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
            decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
            FLAGS.tfhub_module))
    # Run a validation step and capture training summaries for TensorBoard
    # with the `merged` op.
    validation_summary, validation_accuracy = sess.run(
        [merged, evaluation_step],
        feed_dict={bottleneck_input: validation_bottlenecks,
                   ground_truth_input: validation_ground_truth})
    validation_writer.add_summary(validation_summary, i)
    tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
                    (datetime.now(), i, validation_accuracy * 100,
                     len(validation_bottlenecks)))

  # Store intermediate results
  intermediate_frequency = FLAGS.intermediate_store_frequency

  if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
      and i > 0):
    # If we want to do an intermediate save, save a checkpoint of the train
    # graph, to restore into the eval graph.
    train_saver.save(sess, CHECKPOINT_NAME)
    intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
                              'intermediate_' + str(i) + '.pb')
    tf.logging.info('Save intermediate result to : ' +
                    intermediate_file_name)
    save_graph_to_file(intermediate_file_name, module_spec,
                       class_count)

# After training is complete, force one last save of the train checkpoint.
train_saver.save(sess, CHECKPOINT_NAME)

# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
run_final_eval(sess, module_spec, class_count, image_lists,
               jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
               bottleneck_tensor) 

任何帮助将不胜感激!

0 个答案:

没有答案