如何从Tensorflow1.13.1中的检查点继续训练GNN模型

时间:2020-06-20 19:28:01

标签: tensorflow deep-learning

我想加载预训练的模型并继续使用该模型进行训练。 用于保存模型的标准代码段(pretrain.py):

# Loads train and test dataset.
test_data = load_data(parameters_for_test_data)
training_data = load_data(parameters_for_training_data)

# Defines wrapper functions, which can directly be passed to the
# tf.data.Dataset.map function.
def _make_graph_from_static_structure(static_structure):
  """Converts static structure to graph, targets and types."""
  return (graph_model.make_graph_from_static_structure(
      static_structure.positions,
      static_structure.types,
      static_structure.box,
      edge_threshold),
          static_structure.targets,
          static_structure.types)

def _apply_random_rotation(graph, targets, types):
  """Applies random rotations to the graph and forwards targets and types."""
  return graph_model.apply_random_rotation(graph), targets, types

# Defines data-pipeline based on tf.data.Dataset 
placeholders = GlassSimulationData._make(
    tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0])
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
dataset = dataset.map(_make_graph_from_static_structure)
dataset = dataset.cache()
dataset = dataset.shuffle(400)
# Augments data.
if augment_data_using_rotations:
  dataset = dataset.map(_apply_random_rotation)
dataset = dataset.repeat()
train_iterator = dataset.make_initializable_iterator()

dataset = tf.data.Dataset.from_tensor_slices(placeholders)
dataset = dataset.map(_make_graph_from_static_structure)
dataset = dataset.cache()
dataset = dataset.repeat()
test_iterator = dataset.make_initializable_iterator()

# Creates tensorflow graph.
dataset_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    dataset_handle, train_iterator.output_types, train_iterator.output_shapes)
graph, targets, types = iterator.get_next()

model = graph_model.GraphBasedModel(
    n_recurrences, mlp_sizes, mlp_kwargs)
prediction = model(graph)
# Adds sigmoid activation
def sigmoid_activate(x):
    return 1.0 / ( 1.0 + tf.exp(-x) )
prediction = sigmoid_activate(prediction)

# Defines loss and minimization operations.
loss_ops = get_loss_ops(prediction, targets)
minimize_op = get_minimize_op(loss_ops.binary_crossentrop, learning_rate, grad_clip)

best_so_far = -1
train_stats = []
test_stats = []

saver = tf.train.Saver()

with tf.train.SingularMonitoredSession() as session:
  # Initializes train and test iterators with the training and test datasets.
  # The obtained training and test string-handles can be passed to the
  # dataset_handle placeholder to select the dataset.
  train_handle = session.run(train_iterator.string_handle())
  test_handle = session.run(test_iterator.string_handle())
  feed_dict = {p: [x[i] for x in training_data]
               for i, p in enumerate(placeholders)}
  session.run(train_iterator.initializer, feed_dict=feed_dict)
  feed_dict = {p: [x[i] for x in test_data]
               for i, p in enumerate(placeholders)}
  session.run(test_iterator.initializer, feed_dict=feed_dict)

  # Trains model using stochatic gradient descent on the training dataset.
  n_training_steps = len(training_data) * n_epochs
  for i in range(n_training_steps):
    feed_dict = {dataset_handle: train_handle}
    train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
    train_stats.append(train_loss)

    if (i+1) % measurement_store_interval == 0:
      # Evaluates model on test dataset.
      for _ in range(len(test_data)):
        feed_dict = {dataset_handle: test_handle}
        test_stats.append(session.run(loss_ops, feed_dict=feed_dict))

      # Outputs performance statistics on training and test dataset.
      _log_stats_and_return_mean_correlation('Train', train_stats)
      correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
      train_stats = []
      test_stats = []

      # Updates best model 
      if correlation > best_so_far:
        best_so_far = correlation
        if checkpoint_path:
          saver.save(session.raw_session(), checkpoint_path)

我想continue.py应该如下所示。如果是这样,但是如何从graph0到get_tensor_by_name获取loss_ops,maximum_ops和best_so_far?我用graph0.get_operations()检查了所有可用的张量名称,但不明白我应该选择哪个名称。

# Loads train and test dataset.
test_data = load_data(parameters_for_test_data)
training_data = load_data(parameters_for_training_data)


tf.reset_default_graph()
saver = tf.train.import_meta_graph(checkpoint_path + '.meta')
graph0 = tf.get_default_graph()

placeholders = GlassSimulationData(
        positions=graph0.get_tensor_by_name('Placeholder:0'),
        targets=graph0.get_tensor_by_name('Placeholder_1:0'),
        types=graph0.get_tensor_by_name('Placeholder_2:0'),
        box=graph0.get_tensor_by_name('Placeholder_3:0'))

dataset_handle = graph0.get_tensor_by_name('Placeholder_4:0')
train_initalizer = graph0.get_operation_by_name('MakeIterator')
test_initalizer = graph0.get_operation_by_name('MakeIterator_1')
train_string_handle = graph0.get_tensor_by_name('IteratorToStringHandle:0')
test_string_handle = graph0.get_tensor_by_name('IteratorToStringHandle_1:0')

loss_ops = graph0.get_tensor_by_name(???)  # I don't know the tensor name
minimize_op = graph0.get_tensor_by_name(???)  # I don't know the tensor name

best_so_far = graph0.get_tensor_by_name(???)  # I don't know the tensor name
train_stats = []
test_stats = []

with tf.train.SingularMonitoredSession() as session:
  saver.restore(session, checkpoint_path)
  train_handle = session.run(train_string_handle)
  test_handle = session.run(test_string_handle)

  feed_dict = {p: [x[i] for x in training_data]
               for i, p in enumerate(placeholders)}
  session.run(train_initalizer, feed_dict=feed_dict)
  feed_dict = {p: [x[i] for x in test_data]
               for i, p in enumerate(placeholders)}
  session.run(test_initializer, feed_dict=feed_dict)

  # Trains model on the training dataset.
  n_training_steps = len(training_data) * n_epochs
  for i in range(n_training_steps):
    feed_dict = {dataset_handle: train_handle}
    train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
    train_stats.append(train_loss)

    if (i+1) % measurement_store_interval == 0:
      # Evaluates model on test dataset.
      for _ in range(len(test_data)):
        feed_dict = {dataset_handle: test_handle}
        test_stats.append(session.run(loss_ops, feed_dict=feed_dict))

      # Outputs performance statistics on training and test dataset.
      _log_stats_and_return_mean_correlation('Train', train_stats)
      correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
      train_stats = []
      test_stats = []

      # Updates best model
      if correlation > best_so_far:
        best_so_far = correlation
        if new_checkpoint_path:
          saver.save(session.raw_session(), new_checkpoint_path)

还是有更好的方法可以在不更改pretrain.py的情况下继续训练? 任何建议表示赞赏。

0 个答案:

没有答案