如何获得每个批次中使用的数据的索引?

时间:2018-10-29 22:42:29

标签: python tensorflow

我需要保存每个迷你批处理中使用的数据的索引。

例如,如果我的数据是:

x = np.array([[1.1], [2.2], [3.3], [4.4]])

,第一个迷你批处理是[1.1][3.3],然后我想存储02(因为[1.1]是第0个观测值, [3.3]是第二个观察结果。

我正在通过keras.sequential API使用tensorflow来渴望执行。

据阅读源代码所知,该信息未存储在任何地方,因此我无法通过回调做到这一点。

我目前正在通过创建一个存储索引的对象来解决我的问题。

class IndexIterator(object):
    def __init__(self, n, n_epochs, batch_size, shuffle=True):
        data_ix = np.arange(n)
        if shuffle:
            np.random.shuffle(data_ix)

        self.ix_batches = np.array_split(data_ix, np.ceil(n / batch_size))
        self.batch_indices = []

    def generate_arrays(self, x, y):
        batch_ixs = np.arange(len(self.ix_batches))
        while 1: 
            np.random.shuffle(batch_ixs)
            for batch in batch_ixs:
                self.batch_indices.append(self.ix_batches[batch])
                yield (x[self.ix_batches[batch], :], y[self.ix_batches[batch], :])

data_gen = IndexIterator(n=32, n_epochs=100, batch_size=16)
dnn.fit_generator(data_gen.generate_arrays(x, y), 
                  steps_per_epoch=2, 
                  epochs=100)
# This is what I am looking for
print(data_gen.batch_indices)

无法使用张量流回调来做到这一点吗?

1 个答案:

答案 0 :(得分:0)

不确定这是否会比您的解决方案更有效,但肯定会更普遍。

如果您拥有带有n索引的训练数据,则可以创建仅包含这些索引的辅助Dataset并使用“真实”数据集进行压缩。

IE。

real_data = tf.data.Dataset ... 
indices = tf.data.Dataset.from_tensor_slices(tf.range(data_set_length)))
total_dataset = tf.data.Dataset.zip((real_data, indices))

# Perform optional pre-processing ops.

iterator = total_dataset.make_one_shot_iterator()

# Next line yields `(original_data_element, index)`
item_and_index_tuple = iterator.get_next() 

`