在Tensorflow数据集管道中计算样本权重

时间:2018-12-21 01:48:32

标签: python tensorflow

假设我有以下Tensorflow数据集,其中标签为[0,1] int值。数据集高度不平衡,我已经计算出要使用样本权重{1:5.1,0:0.8}作为映射。

权重不是原始TFRecords文件的一部分。如何修改代码以合并此样本权重映射,以便返回“ sample_weight”功能,以后可以在自定义Estimator中使用?

def train_input_fn(self):       
    feature_map = _get_features()

    def _parse_line(line):
        parsed_features = tf.parse_example(line, feature_map)
        labels = parsed_features.pop('target_open')

        return parsed_features, tf.reshape(labels, (-1,1))

    dataset = tf.data.TFRecordDataset('train.tfrecords')\
        .shuffle(buffer_size=10000)\
        .batch(self.batch_size)\
        .map(_parse_line, num_parallel_calls=6)\
        .repeat()\
        .prefetch(2)

    return dataset

0 个答案:

没有答案