Tensorflow:使用Queue for CSV文件和自定义Estimator以及“input_fn”函数

时间:2017-08-28 09:08:48

标签: python-3.x csv machine-learning tensorflow

我在很长一段时间(很多小时)搜索了我的问题的正确答案而没有结果,所以我在这里。我想我错过了一些明显的东西,但我不知道是什么......

问题:使用队列读取CSV文件并使用input_fn训练Estimator,而不是每次都重新加载图表(这非常慢)。

我创建了一个自定义模型,它给了我一个model_fn函数来创建我自己的估算器:

tf.estimator.Estimator(model_fn=model_fn, params=model_params)

之后,我需要读取一个非常大的CSV文件(无法在内存中加载),所以我决定使用Queue(似乎是最好的解决方案):

nb_features = 10
queue = tf.train.string_input_producer(["test.csv"],
                                       shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(queue)

record_defaults = [[0] for _ in range(nb_features+1)]
cols = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.stack(cols[0:len(cols)-1]) # Take all columns without the last
label = tf.stack(cols[len(cols)-1]) # Take last column

我认为这段代码还可以。

然后,主要代码:

with tf.Session() as sess:
    tf.logging.set_verbosity(tf.logging.INFO)
    sess.run(tf.global_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # Return a Tensor of 1000 features/labels
    def get_inputs():
        print("input call !")
        xs = []
        ys = []
        for i in range(1000):
            x, y = sess.run([features, label])
            xs.append(x)
            ys.append(y)
        return tf.constant(np.asarray(xs), dtype=tf.float32), tf.constant(np.asarray(ys))

    estimator.train(input_fn=get_inputs,
                   steps=100)

    coord.request_stop()
    coord.join(threads)

正如你所看到的,这里有很多丑陋的东西......

我想要的是什么:我希望列车功能在每个步骤中使用一批新功能。但是在这里,它在100个步骤中使用相同批次的1000个功能,因为get_inputs函数只是在我们开始训练时调用。有一个简单的方法吗?

我尝试使用 step = 1 循环 estimator.train ,但这会每次重新加载图形并变得非常慢。

我现在不知道该怎么办,也不知道是否可能......

感谢您的帮助!

3 个答案:

答案 0 :(得分:1)

简短版本:将您的CSV文件转换为tfrecords,然后使用tf.contrib.data.TFRecordDataset。长版:请参阅代码查看问题/接受的答案here(为方便起见,在下面复制)。

查看tf.contrib.data.Dataset API。我怀疑你最好将CSV转换为TfRecord文件并使用TfRecordDataset。这里有一个完整的教程。

步骤1:将csv数据转换为tfrecords数据。下面的示例代码。

import tensorflow as tf


def read_csv(filename):
    with open(filename, 'r') as f:
        out = [line.rstrip().split(',') for line in f.readlines()]
    return out


csv = read_csv('data.csv')
with tf.python_io.TFRecordWriter("data.tfrecords") as writer:
    for row in csv:
        features, label = row[:-1], row[-1]
        features = [float(f) for f in features]
        label = int(label)
        example = tf.train.Example()
        example.features.feature[
            "features"].float_list.value.extend(features)
        example.features.feature[
            "label"].int64_list.value.append(label)
        writer.write(example.SerializeToString())

这假定标签是最后一列中的整数,前面的列中包含浮点要素。这只需要运行一次。

步骤2:编写解码这些记录文件的数据集。

def parse_function(example_proto):
    features = {
        'features': tf.FixedLenFeature((n_features,), tf.float32),
        'label': tf.FixedLenFeature((), tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['features'], parsed_features['label']


def input_fn():
    dataset = tf.contrib.data.TFRecordDataset(['data.tfrecords'])
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(shuffle_size)
    dataset = dataset.repeat()  # repeat indefinitely
    dataset = dataset.batch(batch_size)
    print(dataset.output_shapes)
    features, label = dataset.make_one_shot_iterator().get_next()
    return features, label

进行测试(独立于估算器):

batch_size = 4
shuffle_size = 10000
features, labels = input_fn()
with tf.Session() as sess:
    f_data, l_data = sess.run([features, labels])
print(f_data, l_data)

与tf.estimator.Estimator一起使用:

estimator.train(input_fn, max_steps=1e7)

答案 1 :(得分:0)

因此,在测试了很多可能性之后,我找到了纯CSV的解决方案,但是在某些条件下工作,这就是我再次需要你帮助的原因!

我们来看看代码:

filename = "test.csv"

queue = tf.train.string_input_producer([filename],
                                   num_epochs=1,
                                   shuffle=False)
reader = tf.TextLineReader()
_, csv_row = reader.read(queue)
record_defaults = [[0] for _ in range(341)]
cols = tf.decode_csv(csv_row, record_defaults=record_defaults)
features = tf.stack(cols[0:340])
label = tf.stack(cols[340])

# WORKS ------
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * 4
example_batch, label_batch = tf.train.batch(
  [features, label], batch_size=4, capacity=capacity)
# ------------

# DOESN'T WORK
#def input_fn():
#    min_after_dequeue = 1000
#    capacity = min_after_dequeue + 3 * 4
#    example_batch, label_batch = tf.train.batch(
#      [features, label], batch_size=4, capacity=capacity)
#    return example_batch, label_batch
# ------------

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  tf.local_variables_initializer().run()

  # start populating filename queue
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord, sess=sess)

  cpt = 0
  while True:
    try:
      cpt += 1
      print(cpt, end=' ')
      #_, l = input_fn()
      #print(l.eval())
      print(label_batch.eval())
    except tf.errors.OutOfRangeError:
      break;

  coord.request_stop()
  coord.join(threads)

(我使用的CSV文件包含341列,340个功能和1个标签)

正如您所看到的,代码非常难看,但我们现在可以直接读取CSV。不幸的是,我们需要一个“input_fn”函数来与Estimator一起使用,这就是我尝试将批量创建放在“input_fn”中的原因。但是当运行这段代码时,当Tensorflow尝试这一行时,一切都会冻结(并阻止你漂亮的shell):

print(l.eval())

所以,如果有人知道为什么一切都停止了,请帮助我!

感谢。

答案 2 :(得分:0)

如果您担心tf.train.start_queue_runners未被调用,请尝试以下操作:

class ThreadStartHook(tf.train.SessionRunHook):
    def after_create_session(self, session, coord):
        self.coord = coord
        self.threads = tf.train.start_queue_runners(coord=coord, sess=session)

    def end(self, session):
        self.coord.request_stop()
        self.coord.join(self.threads)


estimator.train(input_fn, [ThreadStartHook()])

我开始时有类似的想法,但发现它没有必要。