UFC数据集InvalidArgumentError断言失败:[标签ID必须<n_classes] [条件x <y不按元素进行保存:]

时间:2019-07-31 17:41:37

标签: python tensorflow-estimator

pandas, tensorflow and numpy are imported here

****Download UFC dataset from https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data

INPUT_PATH = "./breast-cancer-wisconsin.data"
OUTPUT_PATH = "./breast-cancer-wisconsin.csv"

Headers = ["CodeNumber", "ClumpThickness", "UniformityCellSize", 
            "UniformityCellShape", "MarginalAdhesion", 
            SingleEpithelialCellSize", "BareNuclei", "BlandChromatin", 
            "NormalNucleoli", "Mitoses", "CancerType"]  

**处理丢失的信息,即替换“?”使用NAN值并删除这些列

def handle_missings(dataset, missing_vals, drop_cols = []):  
    dataset.replace(missing_vals, np.NaN, inplace=True)  
    dataset.dropna(inplace=True)  
    dataset.drop(drop_cols, axis=1, inplace=True)  
    return dataset

def data_file_to_csv():  # conversion of data file to csv

    dataframe = pd.read_csv(INPUT_PATH, names = headers)  
    dataframe.to_csv(OUTPUT_PATH, index=False)     # this saves dataframe in a csv file

def train_input_fn(df):  # Tensorflow estimator input function read 
from the pandas dataframe

    (x = df[feature_headers], y = df[target_label])

def main():    -- main function 
    data_file_to_csv()   
    dataset = pd.read_csv(OUTPUT_PATH)   
    dataset = handle_missings(dataset, '?', Headers[6])   # replace quoted characters with NaN and then remove the rows 

    feature_headers = Headers[1:-1]  
    target_label = Headers[-1]  

    dataset['BareNuclei'] =  dataset['BareNuclei'].astype(np.int64)#.size())#str.find("?"))
    #print ("missing count ", dataset.dtypes.value_counts())

    feat_columns = [tf.feature_column.numeric_column(k) for k in 
                                           feature_headers]    

    model = tf.estimator.LinearClassifier(feature_columns=feat_columns,  
                             model_dir = None, n_classes=2)  

    model.train(input_fn=train_input_fn(dataset), steps=100)  

在运行带有InvalidArgumentError的代码时出错(请参见上面的回溯):声明失败:[标签ID必须

0 个答案:

没有答案