我目前有单拍编码,我想使用嵌入。但是当我打电话时
embed=tf.nn.embedding_lookup(embeddings, train_data)
print(embed.get_shape())
嵌入数据形状(11,32,729,128)
这个形状应该是(11,32,128),但它给了我错误的尺寸,因为train_data是onehot编码的。
train_data2=tf.matmul(train_data,tf.range(729))
给我错误:
ValueError: Shape must be rank 2 but is rank 3
请帮帮我!感谢。
答案 0 :(得分:2)
对您示例的一个小修复:
encoding_size = 4
one_hot_batch = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])
one_hot_indexes = tf.matmul(one_hot_batch, np.array([range(encoding_size)],
dtype=np.int32).T)
with tf.Session() as session:
print one_hot_indexes.eval()
另一种方式:
batch_size = 3
one_hot_batch = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])
one_hot_indexes = tf.where(tf.not_equal(one_hot_batch, 0))
one_hot_indexes = one_hot_indexes[:, 1]
one_hot_indexes = tf.reshape(one_hot_indexes, [batch_size, 1])
with tf.Session() as session:
print one_hot_indexes.eval()
两种情况都有结果:
[[3]
[1]
[0]]