我的问题是,如果我想为我的数据创建一个tfrecords文件,完成它需要大约15天,它有500000对模板,每个模板是32帧(图像)。为了节省时间,我有3个GPU,所以我想我可以在一个GPU上创建三个tfrocords文件,然后我可以在5天内完成创建tfrecords。但后来我搜索了一种方法将这三个文件合并到一个文件中,但找不到合适的解决方案。
那么有没有办法将这三个文件合并到一个文件中,或者有什么方法可以通过提供从三个tfrecords文件中提取的批量示例来训练我的网络,知道我正在使用数据集API。
答案 0 :(得分:5)
两个月前问了这个问题,我想您已经找到了解决方案。对于以下情况,答案是否定的,您无需创建单个HUGE tfrecord文件。只需使用新的DataSet API:
dataset = tf.data.TFRecordDataset(filenames_to_read,
compression_type=None, # or 'GZIP', 'ZLIB' if compress you data.
buffer_size=10240, # any buffer size you want or 0 means no buffering
num_parallel_reads=os.cpu_count() # or 0 means sequentially reading
)
# Maybe you want to prefetch some data first.
dataset = dataset.prefetch(buffer_size=batch_size)
# Decode the example
dataset = dataset.map(single_example_parser, num_parallel_calls=os.cpu_count())
dataset = dataset.shuffle(buffer_size=number_larger_than_batch_size)
dataset = dataset.batch(batch_size).repeat(num_epochs)
...
有关详细信息,请检查document。
答案 1 :(得分:0)
为希望合并多个heroku
.post("/apps",{name:'example'})
.then(app => {
let name = app.name;
heroku.post("/apps/" + app.name + "/slugs", {
body: {
process_types: {
web: "node-v0.10.20-linux-x64/bin/node web.js"
}
}
})
.then( async app => {
let id = app.id;
await new Promise((x) => {
require('fs').readFile(__dirname + "/slug.tgz", function (err, data) {
if (err) return console.error(err);
request({
url : app.blob.url,
body : data,
method: "PUT"
}, function (err, message,data) {
if (err) return console.error(err);
console.log(data)
x()
});
})
});
heroku
.post("/apps/" + name + "/releases", {body:{ "slug": app.id }})
.then(app => {
console.log(app);
}).catch(console.log)
})
})
文件的任何人直接解决问题标题:
最方便的方法是使用tf.Data API: (adapting an example from the docs)
.tfrecord
但是,正如holmescn所指出的,最好将.tfrecord文件保留为单独的文件,然后将它们作为单个tensorflow数据集一起读取。
您也可以参考a longer discussion regarding multiple .tfrecord
files on Data Science Stackexchange
答案 2 :(得分:0)
MoltenMuffins的答案适用于更高版本的tensorflow。但是,如果您使用的是较低版本,则必须遍历三个tfrecords并将它们保存到新的记录文件中,如下所示。这适用于tf版本1.0及更高版本。
def comb_tfrecord(tfrecords_path, save_path, batch_size=128):
with tf.Graph().as_default(), tf.Session() as sess:
ds = tf.data.TFRecordDataset(tfrecords_path).batch(batch_size)
batch = ds.make_one_shot_iterator().get_next()
writer = tf.python_io.TFRecordWriter(save_path)
while True:
try:
records = sess.run(batch)
for record in records:
writer.write(record)
except tf.errors.OutOfRangeError:
break
答案 3 :(得分:0)
自定义上面的脚本以获得更好的tfrecords列表
import os
import glob
import tensorflow as tf
save_path = 'data/tf_serving_warmup_requests'
tfrecords_path = glob.glob('data/*.tfrecords')
dataset = tf.data.TFRecordDataset(tfrecords_path)
writer = tf.data.experimental.TFRecordWriter(save_path)
writer.write(dataset)