每个纪元后设置TensorFlow tf.data处理dev

时间:2019-05-19 22:04:39

标签: python tensorflow machine-learning tensorflow-datasets

batch_size = 2
x_dim = 2
m = 5
m_dev = 4
epochs = 2

# Toy data
X_train = np.random.randn(m, x_dim)
Y_train = np.random.randint(0, 5, size=m).reshape(-1, 1)
X_dev = np.random.randn(m_dev, x_dim)
Y_dev = np.random.randint(0, 5, size=m_dev).reshape(-1, 1)

X = tf.placeholder(X_train.dtype, shape=[None, x_dim], name='X')
Y = tf.placeholder(Y_train.dtype, shape=[None, 1], name='Y')

# Create two separate datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X, Y)).batch(batch_size)
dev_dataset = tf.data.Dataset.from_tensor_slices((X, Y)).batch(X_dev.shape[0])

# Create a generic Iterator
iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                           train_dataset.output_shapes)

# Create two init ops
train_init_op = iterator.make_initializer(train_dataset)
dev_init_op = iterator.make_initializer(dev_dataset)

next_data = iterator.get_next()

with tf.Session() as sess:
    for epoch in range(epochs):
        # Training data
        sess.run(train_init_op, feed_dict={X: X_train, Y: Y_train})
        while True:
            try:
                X_batch, Y_batch = sess.run(next_data)
                # process training data
            except tf.errors.OutOfRangeError:
                break

        # Epoch done: process the dev data
        sess.run(dev_init_op, feed_dict={X: X_dev, Y: Y_dev})
        X_dev_all, Y_dev_all = sess.run(next_data)

我正在使用tf.data和可重新初始化的迭代器来处理训练和开发集数据。对于每个时期,我都会初始化训练数据集。 The official documentation具有相似的结构。我认为这不是很有效,尤其是在训练量很大的情况下。

我在网上发现的一些资源在for循环之前有sess.run(train_init_op, feed_dict={X: X_train, Y: Y_train}),以避免出现此问题。但是,那么我们就无法在每个时期之后处理开发集。我们只有在完成epochs个时期的迭代之后才能对其进行处理。

是否有一种方法可以在每个时期之后有效地处理开发集?

0 个答案:

没有答案