在线性权重索引中查找string_to_hash_bucket的索引Tensorflow

时间:2016-09-08 00:22:02

标签: python tensorflow

我正在使用LinearClassifier类在二元分类模型上使用张量流。我基于分类的一个功能是名为hat的列:

hat = tf.contrib.layers.sparse_column_with_hash_bucket("hat", hash_bucket_size=1000)

在我初始化模型并在tf.Session()中完成拟合之后:

with tf.Session() as sess:    
    m = tf.contrib.learn.LinearClassifier(feature_columns=hat...)
    m.fit(...)

我想在训练模型后检查每个帽子类别的重量。

帽子标签只是由不同的字符串给出。在训练模型后,我想找到与每个帽子标签相关的重量。然而,为了将权重与特定的帽子进行比较,我需要知道帽子标签被扔进哪个哈希桶。我的帽子标签之一是“tb”。我可以使用函数找到索引的内容:

tf.string_to_hash_bucket(tf.cast("tb",tf.string), 1000)

然后我可以遍历这里返回的权重:

for i,n in enumerate(m.linear_weights_["linear/hat_weights"]):
    print i, n

给了我:

linear/hat_weights
   0 [-0.147]
   ...

我的问题是,没有任何具有重要(abs(x)> 0.0005)权重的索引对应于我从数据集中所有帽子标签上的string_to_hash_bucket获得的任何哈希桶ID。

最后我的问题是:

我认为string_to_hash_bucket id应该对应m.linear_weights _ [“linear / hat_weights”]返回的相应数组的索引吗?

如果没有,我怎样才能获得正确的身份证?在线性模型中,是否有更简单的方法来检查特征列张量的权重,包括稀疏和实值(均不包含在.linear_weights_中)?

非常感谢!

1 个答案:

答案 0 :(得分:1)

另一种检查黑盒重量(因此保证可以工作)的方法是在一个只打开一个帽子功能的示例上评估整个模型。然后,您可以看到模型为每个要素指定的权重。