什么TensorFlow hash_bucket_size很重要

时间:2017-07-20 03:58:27

标签: tensorflow

我正在创建一个稀疏列的DNNclassifier。训练数据如下所示,

samples        col1                         col2          price label
  eg1    [[0,1,0,0,0,2,0,1,0,3,...]    [[0,0,4,5,0,...]    5.2    0
  eg2    [0,0,...]                     [0,0,...]            0     1
  eg3    [0,0,...]]                    [0,0,...]            0     1

以下代码段可以成功运行,

import tensorflow as tf

sparse_feature_a = tf.contrib.layers.sparse_column_with_hash_bucket('col1', 3, dtype=tf.int32)
sparse_feature_b = tf.contrib.layers.sparse_column_with_hash_bucket('col2', 1000, dtype=tf.int32)

sparse_feature_a_emb = tf.contrib.layers.embedding_column(sparse_id_column=sparse_feature_a, dimension=2)
sparse_feature_b_emb = tf.contrib.layers.embedding_column(sparse_id_column=sparse_feature_b, dimension=2)
feature_c = tf.contrib.layers.real_valued_column('price')

estimator = tf.contrib.learn.DNNClassifier(
    feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb, feature_c],
    hidden_units=[5, 3],
    n_classes=2,
    model_dir='./tfTmp/tfTmp0')

# Input builders
def input_fn_train(): # returns x, y (where y represents label's class index).
    features = {'col1': tf.SparseTensor(indices=[[0, 1], [0, 5], [0, 7], [0, 9]],
                                  values=[1, 2, 1, 3],
                                  dense_shape=[3, int(250e6)]),
                'col2': tf.SparseTensor(indices=[[0, 2], [0, 3]],
                                    values=[4, 5],
                                    dense_shape=[3, int(100e6)]),
                        'price': tf.constant([5.2, 0, 0])}
    labels = tf.constant([0, 1, 1])
    return features, labels

estimator.fit(input_fn=input_fn_train, steps=100)

但是,我对这句话有疑问,

sparse_feature_a = tf.contrib.layers.sparse_column_with_hash_bucket('col1', 3, dtype=tf.int32)

其中3表示 hash_bucket_size = 3 ,但此稀疏张量包含4个非零值,

'col1': tf.SparseTensor(indices=[[0, 1], [0, 5], [0, 7], [0, 9]],
                              values=[1, 2, 1, 3],
                              dense_shape=[3, int(250e6)])

似乎 has_bucket_size 在这里什么都不做。无论你的稀疏张量有多少非零值,你只需要用整数>设置它。 1,它工作正常。

我知道我的理解可能不对。任何人都可以解释 has_bucket_size 的工作原理吗?非常感谢!

1 个答案:

答案 0 :(得分:2)

hash_bucket_size通过获取原始索引,将它们散列到指定大小的空间,并使用散列索引作为特征来工作。

这意味着您可以在了解所有可能的指数之前指定您的模型,但代价可能是某些指数可能会发生冲突。