我如何显示标签和预测-PySpark

时间:2018-12-03 23:22:26

标签: python pyspark cross-validation databricks azure-databricks

创建一种算法来对市场产品进行分类,因此我无法返回预测的标签,我尝试了几个命令,但所有命令都有错误(如下)。如何返回标签和百分比预测(我正在使用交叉验证)?

示例:

我想告知您产品“ 7脉轮手链7脉轮手链,蓝色或黑色”,并知道其标签和准确性(此产品的标签返回“手链”)

训练数据

data = spark.createDataFrame([
("Bracelet"," 7 Shakra Bracelet 7 chakra bracelet, in blue or black."),
("Bracelet"," Anchor Bracelet Mens Black leather bracelet with gold or silver anchor for men."),
("Bracelet"," Bangle Bracelet Gold bangle bracelet with studded jewels."),
("Bracelet"," Boho Bangle Bracelet Gold boho bangle bracelet with multicolor tassels."),
("Earrings"," Boho Earrings Turquoise globe earrings on 14k gold hooks."),
("Necklace"," Choker with Bead Black choker necklace with 14k gold bead."),
("Necklace"," Choker with Triangle Black choker with silver triangle pendant."),
("Necklace"," Dainty Gold Necklace Dainty gold necklace with two pendants."),
("Necklace"," Dreamcatcher Pendant Necklace Turquoise beaded dream catcher necklace. Silver feathers adorn this beautiful dream catcher, which move and twinkle as you walk."),
("Earrings"," Galaxy Earrings One set of galaxy earrings, with sterling silver clasps."),
("Necklace"," Gold Bird Necklace 14k Gold delicate necklace, with bird between two chains."),
("Earrings"," Gold Elephant Earrings Small 14k gold elephant earrings, with opal ear detail."),
("Earrings"," Guardian Angel Earrings Sterling silver guardian angel earrings with diamond gemstones."),
("Bracelet"," Moon Charm Bracelet Moon 14k gold chain friendship bracelet."),
("Necklace"," Origami Crane Necklace Sterling silver origami crane necklace."),
("Necklace"," Pretty Gold Necklace 14k gold and turquoise necklace. Stunning beaded turquoise on gold and pendant filled double chain design."),
("Necklace"," Silver Threader Necklace Sterling silver chain thread through circle necklace."),
("Necklace"," Stylish Summer Necklace Double chained gold boho necklace with turquoise pendant.")

], ["id", "description"])

令牌,文本处理和矢量计数器

from pyspark.ml.feature import RegexTokenizer, StopWordsRemover, CountVectorizer
from pyspark.ml.classification import LogisticRegression
# regular expression tokenizer
regexTokenizer = RegexTokenizer(inputCol="description", outputCol="words", pattern="\\W")
# stop words
add_stopwords = ["http","https","amp","rt","t","c","the"] 
stopwordsRemover = StopWordsRemover(inputCol="words", outputCol="filtered").setStopWords(add_stopwords)
# bag of words count
countVectors = CountVectorizer(inputCol="filtered", outputCol="features", vocabSize=10000, minDF=5)

标签创建和数据集创建

from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
label_stringIdx = StringIndexer(inputCol = "id", outputCol = "label")
pipeline = Pipeline(stages=[regexTokenizer, stopwordsRemover, countVectors, label_stringIdx])
# Fit the pipeline to training documents.
pipelineFit = pipeline.fit(data)
dataset = pipelineFit.transform(data)

到目前为止,我的数据集的结果是

enter image description here

人口交叉算法

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")

lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0)
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
# Create ParamGrid for Cross Validation
paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.1, 0.3, 0.5]) # regularization parameter
             .addGrid(lr.elasticNetParam, [0.0, 0.1, 0.2]) # Elastic Net Parameter (Ridge = 0)
#            .addGrid(model.maxIter, [10, 20, 50]) #Number of iterations
#            .addGrid(idf.numFeatures, [10, 100, 1000]) # Number of features
             .build())
# Create 5-fold CrossValidator
cv = CrossValidator(estimator=lr, \
                    estimatorParamMaps=paramGrid, \
                    evaluator=evaluator, \
                    numFolds=5)
cvModel = cv.fit(dataset)

创建要分类的数据

testData = spark.createDataFrame([
(10," 7 Shakra Bracelet 7 chakra bracelet, in blue or black."),
(11," Anchor Bracelet Mens Black leather bracelet with gold or silver anchor for men."),
(12," Bangle Bracelet Gold bangle bracelet with studded jewels."), 
(13," 7 Shakra Bracelet 7 chakra bracelet, in blue or black."),
(14," Anchor Bracelet Mens Black leather bracelet with gold or silver anchor for men."),
(15," Bangle Bracelet Gold bangle bracelet with studded jewels."), 
  (100," 7 Shakra Bracelet 7 chakra bracelet, in blue or black."),
(16," Anchor Bracelet Mens Black leather bracelet with gold or silver anchor for men."),
(17," Bangle Bracelet Gold bangle bracelet with studded jewels."), 
  (101," 7 Shakra Bracelet 7 chakra bracelet, in blue or black."),
(18," Anchor Bracelet Mens Black leather bracelet with gold or silver anchor for men."),
(19," Bangle Bracelet Gold bangle bracelet with studded jewels."), 
  (104," 7 Shakra Bracelet 7 chakra bracelet, in blue or black."),
(20," Anchor Bracelet Mens Black leather bracelet with gold or silver anchor for men."),
(21," Bangle Bracelet Gold bangle bracelet with studded jewels.")
], ["rowid", "description"])

我创建了一个新的数据集,应该通过仅删除labelIndex列来进行排序

pipeline = Pipeline(stages=[regexTokenizer, stopwordsRemover, countVectors])
# Fit the pipeline to training documents.
pipelineFit = pipeline.fit(testData)
datasetTest = pipelineFit.transform(testData)

在这里,我使用datasetTest计算新的预测

enter image description here

在这里一切都顺利完成

现在问题来了,我无法从变量预测中看到任何信息

我在下面尝试了命令,但是发生了所有错误

enter image description here

enter image description here

enter image description here

1 个答案:

答案 0 :(得分:2)

如果进一步查看错误跟踪,则会发现:

  

java.lang.IllegalArgumentException:要求失败:A的列与x的元素数不匹配。 A:6,x:19

这意味着您在训练和测试数据之间的功能数量不匹配(测试中有6个功能,测试中有19个功能)。

训练数据

+--------+--------------------+--------------------+--------------------+--------------------+-----+
|      id|         description|               words|            filtered|            features|label|
+--------+--------------------+--------------------+--------------------+--------------------+-----+
|Bracelet| 7 Shakra Bracele...|[7, shakra, brace...|[7, shakra, brace...|       (6,[3],[2.0])|  1.0|
|Bracelet| Anchor Bracelet ...|[anchor, bracelet...|[anchor, bracelet...|(6,[0,2,3,4],[1.0...|  1.0|

测试数据

+---+--------------------+--------------------+--------------------+--------------------+-----+
| id|         description|               words|            filtered|            features|label|
+---+--------------------+--------------------+--------------------+--------------------+-----+
| 10| 7 Shakra Bracele...|[7, shakra, brace...|[7, shakra, brace...|(19,[0,1,2,3,10,1...|  8.0|
| 11| Anchor Bracelet ...|[anchor, bracelet...|[anchor, bracelet...|(19,[0,2,3,4,5,7,...|  4.0|

您正在尝试分别编码测试和训练数据,这会导致编码数据不匹配。

您需要从一个组合数据集(trainData.union(testData))开始,在该数据集中testData没有标签。然后通过使用管道进行转换来对此数据集进行编码。然后将数据拆分回训练并进行测试,然后训练模型并进行预测。