如何使用张量流或火花分割高度不平衡的数据?

时间:2018-12-14 22:13:56

标签: apache-spark tensorflow pyspark apache-spark-mllib tensorflow-transform

数据-

我拥有的训练和测试数据非常大〜150gb ,并且高度不平衡 99%neg 标签/ 1%pos 标签,并且我也无法将其作为非常重要的信息进行下采样,因此当前使用加权估算器。

问题-

如果我们使用spark()方法使用sample()函数拆分并保存到多个文件,则很有可能仅在多个文件中的一个文件(例如100个中的一个)中使用负样本,这会导致问题摄取数据时,只有正样本被馈送到估计器,导致零损失,模型无法学习。

此外,我在进行批处理时确实使用了shuffle,但是输入函数将多个文件作为输入,因此通过对每个文件中的数据进行混洗来创建批处理,这导致模型在非常长的时间内仅被馈送给小个案直到对带有否定样本的文件进行随机播放。

是否有更好的方法来确保在使用 pyspark 保存数据时,spark保存的每个文件都具有两者 类/标签中的示例(最好与总体数据pos / neg之比相同)?

在这些情况下,我尝试使用一个大文件进行馈送,并且随机播放可以正常工作,但是当我们馈送了许多文件时,由于只有一类的样本被馈送到模型中,因此它会造成零丢失的问题。

在tensorflow代码中使用以下输入功能-

def csv_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL,
             skip_header_lines=0,
             num_epochs=None,
             batch_size=1000):

shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False

num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1

print("")
print("* data input_fn:")
print("================")
print("Input file(s): {}".format(files_name_pattern))
print("Batch size: {}".format(batch_size))
print("Epoch Count: {}".format(num_epochs))
print("Mode: {}".format(mode))
print("Thread Count: {}".format(num_threads))
print("Shuffle: {}".format(shuffle))
print("================")
print("")

file_names = tf.matching_files(files_name_pattern)
dataset = data.TextLineDataset(filenames=file_names)

dataset = dataset.skip(skip_header_lines)

if shuffle:
    dataset = dataset.shuffle(buffer_size=2 * batch_size + 1)

dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda csv_row: parse_csv_row(csv_row),
                      num_parallel_calls=num_threads)

dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

features, target = iterator.get_next()
return features, target

任何建议将不胜感激!谢谢

1 个答案:

答案 0 :(得分:0)

找到了我自己问题的答案,因此您可以将buffer_size更改为完整数据集中的元素/行数,这样我们就可以确保使用shuffle随机分配的索引是统一的,如下所示:现在,改组是使用整个数据集完成的。

代码已更改-

import { Directive, ElementRef, AfterViewInit, Input, HostListener } from '@angular/core';

@Directive({
  selector: '[pyb-button-group]'
})
export class ButtonGroupDirective implements AfterViewInit {
  @Input() className: string;

  @HostListener('window:resize', ['$event'])
  onResize() {
    let resize = this.resize;
    let element = this.element;
    let className = this.className;

    setTimeout(function () {
      resize(element, className);
    }, 500);
  }

  constructor(private element: ElementRef) {
  }

  ngAfterViewInit() {
    this.resize(this.element, this.className);
  }

  resize(nativeElement, className) {
    let elements = nativeElement.nativeElement.getElementsByClassName(className || 'btn-choice');
    let headerHeight = 0;

    for (var i = 0; i < elements.length; i++) {
      let element = elements[i];
      let header = element.getElementsByClassName('header');

      if (!header.length) return;

      header = header[0];
      header.style.height = 'auto'; // Reset when resizing the window

      let height = header.offsetHeight;
      if (height > headerHeight) headerHeight = height;
    }

    for (var i = 0; i < elements.length; i++) {
      let element = elements[i];
      let header = element.getElementsByClassName('header');

      if (!header.length) return;

      header = header[0];
      header.style.height = headerHeight + 'px';
    }
  }
}