我很困惑tf.estimator和tf.keras模型的训练速度存在差异。例如,在MNIST数据集上进行训练时,使用tf.estimator可以提高x1.5的速度。
对于我自己的自定义数据(1080x1920x3图像),速度差异变得更大(x2.5)。 Tensorflow文档中似乎没有任何相关信息。
def main(_):
x_train, y_train, input_shape = get_input_datasets()
model = get_model(input_shape)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=optimizer,
metrics=['accuracy'])
# Convert tf.keras model to estimator
run_config = tf.estimator.RunConfig(model_dir="./")
estimator_model = tf.keras.estimator.model_to_estimator(model, config=run_config)
estimator_model.train(input_fn=lambda: input_fn(x_train, y_train), steps=100*20)
def main(_):
x_train, y_train, input_shape = get_input_datasets()
model = get_model(input_shape)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=optimizer,
metrics=['accuracy'])
# Train the model with the train dataset.
model.fit(x=input_fn(x_train, y_train), epochs=20, steps_per_epoch=100)
def get_input_datasets(use_bfloat16=False):
# input image dimensions
img_rows, img_cols = 28, 28
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
if tf.keras.backend.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, 10)
return x_train, y_train, input_shape
def input_fn(x_train, y_train, use_bfloat16=False):
cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32
# train dataset
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.repeat()
train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
train_ds = train_ds.batch(256, drop_remainder=True)
return train_ds
def get_model(input_shape):
# Define a CNN model to recognize MNIST digits.
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, kernel_size=(5, 5),
activation='relu',
input_shape=input_shape,
padding="same"))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(tf.keras.layers.Conv2D(64, (5, 5), activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1024, activation='relu'))
model.add(tf.keras.layers.Dropout(0.4))
model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
return model
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./model.ckpt.
INFO:tensorflow:loss = 2.301029, step = 0
INFO:tensorflow:loss = 2.301029, step = 0
INFO:tensorflow:global_step/sec: 70.8963
INFO:tensorflow:global_step/sec: 70.8963
INFO:tensorflow:loss = 2.2951665, step = 100 (1.411 sec)
INFO:tensorflow:loss = 2.2951665, step = 100 (1.411 sec)
INFO:tensorflow:global_step/sec: 72.8193
INFO:tensorflow:global_step/sec: 72.8193
INFO:tensorflow:loss = 2.2687101, step = 200 (1.373 sec)
INFO:tensorflow:loss = 2.2687101, step = 200 (1.373 sec)
INFO:tensorflow:global_step/sec: 72.9338
INFO:tensorflow:global_step/sec: 72.9338
INFO:tensorflow:loss = 2.2557936, step = 300 (1.371 sec)
INFO:tensorflow:loss = 2.2557936, step = 300 (1.371 sec)
平均速度:0.0136s
Epoch 1/20
100/100 [==============================] - 3s 28ms/step - loss: 2.3003 - acc: 0.1002
Epoch 2/20
100/100 [==============================] - 2s 20ms/step - loss: 2.2842 - acc: 0.1329
Epoch 3/20
100/100 [==============================] - 2s 20ms/step - loss: 2.2691 - acc: 0.1694
Epoch 4/20
100/100 [==============================] - 2s 19ms/step - loss: 2.2536 - acc: 0.2091
Epoch 5/20
100/100 [==============================] - 2s 19ms/step - loss: 2.2325 - acc: 0.2601
Epoch 6/20
100/100 [==============================] - 2s 20ms/step - loss: 2.2122 - acc: 0.3142
Epoch 7/20
100/100 [==============================] - 2s 20ms/step - loss: 2.1829 - acc: 0.3760
Epoch 8/20
100/100 [==============================] - 2s 20ms/step - loss: 2.1488 - acc: 0.4144
平均速度:0.02秒