使用张量流数据集训练表格数据

时间:2021-03-09 19:00:53

标签: python keras out-of-memory tensorflow-datasets

我有来自三个来源的表格数据,并想使用它们训练多输入神经网络(我使用 tensorflow 和 keras 包)。我能够将整个数据加载到内存中并执行预处理步骤,但是,当我开始训练网络时,出现内存错误。

您有什么建议?我看到有些人推荐使用数据加载器,但是,所有示例都是针对图像数据的,我找不到任何表格数据示例。如果您能帮助我使用表格数据的数据加载器实现训练循环,那就太好了。

到目前为止,我已经学会了如何将表格数据转换为 tf.data.Dataset 为:

import numpy as np
import pandas as pd
import tensorflow as tf

training_df: pd.DataFrame = pd.DataFrame(
    data={
        'feature1': np.random.rand(10),
        'feature2': np.random.rand(10),
        'feature3': np.random.rand(10),
        'target': np.random.randint(0, 3, 10)
    }
)
features = ['feature1', 'feature2', 'feature3']
print(training_df)

training_dataset = (
    tf.data.Dataset.from_tensor_slices(
        (
            tf.cast(training_df[features].values, tf.float32),
            tf.cast(training_df['target'].values, tf.int32)
        )
    )
)

for features_tensor, target_tensor in training_dataset:
    print(f'features:{features_tensor} target:{target_tensor}')

感谢任何帮助。

0 个答案:

没有答案