使用batch_sequences_with_states时出现障碍和prefecth错误

时间:2017-11-16 09:16:42

标签: tensorflow

使用tf.contrib.training.batch_sequences_with_states时遇到与障碍相关的错误。

我将手工制作的数据写入tfrecords格式,它包含100个示例,每个示例都是一个具有50维特征的可变长度序列。

test_list = list()
for i in range(100):
    l = [[], []]
    llen = int(random.uniform(100, 500))
    for j in range(llen):
        l[0].append([j*0.1]*50)
        l[1].append(j)
    test_list.append(l)

writer = tf.python_io.TFRecordWriter("my_test.sequence.tfrecords")
try:
    for (idx, i) in enumerate(test_list):
        example = tf.train.SequenceExample(
            context=tf.train.Features(
                feature={
                    'id' : tf.train.Feature(bytes_list=tf.train.BytesList(value=['list'+str(idx)+'_'])),
                    'length' : tf.train.Feature(int64_list=tf.train.Int64List(value=[len(i[0])]))
                }
            ),
            feature_lists=tf.train.FeatureLists(
                feature_list={
                    'feat' : tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=frame)) for frame in i[0]]),
                    'label' : tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=[frame])) for frame in i[1]])
                }
            )
        )
        writer.write(example.SerializeToString())
finally:
    writer.close()

然后我读取并解码tfrecods数据集,并使用tf.contrib.training.batch_sequences_with_states批处理序列输入。

def read_and_decode(filename, num_epochs=None):
    context_features={
        'id' : tf.FixedLenFeature([], dtype=tf.string),
        'length' : tf.FixedLenFeature([], dtype=tf.int64)
    }
    sequence_features = {
        "feat": tf.FixedLenSequenceFeature([50], dtype=tf.float32, allow_missing=False),
        "label": tf.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=False)
    }
    filename_queue = tf.train.string_input_producer([filename], num_epochs)

    reader = tf.TFRecordReader()
    key, serialized_example = reader.read(filename_queue)
    context_parsed, sequence_parsed  = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features)

    return (key, context_parsed, sequence_parsed)

key, context_parsed, sequence_parsed = read_and_decode("my_test.sequence.tfrecords")

batch_size = 16
num_unroll = 8
num_enqueue_threads = 2
hidden_size = 256
cell = tf.contrib.rnn.GRUCell(hidden_size)

initial_state_values = tf.zeros(cell.state_size, dtype=tf.float32)
initial_states = {"gru": initial_state_values}

batch = tf.contrib.training.batch_sequences_with_states(
    input_key=context_parsed['id'],
    input_sequences=sequence_parsed,
    input_context=context_parsed,
    initial_states=initial_states,
    num_unroll=num_unroll,
    batch_size=batch_size,
    input_length = tf.cast(context_parsed["length"], tf.int32),
    pad = True,
    num_threads=num_enqueue_threads,
    capacity= 1000 + batch_size * num_enqueue_threads * 2,
    make_keys_unique=False,
    allow_small_batch=True,
    name='batch_sequence')

k = batch.key
nk = batch.next_key
inputs = batch.sequences['feat']
labels = batch.sequences['label']
wid = batch.context['id']

inputs_by_time = tf.split(inputs, num_unroll, 1)
inputs_by_time = [tf.squeeze(elem, squeeze_dims=1) for elem in inputs_by_time]
assert len(inputs_by_time) == num_unroll

outputs, state = tf.contrib.rnn.static_state_saving_rnn(
      cell,
      inputs_by_time,
      sequence_length=batch.context["length"],
      state_saver=batch,
      state_name='gru')

with tf.device('/cpu:0'):
    sess = tf.Session()

    init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
    sess.run(init_op)

    fetches = [k, nk, wid]
    feed_dict={}

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord,daemon=False)
    try:
        while True:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            k, nk, wid\
            = sess.run(fetches)
            print('key: %s' % k)
            print('next key: %s' % nk)
            print('id: %s' % wid)
    except tf.errors.OutOfRangeError:
        print("done training")
    finally:
        print('request stop ............')
        coord.request_stop()
    coord.join(threads)
    sess.close()

但我收到错误enter image description here

错误消息显示barrie插入了现有密钥。看起来图表无法预取序列的第二个时间步。也许错误来自prefetch_op?

如果我设置make_keys_unique = True,批次也无法达到第二次步骤。 enter image description here

0 个答案:

没有答案