使用tensorflow的dataset.map()方法对类进行存储桶化

时间:2019-05-10 16:22:49

标签: python tensorflow tensorflow-datasets

我需要对数据集的类进行存储桶。但是,我正在处理一个大数据问题,因此我认为使用tensorflow数据集类以避免将整个数据集加载到内存中将是一个好方法。

我正在尝试将类列中的离散数字划分为一组垃圾箱,以便将其输入模型中。是否可以避免使用tf.cond()?我当时在考虑使用tf.feature_columns(),但是我不确定是否可以在类而不是功能中执行此类操作。

我尝试使用通用的python if子句执行此方法,但是tensorflow不允许此操作。

我已经尝试过了,但是没有用,但是很快就可以解决问题。

def _parse_csv_row(*vals):

    (...)

    rg = vals[4]
    print(rg)

    if rg < -8:
        class_label = tf.constant(0)
    elif rg >= -8 and rg <0:
        class_label = tf.constant(1)
    elif rg == 0:
        class_label = tf.constant(2)
    elif rg > 0 and rg < 3:
        class_label = tf.constant(3)
    elif rg >= 3:
        class_label = tf.constant(4)

    return features, class_label

train_dataset = dataset.map(_parse_csv_row).shuffle(dataset_length, seed = 0).skip(10000).batch(batch_size, drop_remainder = True)

0 个答案:

没有答案