有效地在对象中找到最高值,并识别对象

时间:2017-02-17 15:04:34

标签: python python-3.x sorting object

假设我有5个对象:test_pandas_df = pd.read_csv( '/home/piotrek/ml/adults/adult.test', names=header, skipinitialspace=True) train_pandas_df = pd.read_csv( '/home/piotrek/ml/adults/adult.data', names=header, skipinitialspace=True) train_df = sqlContext.createDataFrame(train_pandas_df) test_df = sqlContext.createDataFrame(test_pandas_df) joined = train_df.union(test_df) assembler = VectorAssembler().setInputCols(features).setOutputCol("features") label_indexer = StringIndexer().setInputCol( "label").setOutputCol("label_index") label_indexer_fit = [label_indexer.fit(joined)] string_indexers = [StringIndexer().setInputCol( name).setOutputCol(name + "_index").fit(joined) for name in categorical_feats] one_hot_pipeline = Pipeline().setStages([OneHotEncoder().setInputCol( name + '_index').setOutputCol(name + '_one_hot') for name in categorical_feats]) mlp = MultilayerPerceptronClassifier().setLabelCol(label_indexer.getOutputCol()).setFeaturesCol( assembler.getOutputCol()).setLayers([len(features), 20, 10, 2]).setSeed(42L).setBlockSize(1000).setMaxIter(500) pipeline = Pipeline().setStages(label_indexer_fit + string_indexers + [one_hot_pipeline] + [assembler, mlp]) model = pipeline.fit(train_df) # compute accuracy on the test set result = model.transform(test_df) ## FAILS ON RESULT predictionAndLabels = result.select("prediction", "label_index") evaluator = MulticlassClassificationEvaluator(labelCol="label_index") print "-------------------------------" print("Test set accuracy = " + str(evaluator.evaluate(predictionAndLabels))) print "-------------------------------" obj1

obj5

如何找到(有效)obj1.x = 2.7 obj2.x = 0.9 obj3.x = 3.8 obj4.x = 1.2 obj5.x = 0.4 的最高值,并确定相应的x?这里的预期答案是:

obj

顺便说一句,在实际情况中,我有x = 3.8, it belongs to obj3 个对象。

2 个答案:

答案 0 :(得分:2)

最好将它们放在一个集合中(例如ConfigurationManager.AppSettings.Item("MetropolisBold").ToSt‌​ring() list,..)并使用tupleoperator.attrgetter来抓取最大值为{的对象{1}}:

max

现在返回的值x对应于属性from operator import attrgetter l = obj1, obj2, obj3, obj4, obj5 o = max(l, key=attrgetter('x')) 的最大值,即o

x

理想情况下,您不应该通过名称“识别”对象,这可能很容易更改,如果您有一个未按名称排序的列表,您将得到错误的结果。

相反,你应该从另一个属性(例如obj3)给它创建的类,并为它定义一个o == obj3 # True / "name"来打印出名称和值。< / p>

答案 1 :(得分:2)

如果您有一个列表并且需要最大索引,则可以将maxenumerate合并为纯Python中的numpy.argmax

l = [obj1, obj2, obj3, ...]
i = max(enumerate(l), key=lambda x: x[1].x)[0]

i将是包含最大值的索引,因此您可以将其打印为

print('x = {}, belongs to obj{}'.format(l[i].x, i + 1))

在您提供的示例中,i == 2打印x = 3.8, belongs to obj3