当我看到模块'tensorflow._api.v2.train
'有很多问题时,我找不到能解决我问题的东西。这里有什么想法吗? TensorFlow的2.0
版本有什么问题吗?我是否应该使用其他版本,例如1.15
?
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Flatten, LSTM
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
model = Sequential()
model.add(LSTM(64, return_sequences=True, recurrent_regularizer=l2(0.0015), input_shape=(timesteps, input_dim)))
# model.add(Dropout(0.5))
model.add(LSTM(64, recurrent_regularizer=l2(0.0015), input_shape=(timesteps,input_dim)))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(n_classes, activation='softmax'))
model.summary()
model.compile(optimizer=tf.optimizers.Adam(learning_rate = 0.0025), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
model.fit(X_train, y_train, batch_size=32, epochs=1)
model.save('model1.h5')
print("Saved model to disk")
from tensorflow.keras.models import load_model
model=load_model('model1.h5')
model.summary()
from keras import backend as k
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib
import tensorflow as tf
input_node_name = ['lstm_1_input']
output_node_name = 'output/Softmax'
model_name='har_model
tf.train.write_graph(k.get_session().graph_def, 'models', model_name + '_graph.pbtxt')
saver=tf.train.Saver()
saver.save(k.get_session(), 'models/'+model_name + '.chkp')
freeze_graph.freeze_graph('models/'+model_name + '_graph.pbtxt', None, False, 'models/'+model_name+'.chkp',
output_node_name, 'save/restore_all', 'save/Const:0', 'models/frozen_' + model_name + '.pb', True, "")