我将一组固定长度和可变长度的特征放入一个tf.train.SequenceExample。
context_features
length, scalar, tf.int64
site_code_raw, scalar, tf.string
Date_Local_raw, scalar, tf.string
Time_Local_raw, scalar, tf.string
Sequence_features
Orig_RefPts, [#batch, #RefPoints, 4] tf.float32
tgt_location, [#batch, 3] tf.float32
tgt_val [#batch, 1] tf.float32
#RefPoints
的值对于不同的序列示例是可变的。我将其值存储在length
的{{1}}功能中。其余功能有固定的尺寸。
这是我用来阅读的代码&解析数据:
context_features
当我使用def read_batch_DatasetAPI(
filenames,
batch_size = 20,
num_epochs = None,
buffer_size = 5000):
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_SeqExample1)
if (buffer_size is not None):
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# next_element contains a tuple of following tensors
# length, scalar, tf.int64
# site_code_raw, scalar, tf.string
# Date_Local_raw, scalar, tf.string
# Time_Local_raw, scalar, tf.string
# Orig_RefPts, [#batch, #RefPoints, 4] tf.float32
# tgt_location, [#batch, 3] tf.float32
# tgt_val [#batch, 1] tf.float32
return iterator, next_element
def _parse_SeqExample1(in_SeqEx_proto):
# Define how to parse the example
context_features = {
'length': tf.FixedLenFeature([], dtype=tf.int64),
'site_code': tf.FixedLenFeature([], dtype=tf.string),
'Date_Local': tf.FixedLenFeature([], dtype=tf.string),
'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #,
}
sequence_features = {
"input_features": tf.VarLenFeature(dtype=tf.float32),
'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32),
'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32)
}
context, sequence = tf.parse_single_sequence_example(
in_SeqEx_proto,
context_features=context_features,
sequence_features=sequence_features)
# distribute the fetched context and sequence features into tensors
length = context['length']
site_code_raw = context['site_code']
Date_Local_raw = context['Date_Local']
Time_Local_raw = context['Time_Local']
# reshape the tensors according to the dimension definition above
Orig_RefPts = sequence['input_features'].values
Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4])
tgt_location = sequence['tgt_location_features']
tgt_location = tf.reshape(tgt_location, [-1])
tgt_val = sequence['tgt_val_feature']
tgt_val = tf.reshape(tgt_val, [-1])
return length, site_code_raw, Date_Local_raw, Time_Local_raw, \
Orig_RefPts, tgt_location, tgt_val
调用read_batch_DatasetAPI
时(请参阅下面的代码),它可以逐个处理所有(大约200,000个)序列示例,没有任何问题。但是,如果我将batch_size = 1
更改为大于1的任何数字,它只会在获取320到700个序列示例后停止,而不会显示任何错误消息。我不知道如何解决这个问题。任何帮助表示赞赏!
batch_size
答案 0 :(得分:0)
我看到一些帖子(Example 1和Example 2)提到了新的dataset
函数from_generator
(https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/Dataset#from_generator)。我还不确定如何使用它来解决我的问题。任何人都知道如何做,请将其作为新答案发布。谢谢!
以下是我目前对我的问题的诊断和解决方案:
序列长度(#RefPoints
)的变化导致了问题。 dataset.map(_parse_SeqExample1)
仅在批处理中#RefPoints
碰巧相同时才有效。这就是为什么如果batch_size
为1,它总是有效,但如果它大于1,它在某些时候失败了。
我发现dataset
具有padded_batch
函数,该函数可以将变量填充到批处理中的最大长度。我做了一些修改来暂时解决我的问题(我猜from_generator
将是我案例的真正解决方案):
在_parse_SeqExample1
函数中,return语句已更改为
return tf.tuple([length, site_code_raw, Date_Local_raw, Time_Local_raw, \
Orig_RefPts, tgt_location, tgt_val])
在read_batch_DatasetAPI
函数中,语句
dataset = dataset.batch(batch_size)
已更改为
dataset = dataset.padded_batch(batch_size, padded_shapes=(
tf.TensorShape([]),
tf.TensorShape([]),
tf.TensorShape([]),
tf.TensorShape([]),
tf.TensorShape([None, 4]),
tf.TensorShape([3]),
tf.TensorShape([1])
)
)
最后,从
更改fetch语句 length, site_code_raw, Date_Local_raw, Time_Local_raw, \
Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)
到
[length, site_code_raw, Date_Local_raw, Time_Local_raw, \
Orig_RefPts_val, tgt_location, tgt_vale] = sess.run(next_element)
注意:强> 我不知道为什么,这只适用于当前tf-nightly-gpu版本而不是tensorflow-gpu v1.3。