使用tensorflow函数tf.slice_input_producer错误

时间:2017-02-12 07:29:00

标签: python tensorflow

我使用tfrecord来处理序列数据集。当一个例子被引入内存时,我将获得两个上下文特征和一个序列特征。上下文特征是序列的长度和序列的标签。序列特征是一个字节列表,表示每个时间步的特征。 每个例子我都有三个张量:

length  TensorShape([])
label   TensorShape([])
frames  TensorShape([Dimension(None)])

我想使用每个序列特征来预测标签,所以我必须使标签与框架的长度相同。

length=tf.expand_dims(length, 0, name='expand_length')
label=tf.expand_dims(label, 0, name='expand_label')
labels=tf.tile(label, length, name='multi_label')

这次我得到以下资源:

labels   TensorShape([Dimension(None)])
frames   TensorShape([Dimension(None)])

我必须将它们推入队列,以便我可以获得一个框架和标签。

frame, label=tf.train.slice_input_producer([frames, labels])

Ant随后批处理,然后进行网络训练。

frames, labels = tf.train.shuffle_batch([frame, label], 4, 16, 8)

它应该可以工作,但是,函数tf.train.slice_input_producer中发生错误这里有错误信息:

W d:\build\tensorflow\tensorflow_gpu-r0.12\tensorflow\core\framework\op_kernel.cc:975] Invalid argument: indices = 119 is not in [0, 117)
         [[Node: slice_timestep/Gather = Gather[Tindices=DT_INT32, Tparams=DT_STRING, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](parse_one_ex/parse_one_ex:2, slice_timestep/fraction_of_32_full_Dequeue)]]
W d:\build\tensorflow\tensorflow_gpu-r0.12\tensorflow\core\framework\op_kernel.cc:975] Invalid argument: indices = 119 is not in [0, 117)
         [[Node: slice_timestep/Gather = Gather[Tindices=DT_INT32, Tparams=DT_STRING, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](parse_one_ex/parse_one_ex:2, slice_timestep/fraction_of_32_full_Dequeue)]]
I d:\build\tensorflow\tensorflow_gpu-r0.12\tensorflow\stream_executor\dso_loader.cc:128] successfully opened CUDA library cupti64_80.dll locally
W d:\build\tensorflow\tensorflow_gpu-r0.12\tensorflow\core\framework\op_kernel.cc:975] Out of range: RandomShuffleQueue '_3_batch_ex/random_shuffle_queue' is closed and has insufficient elements (requested 4, current size 0)
         [[Node: batch_ex = QueueDequeueMany[_class=["loc:@batch_ex/random_shuffle_queue"], component_types=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](batch_ex/random_shuffle_queue, batch_ex/n)]]
W d:\build\tensorflow\tensorflow_gpu-r0.12\tensorflow\core\framework\op_kernel.cc:975] Out of range: RandomShuffleQueue '_3_batch_ex/random_shuffle_queue' is closed and has insufficient elements (requested 4, current size 0)
         [[Node: batch_ex = QueueDequeueMany[_class=["loc:@batch_ex/random_shuffle_queue"], component_types=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](batch_ex/random_shuffle_queue, batch_ex/n)]]
W d:\build\tensorflow\tensorflow_gpu-r0.12\tensorflow\core\framework\op_kernel.cc:975] Out of range: RandomShuffleQueue '_3_batch_ex/random_shuffle_queue' is closed and has insufficient elements (requested 4, current size 0)
         [[Node: batch_ex = QueueDequeueMany[_class=["loc:@batch_ex/random_shuffle_queue"], component_types=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](batch_ex/random_shuffle_queue, batch_ex/n)]]

slice_input_producer的名称是slice_timestep

shuffle_batch的名称是batch_ex

这里的图表显示在我的张量板中。

the whole graph

local zoomed graph

以下是重现错误的简化代码:

import tensorflow as tf

context_features = {
    "length": tf.FixedLenFeature([], dtype=tf.int64),
    "label": tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
    "imgs_list": tf.FixedLenSequenceFeature([], dtype=tf.string),
}


file=tf.train.string_input_producer(['./train.tfrecord'])
reader=tf.TFRecordReader()
_, ex=reader.read(file)

context_parsed, sequence_parsed = tf.parse_single_sequence_example(
    serialized=ex,
    context_features=context_features,
    sequence_features=sequence_features
)
length=tf.cast(context_parsed['length'], tf.int32)
label=tf.cast(context_parsed['label'], tf.int32)
length=tf.expand_dims(length, 0, name='expand_length')
label=tf.expand_dims(label, 0, name='expand_label')
label=tf.tile(label, length)
imcontent, label=tf.train.slice_input_producer([sequence_parsed['imgs_list'], label])
im=tf.image.decode_jpeg(imcontent, 3)
im=tf.image.resize_images(im, [224, 224])
im, label = tf.train.shuffle_batch([im, label], 4, 16, 8, name='batch_ex')
with tf.Session() as sess:
    tf.train.start_queue_runners(sess)
    fig=plt.figure()
    while(True):
        [res, res2]=sess.run([im, label])
        print(res2)

1 个答案:

答案 0 :(得分:0)

我已经解决了。 slice_input_producer似乎是静态的。我用