如何访问/处理TensorFlow数据集中的内容?

时间:2019-05-26 14:12:51

标签: python tensorflow tensorflow-datasets

我正在使用cnn_dailymail的数据集TensorFlow Datasets。 我按如下方式访问它:

    <?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>by.zaharik</groupId>
    <artifactId>UseBarCamera</artifactId>
    <version>1.0</version>
    <packaging>jar</packaging>

    <dependencies>
        <dependency>
            <groupId>com.github.sarxos</groupId>
            <artifactId>webcam-capture</artifactId>
            <version>0.3.12</version>
        </dependency>
        <dependency>
            <groupId>com.github.sarxos</groupId>
            <artifactId>webcam-capture-driver-ipcam</artifactId>
            <version>0.3.12</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-api -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>1.7.26</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.2.3</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/xuggle/xuggle-xuggler -->
        <dependency>
            <groupId>xuggle</groupId>
            <artifactId>xuggle-xuggler</artifactId>
            <version>5.4</version>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

要从数据集中提取一个示例,请使用:

import tensorflow_datasets as tfds
data, info = tfds.load('cnn_dailymail', with_info=True)
train_data, test_data = data['train'], data['test']

这将返回类似于以下内容的字符串:cnn_ex, = train_data.take(1) cnn_ex['highlights'].numpy() 。我想对该数据集应用一些预处理步骤,以便将其用作深度学习算法的输入。经过预处理后,上面的示例应如下所示:"emma monaghan, 27, from glasgow, used to weigh 18st 5lbs ."

是否有一种方法可以立即访问和预处理所有文本(在 train_data 中),而无需多次应用"<start> emma monaghan, 27, from glasgow, used to weigh 18st 5lbs . <end>"函数?例如,将TensorFlow数据集转换为numpy数组将已经有所帮助。谢谢!

2 个答案:

答案 0 :(得分:1)

这取决于您的特定目标。也许tfds.as_numpy()是您想要的。您可以将其应用于train_data以获取generator_object。您可以直接对其进行迭代,也可以应用任何地图函数

train_data = train_data.map(map_func)
for i in tfds.as_numpy(train_data):
    print(i)
    ...

答案 1 :(得分:1)

您可以使用dataset.map()将转换应用于数据。例如:

import tensorflow as tf
import tensorflow_datasets as tfds

data, info = tfds.load('cnn_dailymail', with_info=True)
dataset_train, dataset_test = data['train'], data['test']

def map_fn(x, start=tf.constant('<start>'), end=tf.constant('<end>')):
    strings = [start, x['highlights'], end]
    x['highlights'] = tf.strings.join(strings, separator=' ')
    return x

dataset_train = dataset_train.map(map_fn) # <-- apply transformation for the whole data
elem,  = dataset_train.take(1)
print(elem['highlights'].numpy())
# b'<start> arthur potts dawson: british ... <end>'