带有可迭代迭代器的tf.contrib.data.prefetch_to_device()的正确用法

时间:2018-06-29 01:05:01

标签: tensorflow neural-network deep-learning

我想知道是否有可能(如果可以,如何?)使用已将tf.data.Dataset应用于可馈送迭代器的tf.contrib.data.prefetch_to_device()对象。我最初未优化的数据集定义是:

training_dataset_shuffle_batch = tf.data.Dataset.from_tensor_slices(training_data).shuffle(dataset_size).repeat().batch(minibatch_size).prefetch(minibatch_size)
training_shuffle_batch_iterator = training_dataset_shuffle_batch.make_initializable_iterator()

我已经将其重写为

training_dataset_shuffle_batch = tf.data.Dataset.from_tensor_slices(training_data)
training_dataset_shuffle_batch = training_dataset_shuffle_batch.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = dataset_size))
training_dataset_shuffle_batch = training_dataset_shuffle_batch.batch(minibatch_size)
training_dataset_shuffle_batch = training_dataset_shuffle_batch.apply(tf.contrib.data.prefetch_to_device(gpu_names[0]))
training_shuffle_batch_iterator = training_dataset_shuffle_batch.make_initializable_iterator()

因为我正在带有GPU的AWS p3.2xlarge实例上运行此程序,所以我想将数据完善到GPU中以提高效率。 在任何一个版本中,我的迭代器和数据获取程序都定义为

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset_shuffle_batch.output_types, training_dataset_shuffle_batch.output_shapes) 
next_input_data_element = iterator.get_next()

稍后,我尝试通过training_shuffle_batch_iterator获取training_shuffle_batch_handle = sess.run(training_shuffle_batch_iterator.string_handle())的句柄 与iterator配合使用(使用next_input_data_element时还有其他Dataset对象迭代器要从中提取数据)。当我使用第一个数据集定义时一切正常,但第二个失败并出现AttributeError: '_PrefetchToDeviceIterator' object has no attribute 'string_handle'错误。

我这样做正确吗?这是不被支持的,还是我的方法有误?这是使用TF V1.8.0。谢谢。

0 个答案:

没有答案