使用tf.contrib.training.batch_sequences_with_states时遇到与障碍相关的错误。
我将手工制作的数据写入tfrecords格式,它包含100个示例,每个示例都是一个具有50维特征的可变长度序列。
test_list = list()
for i in range(100):
l = [[], []]
llen = int(random.uniform(100, 500))
for j in range(llen):
l[0].append([j*0.1]*50)
l[1].append(j)
test_list.append(l)
writer = tf.python_io.TFRecordWriter("my_test.sequence.tfrecords")
try:
for (idx, i) in enumerate(test_list):
example = tf.train.SequenceExample(
context=tf.train.Features(
feature={
'id' : tf.train.Feature(bytes_list=tf.train.BytesList(value=['list'+str(idx)+'_'])),
'length' : tf.train.Feature(int64_list=tf.train.Int64List(value=[len(i[0])]))
}
),
feature_lists=tf.train.FeatureLists(
feature_list={
'feat' : tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=frame)) for frame in i[0]]),
'label' : tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=[frame])) for frame in i[1]])
}
)
)
writer.write(example.SerializeToString())
finally:
writer.close()
然后我读取并解码tfrecods数据集,并使用tf.contrib.training.batch_sequences_with_states批处理序列输入。
def read_and_decode(filename, num_epochs=None):
context_features={
'id' : tf.FixedLenFeature([], dtype=tf.string),
'length' : tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
"feat": tf.FixedLenSequenceFeature([50], dtype=tf.float32, allow_missing=False),
"label": tf.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=False)
}
filename_queue = tf.train.string_input_producer([filename], num_epochs)
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=serialized_example,
context_features=context_features,
sequence_features=sequence_features)
return (key, context_parsed, sequence_parsed)
key, context_parsed, sequence_parsed = read_and_decode("my_test.sequence.tfrecords")
batch_size = 16
num_unroll = 8
num_enqueue_threads = 2
hidden_size = 256
cell = tf.contrib.rnn.GRUCell(hidden_size)
initial_state_values = tf.zeros(cell.state_size, dtype=tf.float32)
initial_states = {"gru": initial_state_values}
batch = tf.contrib.training.batch_sequences_with_states(
input_key=context_parsed['id'],
input_sequences=sequence_parsed,
input_context=context_parsed,
initial_states=initial_states,
num_unroll=num_unroll,
batch_size=batch_size,
input_length = tf.cast(context_parsed["length"], tf.int32),
pad = True,
num_threads=num_enqueue_threads,
capacity= 1000 + batch_size * num_enqueue_threads * 2,
make_keys_unique=False,
allow_small_batch=True,
name='batch_sequence')
k = batch.key
nk = batch.next_key
inputs = batch.sequences['feat']
labels = batch.sequences['label']
wid = batch.context['id']
inputs_by_time = tf.split(inputs, num_unroll, 1)
inputs_by_time = [tf.squeeze(elem, squeeze_dims=1) for elem in inputs_by_time]
assert len(inputs_by_time) == num_unroll
outputs, state = tf.contrib.rnn.static_state_saving_rnn(
cell,
inputs_by_time,
sequence_length=batch.context["length"],
state_saver=batch,
state_name='gru')
with tf.device('/cpu:0'):
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
fetches = [k, nk, wid]
feed_dict={}
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord,daemon=False)
try:
while True:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
k, nk, wid\
= sess.run(fetches)
print('key: %s' % k)
print('next key: %s' % nk)
print('id: %s' % wid)
except tf.errors.OutOfRangeError:
print("done training")
finally:
print('request stop ............')
coord.request_stop()
coord.join(threads)
sess.close()
但我收到错误enter image description here
错误消息显示barrie插入了现有密钥。看起来图表无法预取序列的第二个时间步。也许错误来自prefetch_op?
如果我设置make_keys_unique = True,批次也无法达到第二次步骤。 enter image description here