我正在从较旧的基于队列的数据管道迁移到较新的tf.data
API。假设我有如下代码,我如何为我的训练和验证迭代器显式设置不同的批量大小。
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord",
"/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames:
training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames:
validation_filenames})
编辑:
谢谢。根据回复,我的实施如下: 我的实现如下,但我无法弄清楚为什么我收到此错误:
import tensorflow as tf
def _parse(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [224, 224])
image_resized.set_shape([224,224,3])
return image_resized, label
def input_pipeline(imglist,labellist, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((imglist, labellist))
dataset = dataset.map(_parse) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size)
return dataset
imglist = glob.glob('/var/temp/*.jpg')
train_imgs=imglist[0:100]
train_labels = [i for i in range(100)]
val_imgs=imglist[200:250]
val_labels = [i for i in range(50)]
training_batch_size = 4
validation_batch_size = 1
training_ds = input_pipeline(train_imgs, train_labels, training_batch_size)
validation_ds = input_pipeline(val_imgs, val_labels, validation_batch_size)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()
train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
# Initialize training and validation dataset
sess.run(train_iter)
sess.run(val_iter)
# If we use training_handle, then input_batch tensor comes from training tfrecords
training_batch = sess.run(input_batch, feed_dict={handle: training_handle})
# If we use validation_handle, then input_batch tensor comes from validation tfrecords
validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
但我最终得到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
281 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 282 fetch, allow_tensor=True, allow_operation=True))
283 except TypeError as e:
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3589 with self._lock:
-> 3590 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3591
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3678 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
-> 3679 types_str))
3680
TypeError: Can not convert a Iterator into a Tensor or Operation.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
<ipython-input-31-50c4f3464d03> in <module>()
47
48 # Initialize training and validation dataset
---> 49 sess.run(train_iter)
50 sess.run(val_iter)
51
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
898 try:
899 result = self._run(None, fetches, feed_dict, options_ptr,
--> 900 run_metadata_ptr)
901 if run_metadata:
902 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 # Create a fetch handler to take care of the structure of fetches.
1119 fetch_handler = _FetchHandler(
-> 1120 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1121
1122 # Run request and get response.
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
251 if isinstance(fetch, tensor_type):
252 fetches, contraction_fn = fetch_fn(fetch)
--> 253 return _ElementFetchMapper(fetches, contraction_fn)
254 # Did not find anything.
255 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
284 raise TypeError('Fetch argument %r has invalid type %r, '
285 'must be a string or Tensor. (%s)' %
--> 286 (fetch, type(fetch), str(e)))
287 except ValueError as e:
288 raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument <tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fa2c0697c88> has invalid type <class 'tensorflow.python.data.ops.iterator_ops.Iterator'>, must be a string or Tensor. (Can not convert a Iterator into a Tensor or Operation.)
答案 0 :(得分:6)
我会创建2 tf.data.Dataset
,一个用于培训,一个用于验证子集。一旦定义了两个数据集管道(您可以定义2个不同的批处理大小),您可以通过创建一个带有处理程序的tf.data.Iterator
来加入图表(在我的例子中,tf.placeholder
handle
)。
import tensorflow as tf
def input_pipeline(filenames, batch_size):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size)
return dataset
training_filenames = ["/var/data/file1.tfrecord",
"/var/data/file2.tfrecord"]
training_batch_size = 32
validation_filenames = ["/var/data/validation1.tfrecord",
"/var/data/validation2.tfrecord"]
validation_batch_size = 16
training_ds = input_pipeline(training_filenames, training_batch_size)
validation_ds = input_pipeline(validation_filenames, validation_batch_size)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()
在从两个数据集中的任何一个请求批处理之前,您可以使用string_handle()
从每个数据集中获取相应的处理程序。之后,当您运行input_batch
时,您可以通过在handle
占位符上对其进行定义来确定它是来自培训还是验证。
train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
# Initialize training and validation dataset
sess.run(train_iter.initializer)
sess.run(val_iter.initializer)
# If we use training_handle, then input_batch tensor comes from training tfrecords
trainaing_batch = sess.run(input_batch, feed_dict={handle: training_handle})
# If we use validation_handle, then input_batch tensor comes from validation tfrecords
validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
希望它有所帮助!
编辑:
您当前的错误似乎是由于尝试在sess.run()
上执行tf.data.Iterator
。尝试将sess.run(train_iter)
替换为sess.run(train_iter.initializer)
(对于验证迭代器也是如此)。 train_iter.initializer
是初始化tf.Operation
迭代器的train_iter
。现在一切都应该有效。
答案 1 :(得分:0)
需要进行轻微修改才能得到正确答案:
import tensorflow as tf
imglist = glob.glob('/var/temp/*.jpg')
train_imgs=imglist[0:100]
train_labels = [i for i in range(100)]
val_imgs=imglist[200:250]
val_labels = [i for i in range(50)]
training_ds = tf.data.Dataset.from_tensor_slices((train_imgs,train_labels)).batch(4)
validation_ds = tf.data.Dataset.from_tensor_slices((val_imgs,val_labels)).batch(1)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()
train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
sess.run(train_iter.initializer)
# If we use training_handle, then input_batch tensor comes from training tfrecords
training_batch = sess.run(input_batch, feed_dict={handle: training_handle})
print("Training...")
print(training_batch)
sess.run(val_iter.initializer)
# If we use validation_handle, then input_batch tensor comes from validation tfrecords
print("Validation")
validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
print(validation_batch)