tf.train.batch()和tf.data.Datasets.from_tensor_slices.batch()有什么区别?

时间:2018-11-24 02:06:39

标签: python tensorflow

最近,我尝试使用ENAS代码在自己的数据集上自动设计网络。

代码首先在main.py中将数据加载为numpy,然后例如将数据分配给model.py

# main.py
images, labels = read_data(path)

然后在model.py中按如下所示初始化self.x_trainself.y_train

# model.py
class Model(object):
    ...
    with tf.device("/cpu:0"):
    # training data
    self.num_train_examples = np.shape(images["train"])[0]
    self.num_train_batches = (
        self.num_train_examples + self.batch_size - 1) // self.batch_size

    x_train, y_train = tf.train.shuffle_batch(
        [images["train"], labels["train"]], # images['train'] and labels['train'] are both numpy。array
        batch_size=self.batch_size,
        capacity=50000,
        enqueue_many=True,
        num_threads=16,
        allow_smaller_final_batch=True,
    )

然后在main.py中,运行图的部分如下:

# main.py
with tf.train.SingularMonitoredSession(
        config=config, hooks=hooks, checkpoint_dir=FLAGS.output_dir) as sess:
    start_time = time.time()
    while True:
        #####################################
        ######  calculate child ops  ########
        #####################################

        run_ops = [
            child_ops["loss"],
            child_ops["lr"],
            child_ops["grad_norm"],
            child_ops["train_acc"],
            child_ops["train_op"],
        ]
        loss, lr, gn, tr_acc, _ = sess.run(run_ops)
        global_step = sess.run(child_ops["global_step"])
        print(sess.run(child_ops['y_train']))
        if FLAGS.child_sync_replicas:
            actual_step = global_step * FLAGS.num_aggregate
        else:
            actual_step = global_step
        epoch = actual_step // ops["num_train_batches"] # ops["num_train_batches"] 
        print('Epoch:{}, step:{}'.format(epoch, actual_step))
        curr_time = time.time()

让我感到困惑的是,代码没有定义操作,例如self.x_train_next=self.x_train.get_next()tf.train.Coordinator()来加载任何.py文件中的下一个iter数据。

以下是我的问题:

1。tf.train.shuffle_batch是否会自动加载下一批?

2。tf.train.batch()tf.data.Datasets.from_tensor_slices.batch()有什么区别?

3。原始代码使用CIFAR10,当我尝试使用自己的数据集时,图像大小只能设置为小于160 * 160,否则将提高ValueError: GraphDef cannot be larger than 2GB。我曾尝试使用占位符或TFRecord加载数据,但是我不知道何时加载下一个批处理数据,所以我不知道如何更改代码。那么有什么建议加载数据吗?

非常感谢!

0 个答案:

没有答案