标记TFRecords数据集

时间:2019-10-04 20:26:57

标签: python numpy tensorflow keras

我以前将5GB的json数据保存为多个tfrecords文件,并使用以下内容将其读回

import glob
train_files = [file for file in glob.glob("train/*.tfrecords")
test_files = [file for file in glob.glob("test/*.tfrecords")

raw_train = tf.data.TFRecordDataset(train_files)
raw_test = tf.data.TFRecordDataset(test_files)

在解析后打印(为了清楚起见,省略了此内容)

for record in parsed_train.take(1):
  print(record)
{
  'sentiment': <tf.Tensor: id=52, shape=(), dtype=int64, numpy=1>,
  'text': <tf.Tensor: id=53, shape=(), dtype=string, numpy=b'Great food!'>
}

我正在关注一个教程,在该教程中,我需要向Keras Tokenizer提供来自测试和训练数据集的字符串列表。我试图做的是创建一个空列表并附加每条记录。

from tensorflow.python.keras.preprocessing.text import Tokenizer

num_words = 10000
tokenizer = Tokenizer(num_words=num_words)

l = []
for train_record in parsed_train.take(-1):
  l.append(train_record['text'].numpy().decode('utf-8'))

for test_record in parsed_test.take(-1):
  l.append(test_record['text'].numpy().decode('utf-8'))

但是,这花费了可笑的时间,并且经常使我的4GB ram笔记本电脑崩溃。有没有更好的方法来执行此任务-例如在不使用列表格式的情况下标记TF张量?我应该将字符串附加在numpy数组中吗?

0 个答案:

没有答案