可视化tf.contrib.learn.LinearClassifier权重

时间:2017-04-03 13:14:39

标签: python tensorflow

我使用着名的泰坦数据集玩过tensorflow的LinearClassifier数据。

(我的问题本身在底部 - 这是模型本身的所有代码)

所以我有我的专栏:

CONTINUOUS_COLS = ['Age', 'Fare']
CATEGORICAL_COLS = ['Sex', 'Pclass', 'Title']
LABELS_COL = 'Survived'

sex_col = sparse_column_with_keys('Sex', keys=['male', 'female'])
title_col = sparse_column_with_hash_bucket('Title', 10)
fare_class_col = sparse_column_with_keys('Pclass', keys=['1','2','3'])
age_col = real_valued_column('Age')
fare_col = real_valued_column('Fare')

我的输入功能:

def create_input_fn(df):
    continous_features = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLS}
    categorical_features = {k : tf.SparseTensor(
        indices=[[0,i] for i in range(df[k].size)],
        values=df[k].values,
        dense_shape=[df[k].size, 1]
    ) for k in CATEGORICAL_COLS}
    feature_cols = {**continous_features, **categorical_features}
    labels = tf.constant(df[LABELS_COL].values)
    return feature_cols, labels

和我的模特:

clf = LinearClassifier(feature_columns=[sex_col, fare_class_col, age_col, fare_col, title_col],
    optimizer=tf.train.FtrlOptimizer(
        learning_rate=0.5,
        l1_regularization_strength=1.0,
        l2_regularization_strength=1.0),
    model_dir=tempfile.TemporaryDirectory().name)

现在当我运行模型时,它确实是okaish,我想查看模型的权重以更好地可视化它们。

所以clf.weights_存在(虽然它被列为已弃用),所以我只是手动将它们拉出来:

for var in clf.get_variable_names():
    if var.endswith('weights'):
        print(f'{var} -> {clf.get_variable_value(var)}')

我得到了一些不错的结果:

linear/Pclass/weights -> [[ 0.        ]
 [ 0.        ]
 [-0.01772301]]
linear/Sex/weights -> [[-0.07285357]
 [ 0.        ]]
linear/Title/weights -> [[ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [-0.03760524]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]]

现在我的问题是 - 如何取出最初使用的? 所以我可以更好地匹配数字,例如与性别相关 - 键最初映射到男性/女性。

谢谢!

1 个答案:

答案 0 :(得分:0)

sparse_column_with_keys
sex_col.lookup_config.keys # ('male', 'female')

类似于:

matched = {}
weights = clf.get_variable_value('linear/Sex/weights')  # np array
for index, key in enumerate(sex_col.lookup_config.keys):
    matched[key] = weights[index]

并且在您dir(sex_col.lookup_config)时还有一些其他有趣的属性,并检查方法文档字符串:Source for SparseColumn Feature classes https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/feature_column.py

我没有找出sparse_column_with_hash_bucket

的地图

如果教程中有tf.contrib.layers.bucketized_column个age_buckets: age_buckets.boundaries