S项,M桶加权选择算法

时间:2015-09-18 21:07:02

标签: algorithm sampling

我想从M个桶中抽取总共S个样本。每个桶具有权重W,其描述了最终样本中来自给定桶的项目的表示。例如,如果我的桶A,B和C的重量分别为0.5,0.2和0.3,每个桶的样品数量足够大,那么如果我的最终样品尺寸S = 10,我希望我的样品包含来自桶A的5个样品,来自桶B的2个样品和来自桶C的3个样品。当考虑到每个桶可能不包含根据重量和总样品尺寸计算的所需数量的样品时,问题变得更加复杂。在这种情况下,需要调整其他权重,以便尽可能接近所需的加权表示来传送样本。有谁知道这样做的算法?

2 个答案:

答案 0 :(得分:0)

我用Java编写了一个解决方案。由于舍入错误,它可能会返回一个或两个以上的样本,但这对我的应用程序来说很好。如果您发现任何方法可以使算法更好,请随时发布解决方案。

SampleNode.java

public abstract class SampleNode {
    protected double weight;

    protected abstract int getNumSamplesAvailable();
    protected abstract boolean hasSamples();
    protected abstract int takeAllSamples();
    protected abstract void sample(int target);
    public abstract boolean takeOneSample();
}

LeafSampleNode.java

public class LeafSampleNode extends SampleNode {
    private int numselected;
    private int numsamplesavailable;

    public LeafSampleNode(double weight, int numsamplesavailable) {
        this.weight = weight;
        this.numsamplesavailable = numsamplesavailable;
        this.numselected = 0;
    }

    protected void sample(int target) {
        if(target >= numsamplesavailable) {
            takeAllSamples();
        }
        else {
            numselected += target;
            numsamplesavailable -= target;
        }
    }

    @Override
    protected int getNumSamplesAvailable() {
        return numsamplesavailable;     
    }

    protected boolean hasSamples() {
        return numsamplesavailable > 0;
    }

    protected int getNumselected() {
        return numselected;
    }

    protected int takeAllSamples() {
        int samplestaken = numsamplesavailable;
        numselected += numsamplesavailable;
        numsamplesavailable = 0;
        return samplestaken;
    }
@Override
public boolean takeOneSample() {
    if(hasSamples()) {
        numsamplesavailable--;
        numselected++;
        return true;
    }
    return false;
}
}

RootSampleNode.java:

import java.util.ArrayList;
import java.util.List;

public class RootSampleNode extends SampleNode {    
    private List<SampleNode> children;

    public RootSampleNode(double weight) {
        this.children = new ArrayList<SampleNode>();
        this.weight = weight;
    }

    public void selectSample(int target) {
        int totalsamples = getNumSamplesAvailable();
        if(totalsamples < target) { 
            //not enough samples to meet target, simply take everything
            for(int i = 0; i < children.size(); i++) {
                children.get(i).takeAllSamples();
            }
        }
        else {
            //there are enough samples to meet target, distribute to meet quotas as closely as possible
            sample(target);
        }
    }

    protected void sample(int target) {
        int samplestaken = 0;
        double totalweight = getTotalWeight(children);
        samplestaken +=  sample(totalweight, target, children);
        if(samplestaken < target) {
            sample(target - samplestaken);
        }
    }

    private int sample(double totalweight, int target, List<SampleNode> children) {
        int samplestaken = 0;
        for(int i = 0; i < children.size(); i++) {
            SampleNode child = children.get(i);
            if(child.hasSamples()) {
                int desired = (int) (target * (child.weight / totalweight) + 0.5);
                if(desired >= child.getNumSamplesAvailable()) {
                    samplestaken += child.takeAllSamples();
                }
                else {
                    child.sample(desired);
                    samplestaken += desired;
                }
            }           
        }
    if(samplestaken == 0) { //avoid deadlock / stack overflow...someone just take a sample
        for(int i = 0; i < children.size(); i++) {
            if(children.get(i).takeOneSample()) {
                samplestaken++;
                break;
            }   
        }
    }
        return samplestaken;
    }

@Override
public boolean takeOneSample() {
    if(hasSamples()) {
        for(int i = 0; i < children.size(); i++) {
            if(children.get(i).takeOneSample()) {
                return true;
            }
        }           
    }
    return false;
}

    protected double getTotalWeight(List<SampleNode> children) {
        double totalweight = 0;
        for(int i = 0; i < children.size(); i++) {
            SampleNode child = children.get(i);
            if(child.hasSamples()) {
                totalweight += child.weight;
            }
        }
        return totalweight;
    }

    protected boolean hasSamples() {
        for(int i = 0; i < children.size(); i++) {
            if(children.get(i).hasSamples()) {
                return true;
            }
        }
        return false;
    }

    protected int takeAllSamples() {
        int samplestaken = 0;
        for(int i = 0; i < children.size(); i++) {
            samplestaken += children.get(i).takeAllSamples();
        }
        return samplestaken;
    }

    protected int getNumSamplesAvailable() {
        int numsamplesavailable = 0;
        for(int i = 0; i < children.size(); i++) {
            numsamplesavailable += children.get(i).getNumSamplesAvailable();
        }
        return numsamplesavailable;
    }

    public void addChild(SampleNode sn) {
        this.children.add(sn);
    }
}

一些单元测试:

import static org.junit.Assert.assertTrue;

import org.junit.Test;

public class SampleNodeTest {

    @Test
    public void test1() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        root.selectSample(9);
        assertTrue(bucketA.getNumselected() == 5);
        assertTrue(bucketB.getNumselected() == 2);
        assertTrue(bucketC.getNumselected() == 3);
    }

    @Test
    public void test2() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        root.selectSample(13);
        assertTrue(bucketA.getNumselected() == 5);
        assertTrue(bucketB.getNumselected() == 2);
        assertTrue(bucketC.getNumselected() == 3);
    }

    @Test
    public void test3() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(0.5);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(13);
        assertTrue(bucketA.getNumselected() == 4);
        assertTrue(bucketB.getNumselected() == 2);
        assertTrue(bucketC.getNumselected() == 3);
        assertTrue(bucketD.getNumselected() == 3);
        assertTrue(bucketE.getNumselected() == 1);
    }

    @Test
    public void test4() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(13);
        assertTrue(bucketA.getNumselected() == 3);
        assertTrue(bucketB.getNumselected() == 1);
        assertTrue(bucketC.getNumselected() == 2);
        assertTrue(bucketD.getNumselected() == 5);
        assertTrue(bucketE.getNumselected() == 2);
    }

    @Test
    public void test5() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(4);
        assertTrue(bucketA.getNumselected() == 1);
        assertTrue(bucketB.getNumselected() == 0);
        assertTrue(bucketC.getNumselected() == 1);
        assertTrue(bucketD.getNumselected() == 1);
        assertTrue(bucketE.getNumselected() == 1);
    }

    @Test
    public void test6() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 5);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 2);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 3);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 10);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 12);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(2);
        assertTrue(bucketA.getNumselected() == 1);
        assertTrue(bucketB.getNumselected() == 0);
        assertTrue(bucketC.getNumselected() == 0);
        assertTrue(bucketD.getNumselected() == 1);
        assertTrue(bucketE.getNumselected() == 0);
    }

    @Test
    public void test7() {
        RootSampleNode root = new RootSampleNode(1);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 50);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 20);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 33);
        root.addChild(bucketA);
        root.addChild(bucketB);
        root.addChild(bucketC);
        RootSampleNode branch1 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 100);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 120);
        branch1.addChild(bucketD);
        branch1.addChild(bucketE);
        root.addChild(branch1);
        root.selectSample(200);
        assertTrue(bucketA.getNumselected() == 50);
        assertTrue(bucketB.getNumselected() == 20);
        assertTrue(bucketC.getNumselected() == 30);
        assertTrue(bucketD.getNumselected() == 71);
        assertTrue(bucketE.getNumselected() == 29);
    }

    @Test
    public void test8() {
        RootSampleNode root = new RootSampleNode(1);
        RootSampleNode branch1 = new RootSampleNode(5);
        LeafSampleNode bucketA = new LeafSampleNode(0.5, 50);
        LeafSampleNode bucketB = new LeafSampleNode(0.2, 20);
        LeafSampleNode bucketC = new LeafSampleNode(0.3, 33);
        branch1.addChild(bucketA);
        branch1.addChild(bucketB);
        branch1.addChild(bucketC);
        RootSampleNode branch2 = new RootSampleNode(1);
        LeafSampleNode bucketD = new LeafSampleNode(0.5, 100);
        LeafSampleNode bucketE = new LeafSampleNode(0.2, 120);
        branch2.addChild(bucketD);
        branch2.addChild(bucketE);
        root.addChild(branch1);
        root.addChild(branch2);
        root.selectSample(200);
        assertTrue(bucketA.getNumselected() == 50);
        assertTrue(bucketB.getNumselected() == 20);
        assertTrue(bucketC.getNumselected() == 33);
        assertTrue(bucketD.getNumselected() == 70);
        assertTrue(bucketE.getNumselected() == 27);
    }
}

希望有人觉得这很有用。

答案 1 :(得分:0)

我的建议是编写一个循环,从当前重量离所需重量最远且非空的桶中取样。这是一些伪代码。显然,您可能希望将此概括为更多存储桶,但这应该可以为您提供这个想法。

set buckets[] = { // original items };
double weights[] = { 0.5, 0.2, 0.3}; // the desired weights
int counts[] = { 0, 0, 0 };  // number of items sampled so far

for (i = 0; i < n; i++) {
  double errors[] = { 0.0, 0.0, 0.0 };
  for (j = 0; j < 3; j++) {
    if (!empty(buckets[j]))
      errors[j] = abs(weights[j] - (counts[j] / n))
    else
      errors[j] = 0;
  }
  // choose the non-empty bucket whose current weight is 
  // furthest from the desired weight
  k = argmax(errors);
  sample(buckets[k]);  // take an item out of that bucket
  counts[k]++;         // increment count
}

如果你需要将它翻译成有效的Java,我可能会谈到:)。这将始终生成n个样本(假设至少有n项,否则会对所有项进行采样),其分布尽可能接近所需权重。