我使用TFrecord
制作了自己的数据集,对数据进行了编码和解码,然后使用tf.train.shuffle_batch
进行了小批量生产。
代码如下:
def encode_data(train_set, train_label):
recordfilenum = 0
ftrecordfname = ("p20_s80_tfrecord_%03d." % recordfilenum)
writer = tf.python_io.TFRecordWriter(filepath+ftrecordfname)
for i in range(5000):
label = train_label[i]
D1 = train_set[i,2:130]
D2 = train_set[i,132:260]
data1 = D1.tobytes()
data2 = D2.tobytes()
example = tf.train.Example(features = tf.train.Features(feature = {
"label": tf.train.Feature(float_list = tf.train.FloatList(value=[label])),
"data1": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data1])),
"data2": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data2]))
}))
writer.write(example.SerializeToString())
writer.close()
def read_and_decode(fname):
fname_queue = tf.train.string_input_producer([fname])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(fname_queue)
features = tf.parse_single_example(serialized_example, features={
'label': tf.FixedLenFeature([], tf.float32),
'data1': tf.FixedLenFeature([], tf.string),
'data2': tf.FixedLenFeature([], tf.string)
})
data1 = tf.decode_raw(features['data1'], tf.float32)
data1 = tf.cast(data1, tf.float32)
print(data1.shape)
data1 = tf.reshape(data1, [128])
print(data1.shape)
data2 = tf.decode_raw(features['data2'], tf.float32)
data2 = tf.cast(data2, tf.float32)
print(data2.shape)
data2 = tf.reshape(data1, [128])
label = features['label']
label = tf.cast(features['label'], tf.float32)
label = tf.reshape(label, [1])
return data1, data2, label
sess = tf.InteractiveSession()
encode_data(train_set, train_label)
data1, data2, label = read_and_decode('p20_s80_tfrecord_000')
threads = tf.train.start_queue_runners(sess=sess)
xdata_1, xdata_2, ydata = tf.train.shuffle_batch([data1, data2, label],batch_size=32,capacity=1000,min_after_dequeue=100)
直到现在还好。但是当我跑步时:
sess.run([xdata_1, xdata_2, ydata])
程序被卡住,什么也不做。当我使用Jupyter笔记本时,单元格的左侧有一个"*"
。
我认为我的代码肯定有问题,但是我找不到它。
有人会帮助我吗?