从数据中生成随机样本

时间:2020-05-16 10:43:46

标签: python-3.x error-handling sampling

我在数据集中有506点。我必须从这些数据中生成随机样本,例如我必须选择303个点而不进行替换,而剩下的203个点我需要从这303个点中选择。

我写了以下代码。

def generating_samples(input_data, target_data):

    selected_rows = np.random.choice(len(input_data), 303)
    replacing_rows = np.random.choice(selected_rows,203)
    selected_columns = np.random.choice(3,13,1)
    sample_data = input_data[selected_rows[:,None],selected_columns]
    target_of_sample_data = target_data[selected_rows]

    #replicating data
    replicated_sample_data = sample_data[replacing_rows]
    target_of_replicated_sample_data = target_data[replacing_rows]

    #concatenating data
    sampled_input_data = np.vstack(sample_data, replicated_sample_data)
    target_of_sample_data = target_of_sample_data.reshape(-1,1)
    target_of_replicated_sample_data = target_of_replicated_sample_data.reshape(-1,1)
    sampled_target_data = np.vstack(target_of_sample_data,target_of_replicated_sample_data)

    return sampled_input_data , sampled_target_data, selected_rows,selected_columns



def grader_samples(a,b,c,d):
        length = (len(a)==506  and len(b)==506)
        sampled = (len(a)-len(set([str(i) for i in a]))==203)
        rows_length = (len(c)==303)
        column_length= (len(d)>=3)
        assert(length and sampled and rows_length and column_length)
        return True

a,b,c,d = generating_samples(x, y)
grader_samples(a,b,c,d)

但是在此方面出现以下错误。

IndexError                                Traceback (most recent call last)
<ipython-input-14-ca772632e834> in <module>
      7     return True
      8 
----> 9 a,b,c,d = generating_samples(x, y)
     10 grader_samples(a,b,c,d)

<ipython-input-13-bcf904f160e5> in generating_samples(input_data, target_data)
     13 
     14     #replicating data
---> 15     replicated_sample_data = sample_data[replacing_rows]
     16     target_of_replicated_sample_data = target_data[replacing_rows]
     17 

IndexError: index 391 is out of bounds for axis 0 with size 303

2 个答案:

答案 0 :(得分:0)

使用:replicated_sample_data = input_data[replacing_rows],因为复制的样本数据来自原始数据集。 而且样本数据已经从原始数据集中进行了采样,因此它是我们原始数据集的一个子集,并导致索引错误

答案 1 :(得分:0)

由于索引和非唯一样本而发生错误。 使用:

<块引用>

selected_rows = np.random.choice(len(input_data), 303, replace=False)

因为 replace=False 会得到 303 唯一的样本索引。索引值可用于从 input_data 中提取行。 对于复制样本,我们可以选择

<块引用>

replacing_rows = np.random.choice(len(selected_rows),203,replace=False)

在replace_rows中,我们将得到一个唯一的样本索引。 现在我们可以从样本数据集中选择替换样本。

<块引用>

replicated_sample_data = sample_data[replacing_rows]