运行tf.train.shuffle_batch

时间:2018-08-16 15:28:57

标签: python tensorflow jupyter-notebook

我使用TFrecord制作了自己的数据集,对数据进行了编码和解码,然后使用tf.train.shuffle_batch进行了小批量生产。

代码如下:

def encode_data(train_set, train_label):
    recordfilenum = 0
    ftrecordfname = ("p20_s80_tfrecord_%03d." % recordfilenum)
    writer = tf.python_io.TFRecordWriter(filepath+ftrecordfname)
    for i in range(5000):
        label = train_label[i]
        D1 = train_set[i,2:130]
        D2 = train_set[i,132:260]
        data1 = D1.tobytes()
        data2 = D2.tobytes()
        example = tf.train.Example(features = tf.train.Features(feature = {
            "label": tf.train.Feature(float_list = tf.train.FloatList(value=[label])),
            "data1": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data1])),
            "data2": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data2]))
        }))

    writer.write(example.SerializeToString())
    writer.close()

def read_and_decode(fname):
    fname_queue = tf.train.string_input_producer([fname])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(fname_queue)
    features = tf.parse_single_example(serialized_example, features={
        'label': tf.FixedLenFeature([], tf.float32),
        'data1': tf.FixedLenFeature([], tf.string),
        'data2': tf.FixedLenFeature([], tf.string)
    })
    data1 = tf.decode_raw(features['data1'], tf.float32)
    data1 = tf.cast(data1, tf.float32)
    print(data1.shape)
    data1 = tf.reshape(data1, [128])
    print(data1.shape)

    data2 = tf.decode_raw(features['data2'], tf.float32)
    data2 = tf.cast(data2, tf.float32)
    print(data2.shape)
    data2 = tf.reshape(data1, [128])

    label = features['label']
    label = tf.cast(features['label'], tf.float32)
    label = tf.reshape(label, [1])

    return data1, data2, label


sess = tf.InteractiveSession()
encode_data(train_set, train_label)
data1, data2, label = read_and_decode('p20_s80_tfrecord_000')
threads = tf.train.start_queue_runners(sess=sess)
xdata_1, xdata_2, ydata = tf.train.shuffle_batch([data1, data2, label],batch_size=32,capacity=1000,min_after_dequeue=100)

直到现在还好。但是当我跑步时:

sess.run([xdata_1, xdata_2, ydata])

程序被卡住,什么也不做。当我使用Jupyter笔记本时,单元格的左侧有一个"*"。 我认为我的代码肯定有问题,但是我找不到它。

有人会帮助我吗?

0 个答案:

没有答案