我想加载预训练的模型并继续使用该模型进行训练。 用于保存模型的标准代码段(pretrain.py):
# Loads train and test dataset.
test_data = load_data(parameters_for_test_data)
training_data = load_data(parameters_for_training_data)
# Defines wrapper functions, which can directly be passed to the
# tf.data.Dataset.map function.
def _make_graph_from_static_structure(static_structure):
"""Converts static structure to graph, targets and types."""
return (graph_model.make_graph_from_static_structure(
static_structure.positions,
static_structure.types,
static_structure.box,
edge_threshold),
static_structure.targets,
static_structure.types)
def _apply_random_rotation(graph, targets, types):
"""Applies random rotations to the graph and forwards targets and types."""
return graph_model.apply_random_rotation(graph), targets, types
# Defines data-pipeline based on tf.data.Dataset
placeholders = GlassSimulationData._make(
tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0])
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
dataset = dataset.map(_make_graph_from_static_structure)
dataset = dataset.cache()
dataset = dataset.shuffle(400)
# Augments data.
if augment_data_using_rotations:
dataset = dataset.map(_apply_random_rotation)
dataset = dataset.repeat()
train_iterator = dataset.make_initializable_iterator()
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
dataset = dataset.map(_make_graph_from_static_structure)
dataset = dataset.cache()
dataset = dataset.repeat()
test_iterator = dataset.make_initializable_iterator()
# Creates tensorflow graph.
dataset_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
dataset_handle, train_iterator.output_types, train_iterator.output_shapes)
graph, targets, types = iterator.get_next()
model = graph_model.GraphBasedModel(
n_recurrences, mlp_sizes, mlp_kwargs)
prediction = model(graph)
# Adds sigmoid activation
def sigmoid_activate(x):
return 1.0 / ( 1.0 + tf.exp(-x) )
prediction = sigmoid_activate(prediction)
# Defines loss and minimization operations.
loss_ops = get_loss_ops(prediction, targets)
minimize_op = get_minimize_op(loss_ops.binary_crossentrop, learning_rate, grad_clip)
best_so_far = -1
train_stats = []
test_stats = []
saver = tf.train.Saver()
with tf.train.SingularMonitoredSession() as session:
# Initializes train and test iterators with the training and test datasets.
# The obtained training and test string-handles can be passed to the
# dataset_handle placeholder to select the dataset.
train_handle = session.run(train_iterator.string_handle())
test_handle = session.run(test_iterator.string_handle())
feed_dict = {p: [x[i] for x in training_data]
for i, p in enumerate(placeholders)}
session.run(train_iterator.initializer, feed_dict=feed_dict)
feed_dict = {p: [x[i] for x in test_data]
for i, p in enumerate(placeholders)}
session.run(test_iterator.initializer, feed_dict=feed_dict)
# Trains model using stochatic gradient descent on the training dataset.
n_training_steps = len(training_data) * n_epochs
for i in range(n_training_steps):
feed_dict = {dataset_handle: train_handle}
train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
train_stats.append(train_loss)
if (i+1) % measurement_store_interval == 0:
# Evaluates model on test dataset.
for _ in range(len(test_data)):
feed_dict = {dataset_handle: test_handle}
test_stats.append(session.run(loss_ops, feed_dict=feed_dict))
# Outputs performance statistics on training and test dataset.
_log_stats_and_return_mean_correlation('Train', train_stats)
correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
train_stats = []
test_stats = []
# Updates best model
if correlation > best_so_far:
best_so_far = correlation
if checkpoint_path:
saver.save(session.raw_session(), checkpoint_path)
我想continue.py应该如下所示。如果是这样,但是如何从graph0到get_tensor_by_name获取loss_ops,maximum_ops和best_so_far?我用graph0.get_operations()
检查了所有可用的张量名称,但不明白我应该选择哪个名称。
# Loads train and test dataset.
test_data = load_data(parameters_for_test_data)
training_data = load_data(parameters_for_training_data)
tf.reset_default_graph()
saver = tf.train.import_meta_graph(checkpoint_path + '.meta')
graph0 = tf.get_default_graph()
placeholders = GlassSimulationData(
positions=graph0.get_tensor_by_name('Placeholder:0'),
targets=graph0.get_tensor_by_name('Placeholder_1:0'),
types=graph0.get_tensor_by_name('Placeholder_2:0'),
box=graph0.get_tensor_by_name('Placeholder_3:0'))
dataset_handle = graph0.get_tensor_by_name('Placeholder_4:0')
train_initalizer = graph0.get_operation_by_name('MakeIterator')
test_initalizer = graph0.get_operation_by_name('MakeIterator_1')
train_string_handle = graph0.get_tensor_by_name('IteratorToStringHandle:0')
test_string_handle = graph0.get_tensor_by_name('IteratorToStringHandle_1:0')
loss_ops = graph0.get_tensor_by_name(???) # I don't know the tensor name
minimize_op = graph0.get_tensor_by_name(???) # I don't know the tensor name
best_so_far = graph0.get_tensor_by_name(???) # I don't know the tensor name
train_stats = []
test_stats = []
with tf.train.SingularMonitoredSession() as session:
saver.restore(session, checkpoint_path)
train_handle = session.run(train_string_handle)
test_handle = session.run(test_string_handle)
feed_dict = {p: [x[i] for x in training_data]
for i, p in enumerate(placeholders)}
session.run(train_initalizer, feed_dict=feed_dict)
feed_dict = {p: [x[i] for x in test_data]
for i, p in enumerate(placeholders)}
session.run(test_initializer, feed_dict=feed_dict)
# Trains model on the training dataset.
n_training_steps = len(training_data) * n_epochs
for i in range(n_training_steps):
feed_dict = {dataset_handle: train_handle}
train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
train_stats.append(train_loss)
if (i+1) % measurement_store_interval == 0:
# Evaluates model on test dataset.
for _ in range(len(test_data)):
feed_dict = {dataset_handle: test_handle}
test_stats.append(session.run(loss_ops, feed_dict=feed_dict))
# Outputs performance statistics on training and test dataset.
_log_stats_and_return_mean_correlation('Train', train_stats)
correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
train_stats = []
test_stats = []
# Updates best model
if correlation > best_so_far:
best_so_far = correlation
if new_checkpoint_path:
saver.save(session.raw_session(), new_checkpoint_path)
还是有更好的方法可以在不更改pretrain.py的情况下继续训练? 任何建议表示赞赏。