模块'tensorflow._api.v2.train'没有属性'write_graph'

时间:2020-09-21 17:16:00

标签: python python-3.x tensorflow deep-learning tensorflow2.0

当我看到模块'tensorflow._api.v2.train'有很多问题时,我找不到能解决我问题的东西。这里有什么想法吗? TensorFlow的2.0版本有什么问题吗?我是否应该使用其他版本,例如1.15

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Flatten, LSTM
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
model = Sequential()
model.add(LSTM(64, return_sequences=True, recurrent_regularizer=l2(0.0015), input_shape=(timesteps, input_dim)))
# model.add(Dropout(0.5))
model.add(LSTM(64, recurrent_regularizer=l2(0.0015), input_shape=(timesteps,input_dim)))


model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))

model.add(Dense(n_classes, activation='softmax'))
model.summary()

model.compile(optimizer=tf.optimizers.Adam(learning_rate = 0.0025), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

model.fit(X_train, y_train, batch_size=32, epochs=1)

model.save('model1.h5')
print("Saved model to disk")

from tensorflow.keras.models import load_model
model=load_model('model1.h5')
model.summary()

from keras import backend as k
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib
import tensorflow as tf

input_node_name = ['lstm_1_input']
output_node_name = 'output/Softmax'
model_name='har_model

tf.train.write_graph(k.get_session().graph_def, 'models', model_name + '_graph.pbtxt')
saver=tf.train.Saver()
saver.save(k.get_session(), 'models/'+model_name + '.chkp')

freeze_graph.freeze_graph('models/'+model_name + '_graph.pbtxt', None, False, 'models/'+model_name+'.chkp',
                         output_node_name, 'save/restore_all', 'save/Const:0', 'models/frozen_' + model_name + '.pb', True, "")

0 个答案:

没有答案