我想知道是否有可能(如果可以,如何?)使用已将tf.data.Dataset
应用于可馈送迭代器的tf.contrib.data.prefetch_to_device()
对象。我最初未优化的数据集定义是:
training_dataset_shuffle_batch = tf.data.Dataset.from_tensor_slices(training_data).shuffle(dataset_size).repeat().batch(minibatch_size).prefetch(minibatch_size)
training_shuffle_batch_iterator = training_dataset_shuffle_batch.make_initializable_iterator()
我已经将其重写为
training_dataset_shuffle_batch = tf.data.Dataset.from_tensor_slices(training_data)
training_dataset_shuffle_batch = training_dataset_shuffle_batch.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = dataset_size))
training_dataset_shuffle_batch = training_dataset_shuffle_batch.batch(minibatch_size)
training_dataset_shuffle_batch = training_dataset_shuffle_batch.apply(tf.contrib.data.prefetch_to_device(gpu_names[0]))
training_shuffle_batch_iterator = training_dataset_shuffle_batch.make_initializable_iterator()
因为我正在带有GPU的AWS p3.2xlarge实例上运行此程序,所以我想将数据完善到GPU中以提高效率。 在任何一个版本中,我的迭代器和数据获取程序都定义为
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset_shuffle_batch.output_types, training_dataset_shuffle_batch.output_shapes)
next_input_data_element = iterator.get_next()
稍后,我尝试通过training_shuffle_batch_iterator
获取training_shuffle_batch_handle = sess.run(training_shuffle_batch_iterator.string_handle())
的句柄
与iterator
配合使用(使用next_input_data_element
时还有其他Dataset对象迭代器要从中提取数据)。当我使用第一个数据集定义时一切正常,但第二个失败并出现AttributeError: '_PrefetchToDeviceIterator' object has no attribute 'string_handle'
错误。
我这样做正确吗?这是不被支持的,还是我的方法有误?这是使用TF V1.8.0。谢谢。