使用Tensorflow的tf.data API为培训和验证设置不同的批处理大小

时间:2018-05-05 14:24:57

标签: python tensorflow

我正在从较旧的基于队列的数据管道迁移到较新的tf.data API。假设我有如下代码,我如何为我的训练和验证迭代器显式设置不同的批量大小。

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # Parse the record into tensors.
dataset = dataset.repeat()  # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()



# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", 
"/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: 
training_filenames})

# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: 
validation_filenames})

编辑:

谢谢。根据回复,我的实施如下:  我的实现如下,但我无法弄清楚为什么我收到此错误:

import tensorflow as tf


def _parse(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string)
  image_resized = tf.image.resize_images(image_decoded, [224, 224])
  image_resized.set_shape([224,224,3])
  return image_resized, label


def input_pipeline(imglist,labellist, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((imglist, labellist))
    dataset = dataset.map(_parse)  # Parse the record into tensors.
    dataset = dataset.repeat()  # Repeat the input indefinitely.
    dataset = dataset.batch(batch_size)
    return dataset


imglist = glob.glob('/var/temp/*.jpg')
train_imgs=imglist[0:100]
train_labels = [i for i in range(100)]

val_imgs=imglist[200:250]
val_labels = [i for i in range(50)]
training_batch_size = 4

validation_batch_size = 1

training_ds = input_pipeline(train_imgs, train_labels, training_batch_size)
validation_ds = input_pipeline(val_imgs, val_labels, validation_batch_size)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
            handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()

train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

     # Define training and validation handlers
    training_handle = sess.run(train_iter.string_handle())
    validation_handle = sess.run(val_iter.string_handle())

    # Initialize training and validation dataset
    sess.run(train_iter)
    sess.run(val_iter)


    # If we use training_handle, then input_batch tensor comes from training tfrecords
    training_batch = sess.run(input_batch, feed_dict={handle: training_handle})

    # If we use validation_handle, then input_batch tensor comes from validation tfrecords
    validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})

但我最终得到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    281         self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 282             fetch, allow_tensor=True, allow_operation=True))
    283       except TypeError as e:

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3589     with self._lock:
-> 3590       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3591 

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3678       raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
-> 3679                                                            types_str))
   3680 

TypeError: Can not convert a Iterator into a Tensor or Operation.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-31-50c4f3464d03> in <module>()
     47 
     48     # Initialize training and validation dataset
---> 49     sess.run(train_iter)
     50     sess.run(val_iter)
     51 

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    898     try:
    899       result = self._run(None, fetches, feed_dict, options_ptr,
--> 900                          run_metadata_ptr)
    901       if run_metadata:
    902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     # Create a fetch handler to take care of the structure of fetches.
   1119     fetch_handler = _FetchHandler(
-> 1120         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1121 
   1122     # Run request and get response.

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    425     """
    426     with graph.as_default():
--> 427       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    428     self._fetches = []
    429     self._targets = []

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    251         if isinstance(fetch, tensor_type):
    252           fetches, contraction_fn = fetch_fn(fetch)
--> 253           return _ElementFetchMapper(fetches, contraction_fn)
    254     # Did not find anything.
    255     raise TypeError('Fetch argument %r has invalid type %r' % (fetch,

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    284         raise TypeError('Fetch argument %r has invalid type %r, '
    285                         'must be a string or Tensor. (%s)' %
--> 286                         (fetch, type(fetch), str(e)))
    287       except ValueError as e:
    288         raise ValueError('Fetch argument %r cannot be interpreted as a '

TypeError: Fetch argument <tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fa2c0697c88> has invalid type <class 'tensorflow.python.data.ops.iterator_ops.Iterator'>, must be a string or Tensor. (Can not convert a Iterator into a Tensor or Operation.)

2 个答案:

答案 0 :(得分:6)

我会创建2 tf.data.Dataset,一个用于培训,一个用于验证子集。一旦定义了两个数据集管道(您可以定义2个不同的批处理大小),您可以通过创建一个带有处理程序的tf.data.Iterator来加入图表(在我的例子中,tf.placeholder handle)。

import tensorflow as tf


def input_pipeline(filenames, batch_size):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(...)  # Parse the record into tensors.
    dataset = dataset.repeat()  # Repeat the input indefinitely.
    dataset = dataset.batch(batch_size)
    return dataset


training_filenames = ["/var/data/file1.tfrecord",
                      "/var/data/file2.tfrecord"]
training_batch_size = 32
validation_filenames = ["/var/data/validation1.tfrecord",
                        "/var/data/validation2.tfrecord"]
validation_batch_size = 16

training_ds = input_pipeline(training_filenames, training_batch_size)
validation_ds = input_pipeline(validation_filenames, validation_batch_size)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
            handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()

在从两个数据集中的任何一个请求批处理之前,您可以使用string_handle()从每个数据集中获取相应的处理程序。之后,当您运行input_batch时,您可以通过在handle占位符上对其进行定义来确定它是来自培训还是验证。

train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # Define training and validation handlers
    training_handle = sess.run(train_iter.string_handle())
    validation_handle = sess.run(val_iter.string_handle())

    # Initialize training and validation dataset
    sess.run(train_iter.initializer)
    sess.run(val_iter.initializer)

    # If we use training_handle, then input_batch tensor comes from training tfrecords
    trainaing_batch = sess.run(input_batch, feed_dict={handle: training_handle})

    # If we use validation_handle, then input_batch tensor comes from validation tfrecords
    validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})

希望它有所帮助!

编辑: 您当前的错误似乎是由于尝试在sess.run()上执行tf.data.Iterator。尝试将sess.run(train_iter)替换为sess.run(train_iter.initializer)(对于验证迭代器也是如此)。 train_iter.initializer是初始化tf.Operation迭代器的train_iter。现在一切都应该有效。

答案 1 :(得分:0)

需要进行轻微修改才能得到正确答案:

import tensorflow as tf

imglist = glob.glob('/var/temp/*.jpg')
train_imgs=imglist[0:100]
train_labels = [i for i in range(100)]

val_imgs=imglist[200:250]
val_labels = [i for i in range(50)]

training_ds = tf.data.Dataset.from_tensor_slices((train_imgs,train_labels)).batch(4)
validation_ds = tf.data.Dataset.from_tensor_slices((val_imgs,val_labels)).batch(1)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
            handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()

train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()



with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())


     # Define training and validation handlers
    training_handle = sess.run(train_iter.string_handle())
    validation_handle = sess.run(val_iter.string_handle())



    sess.run(train_iter.initializer)
    # If we use training_handle, then input_batch tensor comes from training tfrecords
    training_batch = sess.run(input_batch, feed_dict={handle: training_handle})
    print("Training...")
    print(training_batch)

    sess.run(val_iter.initializer)
    # If we use validation_handle, then input_batch tensor comes from validation tfrecords
    print("Validation")
    validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
    print(validation_batch)