对于在tensorflow中实现的采样器,例如tf.nn.fixed_unigram_candidate_sampler。该行为在文档中没有明确定义。例如,我希望从采样池中排除true_classes中指定的标签,并对每个批次进行采样。但根据我的实验,以上都不是真的。
请考虑以下代码:
import tensorflow as tf
labels_matrix = tf.reshape(tf.constant([1, 2, 3, 4], dtype=tf.int64), [-1, 1])
sampled_ids, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes = labels_matrix,
num_true = 1,
num_sampled = 1,
unique = True,
range_max = 5,
distortion = 0.0,
unigrams = range(5)
)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print sess.run([sampled_ids])
输出可以是3,实际上属于真实类的集合。 - 此外,输出具有维度[1],这基本上意味着采样只进行一次,而不是每批次。
有人可以帮忙解释一下吗?
答案 0 :(得分:0)
documentation for fixed_unigram_candidate_sampler确实提到可以对真实标签进行抽样。您在代码中标记为_
的其中一项事实上是采样真实标签的预期比率。