我有一个嵌入张量:
phi = tf.get_variable("phi", [self.X, self.Y],
initializer = tf.random_normal_initializer(0, 0.1))
和索引的占位符:
k = tf.placeholder(tf.int32, [None])
我有以下代码,在索引值正确的情况下返回嵌入,在-1的情况下,返回带零的张量。
kappa = tf.where(tf.not_equal(k,tf.constant(-1)),
tf.nn.embedding_lookup(phi, k,name="k_cluster_look"),
tf.zeros([1,self.Y]))
运行命令:
x = session.run(kappa, feed_dict={k:[-1]})
但是,我收到错误:
InvalidArgumentError(参见上面的回溯):indices [0] = -1不是 在[0,15177] [[节点:k_cluster_look = GatherV2 [Taxis = DT_INT32, Tindices = DT_INT32,Tparams = DT_FLOAT, _device =" / job:localhost / replica:0 / task:0 / device:CPU:0"](phi / read,_arg_neg_item_0_0,k_cluster_look / axis)]]
我不明白为什么where
条件不会阻止无效(-1)值被视为嵌入查找的索引。
(使用最新版本的TensorFlow)