tf.data.TFRecordDataset.shard影响准确性基准Tensorflow

时间:2018-12-02 05:28:26

标签: python-2.7 tensorflow tensorboard tensorflow-datasets

我注意到,当我对数据集进行分片时,整个训练期间的precision_baseline都保持不变。但是,一旦我拆下分片,precision_baseline就会波动。

对于分片导致精度基线相同的原因,是否有人有任何见解?下面是我使用的功能。

谢谢

def input_fn(filenames, train, batch_size=5, buffer_size=10):
    epoch = None
    if t rain != True:
        epoch = 1

    if run_config.task_type == "ps":
        worker_num   = None
        worker_index = None
    elif run_config.task_type == "master":
        worker_num   = run_config._num_worker_replicas
        worker_index = 0
    else:
        worker_num   = run_config._num_worker_replicas
        worker_index = run_config.task_id + 1

    d = tf.data.TFRecordDataset(filenames=filenames)
    d = d.shard(worker_num,worker_index)
    d = d.repeat(epoch)
    d = d.shuffle(buffer_size)
    d = d.map(parse)
    d = d.batch(batch_size)
    d = d.prefetch(1)

    iterator      = d.make_one_shot_iterator()
    X, label      = iterator.get_next()

0 个答案:

没有答案