我正在使用Tensorflow-for-poets合作实验室中的预训练模型来使用我自己的数据训练模型。我一直在寻求进行超参数搜索以改善我的模型。但是,我可以找到的所有示例都使用Keras,但我不完全知道如何在示例中实现此功能。
再培训脚本的相关部分,它添加了最后一层并训练了模型:
with tf.Session(graph=graph) as sess:
# Initialize all weights: for the module to their pretrained
values,
# and for the newly added retraining layer to random initial values.
init = tf.global_variables_initializer()
sess.run(init)
# Set up the image decoding sub-graph.
jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
if do_distort_images:
# We will be applying distortions, so set up the operations we'll need.
(distorted_jpeg_data_tensor,
distorted_image_tensor) = add_input_distortions(
FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
FLAGS.random_brightness, module_spec)
else:
# We'll make sure we've calculated the 'bottleneck' image summaries and
# cached them on disk.
cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
FLAGS.bottleneck_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor,
bottleneck_tensor, FLAGS.tfhub_module)
# Create the operations we need to evaluate the accuracy of our new layer.
evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)
# Merge all the summaries and write them out to the summaries_dir
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph)
validation_writer = tf.summary.FileWriter(
FLAGS.summaries_dir + '/validation')
# Create a train saver that is used to restore values into an eval graph
# when exporting models.
train_saver = tf.train.Saver()
# Run the training for as many cycles as requested on the command line.
for i in range(FLAGS.how_many_training_steps):
# Get a batch of input bottleneck values, either calculated fresh every
# time with distortions applied, or from the cache stored on disk.
if do_distort_images:
(train_bottlenecks,
train_ground_truth) = get_random_distorted_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.image_dir, distorted_jpeg_data_tensor,
distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
else:
(train_bottlenecks,
train_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
FLAGS.tfhub_module)
# Feed the bottlenecks and ground truth into the graph, and run a training
# step. Capture training summaries for TensorBoard with the `merged` op.
train_summary, _ = sess.run(
[merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
train_writer.add_summary(train_summary, i)
# Every so often, print out how well the graph is training.
is_last_step = (i + 1 == FLAGS.how_many_training_steps)
if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
train_accuracy, cross_entropy_value = sess.run(
[evaluation_step, cross_entropy],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %
(datetime.now(), i, train_accuracy * 100))
tf.logging.info('%s: Step %d: Cross entropy = %f' %
(datetime.now(), i, cross_entropy_value))
# TODO: Make this use an eval graph, to avoid quantization
# moving averages being updated by the validation set, though in
# practice this makes a negligable difference.
validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
FLAGS.tfhub_module))
# Run a validation step and capture training summaries for TensorBoard
# with the `merged` op.
validation_summary, validation_accuracy = sess.run(
[merged, evaluation_step],
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
validation_writer.add_summary(validation_summary, i)
tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
(datetime.now(), i, validation_accuracy * 100,
len(validation_bottlenecks)))
# Store intermediate results
intermediate_frequency = FLAGS.intermediate_store_frequency
if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
and i > 0):
# If we want to do an intermediate save, save a checkpoint of the train
# graph, to restore into the eval graph.
train_saver.save(sess, CHECKPOINT_NAME)
intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
'intermediate_' + str(i) + '.pb')
tf.logging.info('Save intermediate result to : ' +
intermediate_file_name)
save_graph_to_file(intermediate_file_name, module_spec,
class_count)
# After training is complete, force one last save of the train checkpoint.
train_saver.save(sess, CHECKPOINT_NAME)
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
run_final_eval(sess, module_spec, class_count, image_lists,
jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
bottleneck_tensor)
任何帮助将不胜感激!