当批量大于1时,张量流数据集API不能稳定工作

时间:2017-09-30 20:14:22

标签: tensorflow dataset protocol-buffers sequence

我将一组固定长度和可变长度的特征放入一个tf.train.SequenceExample。

context_features
    length,            scalar,                    tf.int64
    site_code_raw,     scalar,                    tf.string
    Date_Local_raw,    scalar,                    tf.string
    Time_Local_raw,    scalar,                    tf.string
Sequence_features
    Orig_RefPts,       [#batch, #RefPoints, 4]    tf.float32
    tgt_location,      [#batch, 3]                tf.float32
    tgt_val            [#batch, 1]                tf.float32

#RefPoints的值对于不同的序列示例是可变的。我将其值存储在length的{​​{1}}功能中。其余功能有固定的尺寸。

这是我用来阅读的代码&解析数据:

context_features

当我使用def read_batch_DatasetAPI( filenames, batch_size = 20, num_epochs = None, buffer_size = 5000): dataset = tf.contrib.data.TFRecordDataset(filenames) dataset = dataset.map(_parse_SeqExample1) if (buffer_size is not None): dataset = dataset.shuffle(buffer_size=buffer_size) dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # next_element contains a tuple of following tensors # length, scalar, tf.int64 # site_code_raw, scalar, tf.string # Date_Local_raw, scalar, tf.string # Time_Local_raw, scalar, tf.string # Orig_RefPts, [#batch, #RefPoints, 4] tf.float32 # tgt_location, [#batch, 3] tf.float32 # tgt_val [#batch, 1] tf.float32 return iterator, next_element def _parse_SeqExample1(in_SeqEx_proto): # Define how to parse the example context_features = { 'length': tf.FixedLenFeature([], dtype=tf.int64), 'site_code': tf.FixedLenFeature([], dtype=tf.string), 'Date_Local': tf.FixedLenFeature([], dtype=tf.string), 'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #, } sequence_features = { "input_features": tf.VarLenFeature(dtype=tf.float32), 'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32), 'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32) } context, sequence = tf.parse_single_sequence_example( in_SeqEx_proto, context_features=context_features, sequence_features=sequence_features) # distribute the fetched context and sequence features into tensors length = context['length'] site_code_raw = context['site_code'] Date_Local_raw = context['Date_Local'] Time_Local_raw = context['Time_Local'] # reshape the tensors according to the dimension definition above Orig_RefPts = sequence['input_features'].values Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4]) tgt_location = sequence['tgt_location_features'] tgt_location = tf.reshape(tgt_location, [-1]) tgt_val = sequence['tgt_val_feature'] tgt_val = tf.reshape(tgt_val, [-1]) return length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_val 调用read_batch_DatasetAPI时(请参阅下面的代码),它可以逐个处理所有(大约200,000个)序列示例,没有任何问题。但是,如果我将batch_size = 1更改为大于1的任何数字,它只会在获取320到700个序列示例后停止,而不会显示任何错误消息。我不知道如何解决这个问题。任何帮助表示赞赏!

batch_size

1 个答案:

答案 0 :(得分:0)

我看到一些帖子(Example 1Example 2)提到了新的dataset函数from_generatorhttps://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/Dataset#from_generator)。我还不确定如何使用它来解决我的问题。任何人都知道如何做,请将其作为新答案发布。谢谢!

以下是我目前对我的问题的诊断和解决方案:

序列长度(#RefPoints)的变化导致了问题。 dataset.map(_parse_SeqExample1)仅在批处理中#RefPoints碰巧相同时才有效。这就是为什么如果batch_size为1,它总是有效,但如果它大于1,它在某些时候失败了。

我发现dataset具有padded_batch函数,该函数可以将变量填充到批处理中的最大长度。我做了一些修改来暂时解决我的问题(我猜from_generator将是我案例的真正解决方案):

  1. _parse_SeqExample1函数中,return语句已更改为

    return tf.tuple([length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_val])

  2. read_batch_DatasetAPI函数中,语句

    dataset = dataset.batch(batch_size)

    已更改为

    dataset = dataset.padded_batch(batch_size, padded_shapes=( tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([None, 4]), tf.TensorShape([3]), tf.TensorShape([1]) ) )

  3. 最后,从

    更改fetch语句

    length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)

    [length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts_val, tgt_location, tgt_vale] = sess.run(next_element)

  4. 注意: 我不知道为什么,这只适用于当前tf-nightly-gpu版本而不是tensorflow-gpu v1.3。