假设我有以下Tensorflow数据集,其中标签为[0,1] int值。数据集高度不平衡,我已经计算出要使用样本权重{1:5.1,0:0.8}作为映射。
权重不是原始TFRecords文件的一部分。如何修改代码以合并此样本权重映射,以便返回“ sample_weight”功能,以后可以在自定义Estimator中使用?
def train_input_fn(self):
feature_map = _get_features()
def _parse_line(line):
parsed_features = tf.parse_example(line, feature_map)
labels = parsed_features.pop('target_open')
return parsed_features, tf.reshape(labels, (-1,1))
dataset = tf.data.TFRecordDataset('train.tfrecords')\
.shuffle(buffer_size=10000)\
.batch(self.batch_size)\
.map(_parse_line, num_parallel_calls=6)\
.repeat()\
.prefetch(2)
return dataset