Tensorflow中采样器的行为不明确

时间:2016-08-15 20:09:40

标签: tensorflow

对于在tensorflow中实现的采样器,例如tf.nn.fixed_unigram_candidate_sampler。该行为在文档中没有明确定义。例如,我希望从采样池中排除true_classes中指定的标签,并对每个批次进行采样。但根据我的实验,以上都不是真的。

请考虑以下代码:

import tensorflow as tf

labels_matrix = tf.reshape(tf.constant([1, 2, 3, 4], dtype=tf.int64), [-1, 1])

sampled_ids, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes = labels_matrix,
num_true = 1,
num_sampled = 1,
unique = True,
range_max = 5,
distortion = 0.0,
unigrams = range(5)
)

init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print sess.run([sampled_ids])

输出可以是3,实际上属于真实类的集合。 - 此外,输出具有维度[1],这基本上意味着采样只进行一次,而不是每批次。

有人可以帮忙解释一下吗?

1 个答案:

答案 0 :(得分:0)

documentation for fixed_unigram_candidate_sampler确实提到可以对真实标签进行抽样。您在代码中标记为_的其中一项事实上是采样真实标签的预期比率。