TFLearn:从tfrecords文件

时间:2017-10-03 14:47:04

标签: tflearn

我已经设法使用这篇文章从一个非常大的csv创建了一个.tfrecords文件:https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/。我想将其分批提供给tflearn.DNN模型。

模型示例:

net = tflearn.input_data(shape=[None, 510])
net = tflearn.fully_connected(net, 1020)
net = tflearn.fully_connected(net, 1020)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net)

model = tflearn.DNN(net)

从文件中检索数据:

def read_and_decode_single_example(filename):
    filename_queue = tf.train.string_input_producer([filename],
                                                    num_epochs=None)

    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'features': tf.FixedLenFeature([510], tf.float32)
        })

    _label = features['label']
    _features = features['features']
    return _label, _features

有没有办法将此文件提供给model.fit函数?

labels, features = read_and_decode_single_example('data.tfrecords')
model.fit(features, labels)

该示例显示标签和功能应首先在会话中初始化,并将其与TensorFlow一起使用,但我不确定将其与TFLearn放在何处。

sess = tf.Session()

init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)

0 个答案:

没有答案