如何在tf.scatter_nd中使用tf.meshgrid?

时间:2019-03-17 10:55:46

标签: python tensorflow

我正在尝试使用形状为[I, B, num_topics, vocab_size_1]的映射张量将形状为[I, B, vocab_size_2]的张量映射为[num_topics, vocab_size_1],其中每个条目都指向vocab_size_2中的索引,其中vocab_size_1中的条目应该会显示。

这是一个例子:

# Our mapping file of shape [num_topics=2, vocab_size_1=2]
mapping = [[0, 2], [1, 3]]  # ie. the [0, 1] entry should go into index 2 
mapping = np.asarray(mapping)


# Our source file of shape [I=1, B=1, num_topics=2, vocab_size_1=2]
source = np.arange(4).reshape((1, 1, 2, 2))


# Our target should have shape [I=1, B=1, vocab_size_2=4]
...

如果我理解正确,那么可以将tf.scatter_ndtf.meshgrid一起使用以生成适当的映射:

# In this example the indices tensor should look like this:
# [[[[[ [0, 0, 0] vocab=0], [[0, 0, 2] vocab=1] topic=0], [ [ [0, 0, 1] vocab=0], [[0, 0, 3] vocab=1] topic=1] b=0] i=0]]

# Of shape [I=1, B=1, num_topics=2, vocab_size_1=2, 3]

target = tf.scatter_nd(source, indices, shape=[1, 1, 4])

我尝试用tf.meshgrid生成映射,但是在制作索引文件时遇到了问题。有人知道如何解决这个问题吗?我也非常愿意管理这种映射。

谢谢!

1 个答案:

答案 0 :(得分:0)

所以我找到了使用meshgrid,stack和tf.scatter_nd的解决方案。我发布了numpy版本,尽管它可以通过仅用“ tf”替换“ np”直接翻译成tensorflow。请注意,topic_vocab_size = vocab_size_1full_vocab_size = vocab_size_2

import numpy as np
import tensorflow as tf

I = 1
B = 1
NUM_TOPICS = 2
TOPIC_VOCAB_SIZE = 2
FULL_VOCAB_SIZE = 4

# mapping of shape [num_topics=2, topic_vocab_size=2]
mapping = np.arange(NUM_TOPICS * TOPIC_VOCAB_SIZE)
np.random.shuffle(mapping)
mapping = mapping.reshape((NUM_TOPICS, TOPIC_VOCAB_SIZE))

mapping[0, 0] = 0
mapping[1, 0] = 0

print("Mapping:")
print(mapping)

# Source of shape [I=2, B=1, num_topics=2, topic_vocab_size=2]
source = np.arange(I * B * NUM_TOPICS * TOPIC_VOCAB_SIZE).reshape((I, B, NUM_TOPICS, TOPIC_VOCAB_SIZE))

print("Source:")
print(source)
# Now we want to project the source into shape [I=1, B=1, full_vocab_size=4] using mapping
# For tf.scatter_nd, mapping has to be [I=1, B=1, num_topic=2, topic_vocab_size=2, [I=2, B=1, full_vocab=4]]
# So our aim is to have a tensor:
# [[[[[ [0, 0, 0] vocab=0], [[0, 0, 2] vocab=1] topic=0], [ [ [0, 0, 1] vocab=0], [[0, 0, 3] vocab=1] topic=1] b=0] i=0]


ii, bb, _, _ = np.meshgrid(np.arange(I), np.arange(B), np.arange(NUM_TOPICS), np.arange(TOPIC_VOCAB_SIZE), indexing='ij')
# shape: [I, B, num_topics, topic_vocab_size]

mapping = np.expand_dims(np.expand_dims(mapping, axis=0), axis=0)
mapping = np.tile(mapping, [I, B, 1, 1])  # Make mapping of shape [I, B, num_topics, topic_vocab_size]

print(mapping.shape)
print(ii.shape)
print(bb.shape)


idx = np.stack([ii, bb, mapping], axis=-1)

print(idx)

print(idx.shape)

target = tf.scatter_nd(idx, source, shape=[I, B, FULL_VOCAB_SIZE])  # Shape is [I, B, FULL_VOCAB_SIZE]

sess = tf.Session()
with sess.as_default():
    print(target.eval())