如何在“不同功能”中使用tf.dataset馈送训练和验证数据?

时间:2019-07-07 15:49:03

标签: tensorflow-datasets

我正在尝试创建一个函数来创建tf.dataset,并在另一个“火车”中 函数来使用这些数据集输入模型,但是我不知道如何在不同的函数中输入这些数据

假设我们拥有训练和验证数据(数字,高度,宽度,通道)

training_x = np.arange(500).reshape(20,5,5,1)
training_y = np.arange(200,700).reshape(20,5,5,1)

val_x = np.arange(100,300).reshape(8,5,5,1)
val_y = np.arange(400,600).reshape(8,5,5,1)

和另一个使用tf.datasets的get_batch_data函数

def get_batch_data():
    # placeholder is to feed the training and validation data
    input_x = tf.placeholder(tf.float32, shape=[None, height, width, channel])
    input_y = tf.placeholder(tf.float32, shape=[None, height, width, channel])

    dataset = tf.data.Dataset.from_tensor_slices((input_x, input_y))
    dataset = dataset.shuffle(buffer_size = 5)
    dataset = dataset.batch(2)

    iterator = dataset.make_initializable_iterator()
    image, label = iterator.get_next()

    return image, label, iterator

和一个简单的模型函数

def model(input):

    conv1 = tf.layers.conv2d(input, filters = 3 ,kernel_size = [3,3], strides = (1,1), padding = 'same')

    return conv1

我们将使用数据集在“训练”功能中输入模型

def train():

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    image, label, iterator = get_batch_data()

    for epoc in range(10):

        # for training data
        sess.run(iterator.initializer, feed_dict={} )

        #show some evaluation

        # for validation data
        sess.run(iterator.initializer, feed_dict={} )

        # show some evaluation

我确实不知道如何在“ train”函数中提供这些数据,但是如果我将tf.datasets和“ train”写入同一函数中,我就可以做到。只想输入更优雅的

非常感谢

0 个答案:

没有答案