我正在尝试使用最新的估算器API在tensorflow中实现类似于word2vec的模型。 我遇到的问题是当我尝试评估模型时。模型本身应该是推荐系统的基础。现在,我想要使用的评估指标是Hitrate指标,如下所示:
现在我已经设置了这样的模型:
# Map Embeddings Ids to One Hot Tensors
ref_embedding_ids = tf.feature_column.categorical_column_with_identity(
key='Reference',
num_buckets=params['dict_size'],
default_value=0
)
# Map One Hot Tensors to Dense Embeddings
ref_embeddings = tf.feature_column.embedding_column(
ref_embedding_ids,
dimension=params['embedding_size']
)
# Actually create the input to the model
input_layer = tf.feature_column.input_layer(features, feature_columns=[ref_embeddings])
在估算器API教程中完成。现在我理解了这段代码,input_layer
已经只包含了使用键Reference
在功能字典中引用的嵌入。这很棒,因为我们不需要将所有嵌入保留在内存中。
但是现在如果我想在评估模式中计算邻域,我需要访问所有可能输入的嵌入向量来计算相似性。但是我没有任何参考,因为管道设置只加载了必要的部件。我已经尝试找出包含嵌入的变量的名称,并使用tf.get_variable
显式加载,但这也不起作用。
所以我的问题是我如何计算给定嵌入id的邻域?
此外,在能够计算邻域之后,我需要使用度量来继续计算整个评估集的度量,因为对来自数据集的每个批次调用函数。但我想这是一个不同的问题,我只是提到它的背景。
答案 0 :(得分:0)
我可以通过将cols_to_vars
函数使用tf.feature_column.input_layer
参数来使此工作正常。
基本上,您需要将空字典传递给tf.feature_column.input_layer,该字典将填充在输入层中生成的变量。
下面的示例仅包含model_fn函数的模型范围。希望对您有帮助。
def my_model_fn(features, labels, mode, params):
with tf.name_scope('model'):
num_vocabulary = len(params['vocabulary'])
#create embedding vectors for history feature
vocabulary_lookup = tf.contrib.lookup.index_table_from_tensor(
name='vocabulary_lookup',
mapping=params['vocabulary'],
default_value=-1,
num_oov_buckets=1
)
#print("lookup: {}".format(vocabulary_lookup))
#create a bais matrix
rank_biases = tf.get_variable(name='rank_biases', shape=[num_vocabulary])
# input layer
input_variable_ref = {}
net = tf.feature_column.input_layer(features, params['feature_columns'], cols_to_vars=input_variable_ref)
embedding_metrix = input_variable_ref[params['feature_columns'][0]][0]
print("embedding metrix => {}".format(embedding_metrix))
# hidden layers
for units in params['hidden_units']:
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
# output layer
with tf.name_scope('DNN_output'):
logits = tf.layers.dense(net, params['out_layer_dim'], activation=None)
#print("logits shape: {}".format(logits.shape))