我有一个训练循环,我想要注入一个函数,例如:
def example():
input("Press the <ENTER> key to continue...")
这是我的模特
def main(unused_argv):
assert FLAGS.input_file_pattern, "--input_file_pattern is required"
assert FLAGS.train_dir, "--train_dir is required"
model_config = configuration.ModelConfig()
model_config.input_file_pattern = FLAGS.input_file_pattern
model_config.inception_checkpoint_file = FLAGS.inception_checkpoint_file
training_config = configuration.TrainingConfig()
# Create training directory.
train_dir = FLAGS.train_dir
if not tf.gfile.IsDirectory(train_dir):
tf.logging.info("Creating training directory: %s", train_dir)
tf.gfile.MakeDirs(train_dir)
# Build the TensorFlow graph.
g = tf.Graph()
with g.as_default():
# Build the model.
model = show_and_tell_model.ShowAndTellModel(
model_config, mode="train", train_inception=FLAGS.train_inception)
model.build()
# Set up the learning rate.
learning_rate_decay_fn = None
if FLAGS.train_inception:
learning_rate = tf.constant(training_config.train_inception_learning_rate)
else:
learning_rate = tf.constant(training_config.initial_learning_rate)
if training_config.learning_rate_decay_factor > 0:
num_batches_per_epoch = (training_config.num_examples_per_epoch /
model_config.batch_size)
decay_steps = int(num_batches_per_epoch *
training_config.num_epochs_per_decay)
def _learning_rate_decay_fn(learning_rate, global_step):
return tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps=decay_steps,
decay_rate=training_config.learning_rate_decay_factor,
staircase=True)
learning_rate_decay_fn = _learning_rate_decay_fn
# Set up the training ops.
train_op = tf.contrib.layers.optimize_loss(
loss=model.total_loss,
global_step=model.global_step,
learning_rate=learning_rate,
optimizer=training_config.optimizer,
clip_gradients=training_config.clip_gradients,
learning_rate_decay_fn=learning_rate_decay_fn)
# Set up the Saver for saving and restoring model checkpoints.
saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)
# Run training.
tf.contrib.slim.learning.train(
train_op,
train_dir,
log_every_n_steps=FLAGS.log_every_n_steps,
graph=g,
global_step=model.global_step,
number_of_steps=FLAGS.number_of_steps,
init_fn=model.init_fn,
saver=saver)
if __name__ == "__main__":
tf.app.run()
我尝试添加一个名为tester的函数:
def tester(x):
print('test')
return
就在#run训练之前,我把
inp = tf.placeholder(tf.float32)
tf.py_func(tester, [inp], tf.float32)