如何使用tf.train.range_input_producer()切割tf.train.shuffle_batch()的输出?

时间:2017-03-09 23:44:20

标签: python tensorflow

我正在使用tensorflow在录音谱图上训练一个递归神经网络。数据相当大,所以我想将数据存储为tfrecords并通过输入管道(类似于cifar10_input.py)提供。但是,由于我的频谱图将作为RNN的输入,因此必须进一步切片到num_steps长的窗口(或切片),例如ptb_producer(见reader.py来自tensorflow的RNN教程)。

我的问题是如何做到这一点?如何添加另一个用于切割批量光谱图的队列?

每个频谱图的长度为num_times = 15000,应切成num_slices个切片,每个切片的长度为num_steps

我尝试了以下内容:

def data_producer(fname, batch_size, num_epochs):

    fname_queue = tf.train.string_input_producer([fname],
                                                 num_epochs=num_epoch)
    # Read 1 spectrogram and labels
    X_alltimes, y_alltimes = read_tfrecords(fname_queue)
    # number of time-bins in X_alltimes
    num_times = 15000
    # number of frequencies in X_alltimes
    num_frequencies = 175

    # Stack batch_size spectrograms to make a batch
    X_alltimes_batch, y_alltimes_batch = tf.train.shuffle_batch(
        [X_alltimes, y_alltimes],
        batch_size=batch_size,
        capacity=capacity,
        min_after_dequeue=min_after_dequeue)

    # max number of num_step windows/slices we can get from X_alltimes.
    num_slices = num_times // num_step
    # slice index
    i = tf.train.range_input_producer(
        limit=num_slices,
        shuffle=False).dequeue()

    X = tf.strided_slice(
        X_alltimes_batch,
        [0, i * num_step, 0],
        [batch_size, (i + 1) * num_step, num_frequencies],
        strides=[1, 1, 1])
    X.set_shape([batch_size, num_step, num_frequencies])

    y = tf.strided_slice(
        y_alltimes_batch,
        [0, i * num_step + 1],
        [batch_size, (i + 1) * num_step + 1],
        strides=[1, 1])
   y.set_shape([batch_size, num_step])

   return X, y

# Try to get some slices
X, y = data_producer(fname, batch_size=25, num_steps=10, num_epochs=1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # Start input enqueue threads.
    threads = tf.train.start_queue_runners()

    for i in range(50):
        print(sess.run([X, y]))

但是,对run()的每次调用都会抽出一个新批次,而不是一个新的分片。在获得新批次之前,如何单步执行所有num_slices切片?

1 个答案:

答案 0 :(得分:0)

这样做的正确方法似乎是使用tf.contrib.training.batch_sequences_with_states()tf.contrib.training.SequenceQueueingStateSaver()。 来自tensorflow documentation使用SequenceQueueingStateSaver或其包装程序batch_sequences_with_states如果您的输入数据具有动态主时间/帧计数轴,您希望在小型化期间转换为固定大小的段,并且喜欢在一个示例的段中向前方向存储状态。

一个简单的例子:

def load_sequences(fname):

   N = 10  # number of sequences
   fname_queue = tf.train.string_input_producer([fname])
   X, y, n_times, n_freqs = read_tfrecords(fname_queue)

   context = {}
   context['n_times'] = n_times
   context['n_freqs'] = n_freqs
   context['length'] = X.get_shape()[0].value

   sequences = {'inputs': X,
                'labels': y}

   # i only provided a unique number for key
   i = tf.train.range_input_producer(limit=N, shuffle=True).dequeue()
   key = string_ops.string_join(['key_', string_ops.as_string(i)])

   return key, context, sequences


def test_sequence_loader(fname):

   batch_sz=2
   num_unroll=15
   n_threads=1
   initial_states =  {"state": tf.zeros(batch_sz, dtype=tf.float32)}

   key, context, sequences = load_sequences(fname, Nepoch=1)

   batch = tf.contrib.training.batch_sequences_with_states(
       input_key=key,
       input_sequences=sequences,
       input_context=context,
       input_length=tf.cast(context['length'], tf.int32),
       initial_states=initial_states,
       num_unroll=num_unroll,
       batch_size=batch_sz,
       num_threads=n_threads,
       capacity=batch_sz * 50)

   inputs = batch.sequences["inputs"]
   labels = batch.sequences["labels"]
   state = batch.state("state")  
   state_update = batch.save_state("state", state + 1) # necessary to proceed through the slices.

   input_slice = tf.split(value=inputs, num_or_size_splits=num_unroll, axis=1)
   label_slice = tf.split(value=labels, num_or_size_splits=num_unroll, axis=1)

   with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       sess.run(tf.local_variables_initializer())
       # Start input enqueue threads.
       threads = tf.train.start_queue_runners()

       for i in range(50):
           # the state has to be updated in order to move through slices
           [X, y, _, _] = sess.run([input_slice, label_slice, state, state_update])
           print('len(X):', len(X), 'X[0].shape:', X[0].shape, 'len(y): ', len(y))

输出:

len(X): 15 X[0].shape: (2, 1, 175) len(y):  15
len(X): 15 X[0].shape: (2, 1, 175) len(y):  15
...