我在很长一段时间(很多小时)搜索了我的问题的正确答案而没有结果,所以我在这里。我想我错过了一些明显的东西,但我不知道是什么......
问题:使用队列读取CSV文件并使用input_fn训练Estimator,而不是每次都重新加载图表(这非常慢)。
我创建了一个自定义模型,它给了我一个model_fn函数来创建我自己的估算器:
tf.estimator.Estimator(model_fn=model_fn, params=model_params)
之后,我需要读取一个非常大的CSV文件(无法在内存中加载),所以我决定使用Queue(似乎是最好的解决方案):
nb_features = 10
queue = tf.train.string_input_producer(["test.csv"],
shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(queue)
record_defaults = [[0] for _ in range(nb_features+1)]
cols = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.stack(cols[0:len(cols)-1]) # Take all columns without the last
label = tf.stack(cols[len(cols)-1]) # Take last column
我认为这段代码还可以。
然后,主要代码:
with tf.Session() as sess:
tf.logging.set_verbosity(tf.logging.INFO)
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Return a Tensor of 1000 features/labels
def get_inputs():
print("input call !")
xs = []
ys = []
for i in range(1000):
x, y = sess.run([features, label])
xs.append(x)
ys.append(y)
return tf.constant(np.asarray(xs), dtype=tf.float32), tf.constant(np.asarray(ys))
estimator.train(input_fn=get_inputs,
steps=100)
coord.request_stop()
coord.join(threads)
正如你所看到的,这里有很多丑陋的东西......
我想要的是什么:我希望列车功能在每个步骤中使用一批新功能。但是在这里,它在100个步骤中使用相同批次的1000个功能,因为get_inputs函数只是在我们开始训练时调用。有一个简单的方法吗?
我尝试使用 step = 1 循环 estimator.train ,但这会每次重新加载图形并变得非常慢。
我现在不知道该怎么办,也不知道是否可能......
感谢您的帮助!
答案 0 :(得分:1)
简短版本:将您的CSV文件转换为tfrecords
,然后使用tf.contrib.data.TFRecordDataset
。长版:请参阅代码查看问题/接受的答案here(为方便起见,在下面复制)。
查看tf.contrib.data.Dataset API。我怀疑你最好将CSV转换为TfRecord文件并使用TfRecordDataset。这里有一个完整的教程。
步骤1:将csv数据转换为tfrecords数据。下面的示例代码。
import tensorflow as tf
def read_csv(filename):
with open(filename, 'r') as f:
out = [line.rstrip().split(',') for line in f.readlines()]
return out
csv = read_csv('data.csv')
with tf.python_io.TFRecordWriter("data.tfrecords") as writer:
for row in csv:
features, label = row[:-1], row[-1]
features = [float(f) for f in features]
label = int(label)
example = tf.train.Example()
example.features.feature[
"features"].float_list.value.extend(features)
example.features.feature[
"label"].int64_list.value.append(label)
writer.write(example.SerializeToString())
这假定标签是最后一列中的整数,前面的列中包含浮点要素。这只需要运行一次。
步骤2:编写解码这些记录文件的数据集。
def parse_function(example_proto):
features = {
'features': tf.FixedLenFeature((n_features,), tf.float32),
'label': tf.FixedLenFeature((), tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features['features'], parsed_features['label']
def input_fn():
dataset = tf.contrib.data.TFRecordDataset(['data.tfrecords'])
dataset = dataset.map(parse_function)
dataset = dataset.shuffle(shuffle_size)
dataset = dataset.repeat() # repeat indefinitely
dataset = dataset.batch(batch_size)
print(dataset.output_shapes)
features, label = dataset.make_one_shot_iterator().get_next()
return features, label
进行测试(独立于估算器):
batch_size = 4
shuffle_size = 10000
features, labels = input_fn()
with tf.Session() as sess:
f_data, l_data = sess.run([features, labels])
print(f_data, l_data)
与tf.estimator.Estimator一起使用:
estimator.train(input_fn, max_steps=1e7)
答案 1 :(得分:0)
因此,在测试了很多可能性之后,我找到了纯CSV的解决方案,但是在某些条件下工作,这就是我再次需要你帮助的原因!
我们来看看代码:
filename = "test.csv"
queue = tf.train.string_input_producer([filename],
num_epochs=1,
shuffle=False)
reader = tf.TextLineReader()
_, csv_row = reader.read(queue)
record_defaults = [[0] for _ in range(341)]
cols = tf.decode_csv(csv_row, record_defaults=record_defaults)
features = tf.stack(cols[0:340])
label = tf.stack(cols[340])
# WORKS ------
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * 4
example_batch, label_batch = tf.train.batch(
[features, label], batch_size=4, capacity=capacity)
# ------------
# DOESN'T WORK
#def input_fn():
# min_after_dequeue = 1000
# capacity = min_after_dequeue + 3 * 4
# example_batch, label_batch = tf.train.batch(
# [features, label], batch_size=4, capacity=capacity)
# return example_batch, label_batch
# ------------
with tf.Session() as sess:
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
cpt = 0
while True:
try:
cpt += 1
print(cpt, end=' ')
#_, l = input_fn()
#print(l.eval())
print(label_batch.eval())
except tf.errors.OutOfRangeError:
break;
coord.request_stop()
coord.join(threads)
(我使用的CSV文件包含341列,340个功能和1个标签)
正如您所看到的,代码非常难看,但我们现在可以直接读取CSV。不幸的是,我们需要一个“input_fn”函数来与Estimator一起使用,这就是我尝试将批量创建放在“input_fn”中的原因。但是当运行这段代码时,当Tensorflow尝试这一行时,一切都会冻结(并阻止你漂亮的shell):
print(l.eval())
所以,如果有人知道为什么一切都停止了,请帮助我!
感谢。
答案 2 :(得分:0)
如果您担心tf.train.start_queue_runners
未被调用,请尝试以下操作:
class ThreadStartHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord):
self.coord = coord
self.threads = tf.train.start_queue_runners(coord=coord, sess=session)
def end(self, session):
self.coord.request_stop()
self.coord.join(self.threads)
estimator.train(input_fn, [ThreadStartHook()])
我开始时有类似的想法,但发现它没有必要。