Tensorflow.js数据集到Tensor?

时间:2019-03-02 04:50:18

标签: tensorflow.js

tf.data.Dataset中的基础“数据示例”是平面数组时,是否存在一种推荐的/有效的方法将Tensor转换为Dataset

我正在使用tf.data.csv读取和解析CSV,但是想使用Tensorflow.js Core API将数据作为tf.Tensors处理。

2 个答案:

答案 0 :(得分:0)

tf.data.Dataset.iterator()返回一个迭代器的Promise。

const it = await flattenedDataset.iterator()
   const t = []
   // read only the data for the first 5 rows
   // all the data need not to be read once 
   // since it will consume a lot of memory
   for (let i = 0; i < 5; i++) {
        let e = await it.next()
      t.push(...e.value)
   }
  tf.concat(await t, 0)

使用for await of

const asyncIterable = {
  [Symbol.asyncIterator]() {
    return {
      i: 0,
      async next() {
        if (this.i < 5) {
          this.i++
          const e = await it.next()
          return Promise.resolve({ value: e.value, done: false });
        }

        return Promise.resolve({ done: true });
      }
    };
  }
};

  const t = []
  for await (let e of asyncIterable) {
        if(e) {
          t.push(e)
        }
   }

const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';

(async function run() {
   // We want to predict the column "medv", which represents a median value of
   // a home (in $1000s), so we mark it as a label.
   const csvDataset = tf.data.csv(
     csvUrl, {
       columnConfigs: {
         medv: {
           isLabel: true
         }
       }
     });

   // Number of features is the number of column names minus one for the label
   // column.
   const numOfFeatures = (await csvDataset.columnNames()).length - 1;

   // Prepare the Dataset for training.
   const flattenedDataset =
     csvDataset
     .map(([rawFeatures, rawLabel]) =>
       // Convert rows from object form (keyed by column name) to array form.
       [...Object.values(rawFeatures), ...Object.values(rawLabel)])
   			.batch(1)
  
	const it = await flattenedDataset.iterator()
  const asyncIterable = {
  [Symbol.asyncIterator]() {
    return {
      i: 0,
      async next() {
        if (this.i < 5) {
          this.i++
          const e = await it.next()
          return Promise.resolve({ value: e.value, done: false });
        }

        return Promise.resolve({ done: true });
      }
    };
  }
};
  
  const t = []
  for await (let e of asyncIterable) {
    	if(e) {
          t.push(e)
        }
   }
  console.log(tf.concat(t, 0).shape)
})()
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.1"> </script>
  </head>

  <body>
  </body>
</html>

答案 1 :(得分:0)

请注意,由于具体实现,通常不建议使用此工作流程 JavaScript主内存中的所有数据可能不适用于大型CSV数据集。

您可以使用toArray()个对象的tf.data.Dataset方法。例如:

  const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';

  const csvDataset = tf.data.csv(
     csvUrl, {
       columnConfigs: {
         medv: {
           isLabel: true
         }
       }
     }).batch(4);

  const tensors = await csvDataset.toArray();
  console.log(tensors.length);
  console.log(tensors[0][0]);