加载预先训练的keras模型,以便在Google云上继续进行训练

时间:2020-01-27 09:07:04

标签: python keras google-cloud-platform google-cloud-ai

我正在尝试加载预先训练的Keras模型,以便在Google云上继续进行训练。只需在

上加载鉴别器和生成器,即可在本地运行
 model = load_model('myPretrainedModel.h5')

但是显然,这在Google Cloud上不起作用,我尝试使用与从Google存储桶中读取训练数据相同的方法,

fil = "gs://mygcbucket/myPretrainedModel.h5"    
f = BytesIO(file_io.read_file_to_string(fil, binary_mode=True))
return np.load(f)

但是,这似乎不适用于加载模型,但在运行作业时出现以下错误。

ValueError:allow_pickle = False时,无法加载包含腌制数据的文件

添加allow_pickle=True会引发另一个错误:

OSError:无法将文件0x7fdf2bb42620>的<_io.BytesIO对象解释为泡菜

然后我尝试了我发现的类似问题的建议,因为我了解它可以暂时从存储桶中本地(相对于作业的运行位置)保存模型,然后使用以下方式加载模型:

fil = "gs://mygcbucket/myPretrainedModel.h5"  
model_file = file_io.FileIO(fil, mode='rb')
file_stream = file_io.FileIO(model_file, mode='r')
temp_model_location = './temp_model.h5'
temp_model_file = open(temp_model_location, 'wb')
temp_model_file.write(file_stream.read())
temp_model_file.close()
file_stream.close()
model = load_model(temp_model_location)
return model

但是,这会引发以下错误:

TypeError:预期的二进制或Unicode字符串,得到tensorflow.python.lib.io.file_io.FileIO对象

我必须承认,我不太确定从存储桶中实际加载经过预先​​训练的keras模型所需的操作以及在Google Cloud的培训工作中的使用情况。任何帮助深表感谢。

1 个答案:

答案 0 :(得分:0)

我建议使用AI Platform Notebooks这样做。使用this method下载经过训练的模型。检查“代码示例”选项卡下的Python代码。将模型放在运行Notebook的VM中后,就可以像在本地进行加载一样加载它。 Here,您有一个使用 tf.keras.models.load_model 的示例。