使用tensorflow.keras.estimator.model_to_estimator时,MNIST的准确率降低了10倍

时间:2019-09-17 11:33:50

标签: tensorflow keras tensorflow-estimator

相同模型,model_to_estimator的准确率比keras模型拟合的准确率低10倍

model_to_estimator:精度= 0.10129999 keras模型拟合:精度= 0.9694

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def dense_model() :
    inputs = keras.Input(shape=(784,), name='input_1')
    x = layers.Dense(64, activation='relu')(inputs)
    x = layers.Dense(32, activation='relu')(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    return keras.Model(inputs = inputs, outputs = outputs, name='mnist_model')

def get_estimator(config) :
    model = dense_model()
    model.summary()
    model.compile(loss='sparse_categorical_crossentropy',
             optimizer=keras.optimizers.RMSprop(),
             metrics=['accuracy'])

    return model, keras.estimator.model_to_estimator(keras_model=model, model_dir='./model/', config=config)

def train_input_fn(features, labels, epochs, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    return dataset.repeat(epochs).batch(batch_size)

def main() :
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    x_train = x_train.reshape(-1, 784).astype('float32')/255
    x_test= x_test.reshape(-1, 784).astype('float32')/255

    config = tf.estimator.RunConfig()
    config.mode_dir = './model/'

    model, estimator = get_estimator(config)
    tf.logging.set_verbosity(tf.logging.INFO)

    estimator.train(input_fn=lambda:train_input_fn({'input_1' : x_train}, y_train, 5, 64)
        , steps=5000)

    estimator.evaluate(input_fn=lambda:train_input_fn({'input_1' : x_test}, y_test, 1, 100))

    ##########

    history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)
    test_scores = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', test_scores[0])
    print('Test accuracy:', test_scores[1])

if __name__ == '__main__':

    main()

INFO:tensorflow:全局步骤4689的保存格:精度= 0.10129999 ,global_step = 4689,损失= 0.106062405

INFO:tensorflow:为全局步骤4689保存'checkpoint_path'摘要:./model/model.ckpt-4689

训练48000个样本,验证12000个样本

测试损失:0.10144322782349773 测试准确性: 0.9694

0 个答案:

没有答案