我在急切执行模式下使用tensorflow 1.7.0。我有模型工作,但我找到的用于保存模型的示例都没有。
这是我正在使用的代码:
checkpoint_directory ='./JokeWords/'
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint = tfe.Checkpoint(model=model,optimizer=optimizer) # save as "x"
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
evaluate(model,jokes,2,32)
....
checkpoint.save(file_prefix=checkpoint_prefix)
我已经训练了模型并使用evaluate来检查从重新启动加载时的结果。每次我从evaluate获得一个随机结果时,意味着模型没有从数据中加载,而是只有随机权重。
如何保存模型?培训其中一个可能需要几天时间。
编辑。这是模型:
class EagerRNN(tfe.Network):
def __init__(self,embedding, hidden_dim, num_layers, keep_ratio):
super(EagerRNN, self).__init__()
self.keep_ratio = keep_ratio
self.cells = self._add_cells([
tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
for _ in range(num_layers)
])
self.backcells = self._add_cells([
tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
for _ in range(num_layers)
])
self.linear = layers.Dense(embedding. vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
self.backlinear = layers.Dense(embedding. vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
self.attension = layers.Dense(hidden_dim, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
def call(self, input_seq,seq_lengths, training):
lengths=[i[0] for i in seq_lengths]
nRotations=max(lengths)
batchSize=input_seq.shape[0]
input_seq2 = tf.unstack(input_seq, num=int(input_seq.shape[1]), axis=1)
atten = None
state = self.cells[0].zero_state(batchSize, tf.float32)
for i in range(0,nRotations):
for j in range(0,len(self.cells)):
c=self.cells[j]
inp=input_seq2[i]
output, state = c(inp, state)
#input_seq2[i]=(output)
if atten==None:
atten =self.linear(output)
else:
atten=atten+self.linear(output)
for i in range(nRotations-1,-1,-1):
for j in range(0,len(self.backcells)):
c=self.backcells[j]
inp=input_seq2[i]
output, state = c(inp, state)
#input_seq2[i]=(output)
atten=atten+self.backlinear(output)
#input_seq = tf.stack(input_seq2[0:nRotations], axis=1)
atten=self.attension(atten)
if training:
input_seq = tf.nn.dropout(input_seq, self.keep_ratio)
# Returning a list instead of a single tensor so that the line:
# y = self.rnn(y, ...)[0]
# in PTBModel.call works for both this RNN and CudnnLSTM (which returns a
# tuple (output, output_states).
return input_seq,state,atten
def _add_cells(self, cells):
# "Magic" required for keras.Model classes to track all the variables in
# a list of Layer objects.
# TODO(ashankar): Figure out API so user code doesn't have to do this.
for i, c in enumerate(cells):
setattr(self, "cell-%d" % i, c)
return cells
class EagerLSTM_Model(tfe.Network):
"""LSTM for word language modeling.
Model described in:
(Zaremba, et. al.) Recurrent Neural Network Regularization
http://arxiv.org/abs/1409.2329
See also:
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
"""
def __init__(self,
embedding,
hidden_dim,
num_layers,
dropout_ratio,
use_cudnn_rnn=True):
super(EagerLSTM_Model, self).__init__()
self.keep_ratio = 1 - dropout_ratio
self.use_cudnn_rnn = use_cudnn_rnn
self.embedding = embedding
if self.use_cudnn_rnn:
self.rnn = cudnn_rnn.CudnnLSTM(
num_layers, hidden_dim, dropout=dropout_ratio)
else:
self.rnn = EagerRNN(embedding,hidden_dim, num_layers, self.keep_ratio)
self.unrnn = EagerUnRNN(embedding,hidden_dim, num_layers, self.keep_ratio)
def callRNN(self, input_seq,seq_lengths, training):
y = self.embedding.callbatchword(input_seq)
if training:
y = tf.nn.dropout(y, self.keep_ratio)
y,state,atten = self.rnn.call(y,seq_lengths, training=training)
return state,atten
def callUnRNN (self,state,atten,seq_lengths, training ):
x,state = self.unrnn(state,atten,seq_lengths,training=training)
#b=tf.reshape(y, self._output_shape)
#c=self.linear(b)
return x
答案 0 :(得分:1)
tfe.Network
不是(轻松)Checkpointable,很快就会被弃用。更喜欢子类tf.Keras.Model
。因此,如果您将class EagerRNN(tfe.Network)
更改为class EagerRNN(tf.keras.Model)
而将class EagerLSTM_Model(tfe.Network)
更改为class EagerLSTM_Model(tf.keras.Model)
,则checkpoint.save(file_prefix=checkpoint_prefix)
应该实际保存所有变量,checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
应该将其恢复。< / p>