当tf.data.Dataset
中的基础“数据示例”是平面数组时,是否存在一种推荐的/有效的方法将Tensor
转换为Dataset
?
我正在使用tf.data.csv
读取和解析CSV,但是想使用Tensorflow.js Core API将数据作为tf.Tensors
处理。
答案 0 :(得分:0)
tf.data.Dataset.iterator()
返回一个迭代器的Promise。
const it = await flattenedDataset.iterator()
const t = []
// read only the data for the first 5 rows
// all the data need not to be read once
// since it will consume a lot of memory
for (let i = 0; i < 5; i++) {
let e = await it.next()
t.push(...e.value)
}
tf.concat(await t, 0)
const asyncIterable = {
[Symbol.asyncIterator]() {
return {
i: 0,
async next() {
if (this.i < 5) {
this.i++
const e = await it.next()
return Promise.resolve({ value: e.value, done: false });
}
return Promise.resolve({ done: true });
}
};
}
};
const t = []
for await (let e of asyncIterable) {
if(e) {
t.push(e)
}
}
const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
(async function run() {
// We want to predict the column "medv", which represents a median value of
// a home (in $1000s), so we mark it as a label.
const csvDataset = tf.data.csv(
csvUrl, {
columnConfigs: {
medv: {
isLabel: true
}
}
});
// Number of features is the number of column names minus one for the label
// column.
const numOfFeatures = (await csvDataset.columnNames()).length - 1;
// Prepare the Dataset for training.
const flattenedDataset =
csvDataset
.map(([rawFeatures, rawLabel]) =>
// Convert rows from object form (keyed by column name) to array form.
[...Object.values(rawFeatures), ...Object.values(rawLabel)])
.batch(1)
const it = await flattenedDataset.iterator()
const asyncIterable = {
[Symbol.asyncIterator]() {
return {
i: 0,
async next() {
if (this.i < 5) {
this.i++
const e = await it.next()
return Promise.resolve({ value: e.value, done: false });
}
return Promise.resolve({ done: true });
}
};
}
};
const t = []
for await (let e of asyncIterable) {
if(e) {
t.push(e)
}
}
console.log(tf.concat(t, 0).shape)
})()
<html>
<head>
<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.1"> </script>
</head>
<body>
</body>
</html>
答案 1 :(得分:0)
请注意,由于具体实现,通常不建议使用此工作流程 JavaScript主内存中的所有数据可能不适用于大型CSV数据集。
您可以使用toArray()
个对象的tf.data.Dataset
方法。例如:
const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
const csvDataset = tf.data.csv(
csvUrl, {
columnConfigs: {
medv: {
isLabel: true
}
}
}).batch(4);
const tensors = await csvDataset.toArray();
console.log(tensors.length);
console.log(tensors[0][0]);