我正在为最终项目开发LSTM。我在这里一直关注TensorFlow的教程:https://www.tensorflow.org/tutorials/sequences/text_generation,其中大部分内容,尤其是有关如何保存和加载模型的知识。但是,它会出现以下错误:
回溯(最近通话最近): 文件“ D:\ xxx \ Documents \ Class Coding \ Artificial Intelligence \ Shelley> \ Writerbot.py”,第187行,位于 restore_progress()
restore_progress中的文件“ D:\ xxx \ Documents \ Class Coding \ Artificial Intelligence \ Shelley \ Writerbot.py”,第141行
shelley.load_weights(weights)
文件“ C:\ Users \ xxx \ AppData \ Roaming \ Python \ Python36 \ site-packages \ tensorflow \ python \ keras \ engine \ network.py”,行1508,位于load_weights中
如果_is_hdf5_filepath(filepath):
文件“ C:\ Users \ xxx \ AppData \ Roaming \ Python \ Python36 \ site-packages \ tensorflow \ python \ keras \ engine \ network.py”,行1648,位于_is_hdf5_filepath中
返回filepath.endswith('。h5')或filepath.endswith('。keras')
AttributeError:'NoneType'对象没有属性'endswith'
这是我所讲的与加载和恢复权重有关的代码,据我所知,这是因为错误的其余部分来自keras:
def create_shelley(vocab, embedding, numunits, batch):
"""This is what actually creates a neural network."""
shelley = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab, embedding,
batch_input_shape=[batch, None]),
lstm(numunits,
return_sequences=True,
recurrent_initializer='glorot_uniform',
stateful=True),
tf.keras.layers.Dense(vocab)
])
return shelley
def train():
"""We create weight checkpoints as we train our neural network on files fed into it."""
checkpoints = 'D:\\xxx\\Documents\\Class Coding\\Artificial Intelligence\\Shelley\\trainingcheckpoints'
prefix = os.path.join(checkpoints, "ckpt_{epoch}")
callback=tf.keras.callbacks.ModelCheckpoint(
filepath=prefix,
save_weights_only=True)
print(epochsteps)
history = shelley.fit(botfeed.repeat(), epochs=epochs, steps_per_epoch=epochsteps, callbacks=[callback])
def restore_progress():
"""Load the most recent weight checkpoint."""
trainingcheckpoints = "D:\\Robin Pegau\\Documents\\Class Coding\\Artificial Intelligence\\Shelley\\trainingcheckpoints\\checkpoint"
weights = tf.train.latest_checkpoint(trainingcheckpoints)
shelley = create_shelley(vocab, embed, totalunits, batch = 1)
shelley.load_weights(weights)
shelley.build(tf.TensorShape([1, None]))
restore_progress()
有一个没有文件类型的“检查点”文件。还有一些文件看起来像“ ckpt_ [x] .index”和“ ckpt_ [x] .data-00000-of-00001
谢谢大家的帮助。