从Tensorflow中的张量中提取随机切片

时间:2017-08-24 17:38:53

标签: python tensorflow

我有张量流队列读取器中生成的输入张量形状(1,512,512,32)

viewScope.forkNum

我想在此输出张量的第四维上选择一个随机切片,这样每次调用一个新批次时,也会采用一个新的随机切片。我已使用batch_input, batch_output = tf.train.shuffle_batch([image, label], batch_size=BATCH_SIZE, capacity=3 * BATCH_SIZE + min_queue_examples, enqueue_many=True, min_after_dequeue=min_queue_examples, num_threads=16) #BATCH_SIZE = 1

尝试了以下操作
numpy

但是每次都会返回rand_slice_ind = np.random.randint(0, 32) slice_begin = tf.constant([0, 0, 0, rand_slice_ind]) slice_input = tf.slice(batch_input, begin = slice_begin, size = [BATCH_SIZE, height, width, 1]) 的相同值。我认为这与使用在图形外部生成的非张量流对象有关。

我还尝试了rand_slice_ind的内容:

tf.random_uniform

但这会导致梯度计算出现问题。有什么提示吗?

1 个答案:

答案 0 :(得分:0)

您可以尝试使用gather来实现:

slice_length = 128
data_length = data.shape[your_axis]
max_offset = data_length - slice_length
random_offset = tf.random.uniform((), minval=0, maxval=max_offset, dtype=tf.dtypes.int64)
slice_indices = tf.range(0, slice_length, dtype=tf.dtypes.int64)
random_slice = tf.gather(data, slice_indices + random_offset, axis=your_axis)