我有一张正在处理的FITS表。我想将表随机分为训练和测试数据,以创建两个新的FITS表。
我首先想到使用scikit-learn
函数test_train_split
,但是随后我不得不将数据来回转换为numpy.array
。
到目前为止,我已经从FITS文件中读取了astropy.table.Table data
并尝试了以下操作
training_fraction = 0.5
n = len(data)
indexes = random.sample(range(n), k=int(n*training_fraction))
testing_sample = data[indexes]
training_sample = ?
但是,我不知道如何获取索引不在indexes
中的所有行。也许有更好的方法可以做到这一点?如何获得表格的随机分区?
表中的样本恰好都有唯一的ID,该ID是1到len(data)之间的整数。所以我想,我可以做
indexes = random.sample(range(1, n+1), k=int(n*training_fraction))
testing_sample = data[data['ID'] in indexes]
training_sample = data[data['ID'] not in indexes]
但第一行引发ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
答案 0 :(得分:0)
我是如何做到的
training_indexes = sorted(random.sample(range(n), k=int(n*training_fraction)))
testing_indexes = [i for i in range(n) if i not in training_indexes]
testing_sample = data[testing_indexes]
training_sample = data[training_indexes]
但是我不知道这是最有效的方法还是最pythonic的方法。
答案 1 :(得分:0)
您提到使用来自scikit-learn的现有train_test_split
路由。如果这是您正在使用scikit-learn的 only 东西,那就太过分了。但是,如果您已经将它用于任务的其他部分,则也可能会使用。首先,Numpy数组已经支持Astropy Tables,因此您无需“来回转换数据”。
由于表的'ID'
列对表中的行进行了索引,因此将其正式设置为表的index很有用,以便可以使用ID值对表中的行进行索引表格(独立于其实际位置索引)。例如:
>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
... 'ID': [1, 3, 5, 6, 7, 9],
... 'a': np.random.random(6),
... 'b': np.random.random(6)
... })
>>> t
<Table length=6>
ID a b
int64 float64 float64
----- ------------------- -------------------
1 0.7285295918917892 0.6180944983953155
3 0.9273855839237182 0.28085439237508925
5 0.8677312765220222 0.5996267567496841
6 0.06182255608446752 0.6604620336092745
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
然后将'ID'
设置为表的索引:
>>> t.add_index('ID')
根据需要使用train_test_split
对ID进行分区:
>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
ID a b
int64 float64 float64
----- ------------------- ------------------
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
5 0.8677312765220222 0.5996267567496841
1 0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
ID a b
int64 float64 float64
----- ------------------- -------------------
6 0.06182255608446752 0.6604620336092745
3 0.9273855839237182 0.28085439237508925
(注意:
>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
astropy.table.column.BaseColumn,
astropy.table._column_mixins._ColumnGetitemShim,
numpy.ndarray,
object)
)
对于它的价值,因为它可能会帮助您将来更轻松地找到类似问题的答案,因此有助于您更抽象地考虑要做的事情(看来您已经了这样做,但是对您的问题的表述则相反):表中的列只是Numpy数组-一旦以这种形式出现,从FITS文件中读取它们就无关紧要。此时,您正在做的事情与Astropy也没有直接关系。问题就变成了如何随机划分一个Numpy数组。
您可以找到此问题的一般答案,例如in this question。但是,如果有的话,可以使用现有的专用工具,例如train_test_split
。