检查失败:当我构建svm模型时,NDIMS == dims()(2对1)

时间:2017-04-26 15:28:08

标签: tensorflow deep-learning

当我使用tf.learn构建svm模型时,会出现如下错误: enter image description here 我已经运行了LinearClassifier并且它成功了! 我不知道如何解决错误,任何人都可以帮忙吗?

以下是重现错误的代码:

import tensorflow as tf
import pandas as pd
from tensorflow.contrib.learn.python.learn.estimators import svm


detailed_occupation_recode = tf.contrib.layers.sparse_column_with_hash_bucket(
    column_name='detailed_occupation_recode', 
    hash_bucket_size = 1000
)
education = tf.contrib.layers.sparse_column_with_hash_bucket(
    column_name='education',
    hash_bucket_size=1000
)
# Continuous base columns
age = tf.contrib.layers.real_valued_column('age')
wage_per_hour = tf.contrib.layers.real_valued_column('wage_per_hour')



columns = ['age', 'detailed_occupation_recode', 'education', 'wage_per_hour','label']
FEATURE_COLUMNS = [
    # age, age_buckets, class_of_worker, detailed_industry_recode,
    age, detailed_occupation_recode, education, wage_per_hour

]


LABEL_COLUMN = 'label'

CONTINUOUS_COLUMNS = ['age', 'wage_per_hour']

CATEGORICAL_COLUMNS = ['detailed_occupation_recode','education']


df_train = pd.DataFrame([[12,'12','7th and 8th grade',40,'- 50000'],
                [40,'45','7th and 8th grade',40, '50000+'],
                [50,'50','10th grade',40,'50000+'],
                [60,'30','7th and 8th grade',40,'- 50000']],
                columns=['age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 'label'])


df_test = pd.DataFrame([[12,'12','7th and 8th grade',40,'- 50000'],
                [40,'45','7th and 8th grade',40, '50000+'],
                [50,'50','10th grade',40,'50000+'],
                [60,'30','7th and 8th grade',40,'- 50000']],
                columns=['age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 'label'])
df_train[LABEL_COLUMN] = (df_train[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int)
df_test[LABEL_COLUMN] = (df_test[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int)
dtypess = df_train.dtypes


def input_fn(df):
    continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS}
    categorical_cols = {k: tf.SparseTensor(
        indices=[[i, 0] for i in range(df[k].size)],
        values=df[k].values,
        dense_shape=[df[k].size, 1]) for k in CATEGORICAL_COLUMNS}    
    feature_cols = dict(continuous_cols.items() + categorical_cols.items())
    feature_cols['example_id'] = tf.constant([str(i+1) for i in range(df['age'].size)])
    label = tf.constant(df[LABEL_COLUMN].values)
    return feature_cols, label

def train_input_fn():
    return input_fn(df_train)

def eval_input_fn():
    return input_fn(df_test)

model_dir = '../svm_model_dir'

model = svm.SVM(example_id_column='example_id', feature_columns=FEATURE_COLUMNS, model_dir=model_dir)
model.fit(input_fn=train_input_fn, steps=10)
results = model.evaluate(input_fn=eval_input_fn, steps=1)
for key in sorted(results):
    print("%s: %s" % (key, results[key]))

和所有输出文本结果:

WARNING:tensorflow:The default value of combiner will change from "sum" to 

"sqrtn" after 2016/11/01.
WARNING:tensorflow:The default value of combiner will change from "sum" to "sqrtn" after 2016/11/01.
WARNING:tensorflow:tf.variable_op_scope(values, name, default_name) is deprecated, use tf.variable_scope(name, default_name, values)
WARNING:tensorflow:Rank of input Tensor (1) should be the same as output_rank (2) for column. Will attempt to expand dims. It is highly recommended that you
 resize your input, as this behavior may change.
WARNING:tensorflow:Rank of input Tensor (1) should be the same as output_rank (2) for column. Will attempt to expand dims. It is highly recommended that you
 resize your input, as this behavior may change.
WARNING:tensorflow:From /Users/burness/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py:882: hinge_loss (from t
ensorflow.contrib.losses.python.losses.loss_ops) is deprecated and will be removed after 2016-12-30.
Instructions for updating:
Use tf.losses.hinge_loss instead.
WARNING:tensorflow:From /Users/burness/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py:1362: scalar_summary (f
rom tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2016-11-30.
Instructions for updating:
Please switch to tf.summary.scalar. Note that tf.summary.scalar uses the node name instead of the tag. This means that TensorFlow will automatically de-dupl
icate summary names based on the scope they are created in. Also, passing a tensor or list of tags to a scalar summary op is no longer supported.
WARNING:tensorflow:From /Users/burness/anaconda/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py:882: hinge_loss (from t
ensorflow.contrib.losses.python.losses.loss_ops) is deprecated and will be removed after 2016-12-30.
Instructions for updating:
Use tf.losses.hinge_loss instead.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machi
ne and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machi
ne and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine
and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine
 and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine
and could speed up CPU computations.
F tensorflow/core/framework/tensor_shape.cc:36] Check failed: NDIMS == dims() (2 vs. 1)Asking for tensor of 2 dimensions from a tensor of 1 dimensions
[1]    66225 abort      python simple-tf-svm.py

1 个答案:

答案 0 :(得分:0)

如果其他人遇到此问题,请跟进:real_valued_columns目前必须是具有批量维度的向量,而在您的示例中它们是标量。这里有github issue跟踪创建更好的错误消息。