对TfidfVectorizer.fit_transform的返回结果感到困惑

时间:2018-06-18 09:19:26

标签: scikit-learn nlp tf-idf tfidfvectorizer

我想了解更多有关NLP的信息。我遇到了这段代码。但是在打印结果时,我对TfidfVectorizer.fit_transform的结果感到困惑。我熟悉tfidf是什么,但我无法理解这些数字是什么意思。

import tensorflow as tf 
import numpy as np 
from sklearn.feature_extraction.text import TfidfVectorizer
import os 
import io
import string 
import requests 
import csv 
import nltk
from zipfile import ZipFile 

sess = tf.Session()

batch_size = 100
max_features = 1000

save_file_name = os.path.join('smsspamcollection','SMSSpamCollection.csv')
if os.path.isfile(save_file_name):
    text_data = []
    with open(save_file_name,'r') as temp_output_file:
        reader = csv.reader(temp_output_file)
        for row in reader:
            text_data.append(row)

else:
    zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
    r = requests.get(zip_url)
    z = ZipFile(io.BytesIO(r.content))
    file = z.read('SMSSpamCollection')

    #Format data 
    text_data = file.decode()
    text_data = text_data.encode('ascii',errors='ignore')
    text_data = text_data.decode().split('\n')
    text_data = [x.split('\t') for x in text_data if len(x)>=1]

    #And write to csv 
    with open(save_file_name,'w') as temp_output_file:
        writer = csv.writer(temp_output_file)
        writer.writerows(text_data)

texts = [x[1] for x in text_data]
target = [x[0] for x in text_data]
target = [1 if x=='spam' else 0 for x in target]


#Normalize the text
texts = [x.lower() for x in texts] #lower
texts = [''.join(c for c in x if c not in string.punctuation) for x in texts] #remove punctuation
texts = [''.join(c for c in x if c not in '0123456789') for x in texts] #remove numbers
texts = [' '.join(x.split()) for x in texts] #trim extra whitespace

def tokenizer(text):
    words = nltk.word_tokenize(text)
    return words

tfidf = TfidfVectorizer(tokenizer=tokenizer, stop_words='english', max_features=max_features)
sparse_tfidf_texts = tfidf.fit_transform(texts)
print(sparse_tfidf_texts)

输出是:

  

(0,630)0.37172623140154337(0,160)0.36805562944957004(0,   38)0.3613966215413548(0,545)0.2561101665717327(0,   326)0.2645280991765623(0,967)0.3277447602873963(0,   421)0.3896274380321477(0,227)0.28102915589024796(0,   323)0.22032541100275282(0,922)0.2709848154866997(1,   577)0.4007895093299793(1,425)0.5970064521899725(1,   943)0.6310763941180291(1,878)0.29102173465492637(2,   282)0.1771481430848552(2,243)0.5517018054305785(2,   955)0.2920174942032025(2,138)0.30143666813167863(2,   946)0.2269933441326121(2,165)0.3051095293405041(2,   268)0.2820392223588522(2,780)0.24119626642264894(2,   823)0.1890454397278538(2,674)0.256251970757827(2,   874)0.19343834015314287 ::(5569,648)0.24171652492226922
  (5569,123)0.23011909339432202(5569,957)0.24817919217662862
  (5569,549)0.28583789844730134(5569,863)0.3026729783085827
  (5569,844)0.20228305447951195(5569,146)0.2514415602877767
  (5569,595)0.2463259875380789(5569,511)0.3091904754885042
  (5569,230)0.2872728684768659(5569,638)0.34151390143548765
  (5569,83)0.3464271621701711(5570,370)0.4199910200421362
  (5570,46)0.48234172093857797(5570,317)0.4171646676697801
  (5570,281)0.6456993475093024(5572,282)0.25540827228532487
  (5572,385)0.36945842040023935(5572,448)0.25540827228532487
  (5572,931)0.3031800542518209(5572,192)0.29866989620926737
  (5572,303)0.43990016711221736(5572,87)0.45211284173737176
  (5572,332)0.3924202767503492(5573,866)1.0

如果有人可以解释输出,我会非常高兴。

1 个答案:

答案 0 :(得分:5)

请注意,您正在打印稀疏矩阵,因此与打印标准密集矩阵相比,输出看起来不同。见下面的主要组成部分:

  • 元组代表:(document_id, token_id)
  • 元组后面的值表示给定文档中给定标记的tf-idf分数
  • 不存在的元组的tf-idf分数为0

如果您想查找token_id对应的令牌,请检查get_feature_names方法。