不要为Estimator的最后一步保存检查点

时间:2017-03-03 22:01:31

标签: tensorflow

我使用Estimator并在循环中训练模型以提供数据。每一步都是最后一步。每个最后一步都会保存检查点。我想避免在每次迭代中保存检查点以提高训练的性能(速度)。 我找不到任何有关如何执行此操作的信息。您有任何想法/建议/解决方案吗?

classifier = Estimator(
    model_fn=cnn_model_fn,
    model_dir="./temp_model_Adam",
    config=tf.contrib.learn.RunConfig(
        save_checkpoints_secs=None,
        save_checkpoints_steps=100,
        save_summary_steps=None
    )
)



# Train the model

for e in range(0, 10):
    numbers = np.arange(10000)
    np.random.shuffle(numbers)
    for step in range(0, 2000):
        classifier.fit(
            input_fn=lambda: read_images_for_training_as_batch(step, path, 5, numbers),
            steps=1
        )

1 个答案:

答案 0 :(得分:0)

现在api发生了一些变化,但是从我看到你错误地使用了fit(当前训练)方法,你应该把steps = 2000并让你的输入函数返回你的数据集上的迭代器。今天你有tf.estimator.inputs.numpy_input_fn随你可以帮助你拥有小数据集,否则你必须使用tf.data.DataSet api。

像这样(它加载.wav文件):

from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.python.ops import io_ops
# ...
def input_fn(num_epochs, batch_size, shuffle=False, mode='training')

    def input_fn_bound():
        def _read_file(fn, label):
            return io_ops.read_file(fn), label

        def _decode(data, label):
            pcm = contrib_audio.decode_wav(data,
                                           desired_channels=1,
                                           desired_samples=desired_samples)
            return pcm.audio, label

        filenames = get_files(mode)
        classes = get_classes(mode)
        labels = {'class': np.array(classes)}
        dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

        if shuffle:
            dataset = dataset.shuffle(buffer_size=len(labels))
        dataset = dataset.map(_read_file, num_parallel_calls=num_map_threads)
        dataset = dataset.map(_decode, num_parallel_calls=num_map_threads)
        dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))

        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(2)        # To load next batch while the first one is being processed on GPU
        iter = dataset.make_one_shot_iterator()
        features, labels = iter.get_next()
        return features, labels

    return input_fn_bound

# ....

estimator.train(input_fn=input_fn(
        num_epoths=None, 
        batch_size=64,
        shuffle=True,
        mode='training'), 
    steps=10000)