如何在急切执行模式下使用tensorflow加载检查点?

时间:2018-04-23 20:53:02

标签: tensorflow eager

我在急切执行模式下使用tensorflow 1.7.0。我有模型工作,但我找到的用于保存模型的示例都没有。

这是我正在使用的代码:

checkpoint_directory ='./JokeWords/'
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tfe.Checkpoint(model=model,optimizer=optimizer)  # save as "x"
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
evaluate(model,jokes,2,32)
....
checkpoint.save(file_prefix=checkpoint_prefix)

我已经训练了模型并使用evaluate来检查从重新启动加载时的结果。每次我从evaluate获得一个随机结果时,意味着模型没有从数据中加载,而是只有随机权重。

如何保存模型?培训其中一个可能需要几天时间。

编辑。这是模型:

class EagerRNN(tfe.Network):
  def __init__(self,embedding, hidden_dim, num_layers, keep_ratio):
    super(EagerRNN, self).__init__()
    self.keep_ratio = keep_ratio
    self.cells = self._add_cells([
        tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
        for _ in range(num_layers)
    ])

    self.backcells = self._add_cells([
        tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
        for _ in range(num_layers)
    ])
    self.linear = layers.Dense(embedding. vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
    self.backlinear = layers.Dense(embedding. vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
    self.attension = layers.Dense(hidden_dim, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))

  def call(self, input_seq,seq_lengths, training):

    lengths=[i[0] for i in seq_lengths]

    nRotations=max(lengths)
    batchSize=input_seq.shape[0]

    input_seq2 = tf.unstack(input_seq, num=int(input_seq.shape[1]), axis=1)
    atten = None

    state = self.cells[0].zero_state(batchSize, tf.float32)

    for i in range(0,nRotations):
        for j in range(0,len(self.cells)):
            c=self.cells[j]
            inp=input_seq2[i]
            output, state = c(inp, state)
            #input_seq2[i]=(output)
            if atten==None:
                atten =self.linear(output)
            else:
                atten=atten+self.linear(output)


    for i in range(nRotations-1,-1,-1):
        for j in range(0,len(self.backcells)):
            c=self.backcells[j]
            inp=input_seq2[i]
            output, state = c(inp, state)
            #input_seq2[i]=(output)
            atten=atten+self.backlinear(output)

    #input_seq = tf.stack(input_seq2[0:nRotations], axis=1)

    atten=self.attension(atten)
    if training:
       input_seq = tf.nn.dropout(input_seq, self.keep_ratio)
    # Returning a list instead of a single tensor so that the line:
    # y = self.rnn(y, ...)[0]
    # in PTBModel.call works for both this RNN and CudnnLSTM (which returns a
    # tuple (output, output_states).
    return input_seq,state,atten

  def _add_cells(self, cells):
    # "Magic" required for keras.Model classes to track all the variables in
    # a list of Layer objects.
    # TODO(ashankar): Figure out API so user code doesn't have to do this.
    for i, c in enumerate(cells):
      setattr(self, "cell-%d" % i, c)
    return cells

class EagerLSTM_Model(tfe.Network):
  """LSTM for word language modeling.
  Model described in:
  (Zaremba, et. al.) Recurrent Neural Network Regularization
  http://arxiv.org/abs/1409.2329
  See also:
  https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
  """

  def __init__(self,
               embedding,
               hidden_dim,
               num_layers,
               dropout_ratio,
               use_cudnn_rnn=True):
    super(EagerLSTM_Model, self).__init__()

    self.keep_ratio = 1 - dropout_ratio
    self.use_cudnn_rnn = use_cudnn_rnn
    self.embedding = embedding

    if self.use_cudnn_rnn:
      self.rnn = cudnn_rnn.CudnnLSTM(
          num_layers, hidden_dim, dropout=dropout_ratio)
    else:
      self.rnn = EagerRNN(embedding,hidden_dim, num_layers, self.keep_ratio)

    self.unrnn = EagerUnRNN(embedding,hidden_dim, num_layers, self.keep_ratio)

  def callRNN(self, input_seq,seq_lengths, training):

    y = self.embedding.callbatchword(input_seq)
    if training:
      y = tf.nn.dropout(y, self.keep_ratio)

    y,state,atten = self.rnn.call(y,seq_lengths, training=training)

    return state,atten

  def callUnRNN  (self,state,atten,seq_lengths, training ):
    x,state = self.unrnn(state,atten,seq_lengths,training=training)
    #b=tf.reshape(y, self._output_shape)
    #c=self.linear(b)
    return x

1 个答案:

答案 0 :(得分:1)

tfe.Network不是(轻松)Checkpointable,很快就会被弃用。更喜欢子类tf.Keras.Model。因此,如果您将class EagerRNN(tfe.Network)更改为class EagerRNN(tf.keras.Model)而将class EagerLSTM_Model(tfe.Network)更改为class EagerLSTM_Model(tf.keras.Model),则checkpoint.save(file_prefix=checkpoint_prefix)应该实际保存所有变量,checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))应该将其恢复。< / p>