如何随机排列熊猫GroupBy对象?

时间:2020-04-30 17:52:49

标签: python-3.x pandas

我有一个pandas DataFrame,其中包含图像名称和几列包含要素,图像可以包含具有相同图像名称但列值不同的几行。

以下是DataFrame的外观:

        image  val1  val2  val3
0  image1.png    12    14    15
1  image1.png    10    15    10
2  image2.png    12    -3     7
3  image2.png    17    21     1
4  image6.png    12    12     2
5  image6.png   112    12    10

然后我需要按图像名称对图像进行分组,因此我使用groupby()

groups = df.groupby('image')

然后,我需要将数据分为训练集和验证集,所以我要执行以下操作:

groups = groups.apply(np.array)
training_set = groups[:separation_index]
valid_set = groups[separation_index:]

问题是我需要在拆分之前先对数据(组)进行洗牌。

我尝试了np.random.shuffle(groups),但是它不起作用,不会产生任何错误,但是不起作用,数据保持相同的顺序。

2 个答案:

答案 0 :(得分:2)

我认为您可以不进行分组,而是将唯一的组名(图像)作为列表,从该列表中随机选择火车图像,然后为数据帧编制索引。

df = pd.DataFrame.from_records(
    [
        {"image": "image1.png", "val1": 12, "val2": 14, "val3": 15},
        {"image": "image1.png", "val1": 10, "val2": 15, "val3": 10},
        {"image": "image2.png", "val1": 12, "val2": -3, "val3": 7},
        {"image": "image2.png", "val1": 17, "val2": 21, "val3": 1},
        {"image": "image6.png", "val1": 12, "val2": 12, "val3": 2},
        {"image": "image6.png", "val1": 112, "val2": 12, "val3": 10},
    ]
)

images = df["image"].unique()
train_images = np.random.choice(images, size=2, replace=False)

train_idxs = df["image"].isin(train_images)
train_df = df[train_idxs]
test_df = df[~train_idxs]

print(train_df)
print()
print(test_df)

        image  val1  val2  val3
0  image1.png    12    14    15
1  image1.png    10    15    10
4  image6.png    12    12     2
5  image6.png   112    12    10

        image  val1  val2  val3
2  image2.png    12    -3     7
3  image2.png    17    21     1

答案 1 :(得分:1)

您可以随机整理大熊猫中的数据:

groups = df.groupby('image')
grouped_df = groups.aggregate(np.sum)
# random order for all rows 
grouped_df = grouped_df.sample(frac=1)

结果:

In [103]: grouped_df
Out[103]:
            val1  val2  val3
image                       
image2.png    29    18     8
image6.png   124    24    12
image1.png    22    29    25

然后您可以将其编入索引:

grouped_df[:separation_index]
grouped_df[separation_index:]
相关问题