Tensorflow shuffle迭代器

时间:2018-05-17 15:52:45

标签: tensorflow tensorflow-datasets

我想使用Tensorflow迭代器检索属于不同类的2个项目(进行BC学习)...

我一直在研究的解决方案是使用tf.while_loop,但我找不到合适的解决方案。有没有人找到除我提出的解决方案之外的任何其他方式?

以下是关于属于5个类

的随机数的天真数据集的示例
foo = 1
my_foo = "foo"
my_bar = "bar"
try(:foo)        # => 1
try(:bar)        # => nil
try(my_foo)      # => 1
try(my_bar)      # => nil

Thx:)

1 个答案:

答案 0 :(得分:0)

我不清楚你要做什么,但是如果你有两个tf.data.Dataset个对象并且你想随机抽样,你可以做类似下面的事情(请注意这将是需要升级到tf-nightly包或等待TensorFlow 1.9发布):

# Define two datasets with the same structure but different values, to represent
# the different inputs. Using dummy data (a dataset of '1's and a dataset of '2's)
# to make the example clearer.
dataset_1 = tf.data.Dataset.from_tensors(1).repeat(None)
dataset_2 = tf.data.Dataset.from_tensors(2).repeat(None)

merged_dataset = tf.contrib.data.sample_from_datasets([dataset_1, dataset_2])
merged_dataset = merged_dataset.batch(2)  # Get two elements at a time.

iterator = merged_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
  for i in range(10):
    values = sess.run([next_element])
    print values

# Prints: 
# [array([1, 2], dtype=int32)]
# [array([1, 2], dtype=int32)]
# [array([2, 1], dtype=int32)]
# [array([1, 2], dtype=int32)]
# [array([1, 1], dtype=int32)]
# [array([2, 2], dtype=int32)]
# [array([2, 2], dtype=int32)]
# [array([1, 1], dtype=int32)]
# [array([2, 2], dtype=int32)]
# [array([2, 2], dtype=int32)]

如果要确保元素不是来自同一个类,您可以使用指定weights tf.contrib.data.sample_from_datasets()参数的功能,Dataset可以是import tensorflow as tf NUM_CLASSES = 5 NUM_DISTINCT = 2 datasets = [tf.data.Dataset.from_tensors(i).repeat(None) for i in range(NUM_CLASSES)] # Define a dataset with NUM_DISTINCT distinct class IDs per element, # then unbatch it in to one class per element. weight_dataset = tf.contrib.data.Counter().map( lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:NUM_DISTINCT]) weight_dataset = weight_dataset.apply(tf.contrib.data.unbatch()) weight_dataset = weight_dataset.map(lambda x: tf.one_hot(x, NUM_CLASSES)) merged_dataset = tf.contrib.data.sample_from_datasets(datasets, weight_dataset) merged_dataset = merged_dataset.batch(NUM_DISTINCT) iterator = merged_dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(1000): values = sess.run(next_element) assert values[0] != values[1] print values 的({1}}在这种情况下,一个热门的分布,如下:

select count(distinct id) as hits, name, weapon 
from damage 
group by name, weapon 
order by name, weapon desc;