如何在keras中重新排列张量的轴?

时间:2019-01-21 05:32:13

标签: tensorflow machine-learning keras deep-learning

我的代码就像,

 user_item_matrix = K.constant(user_item_matrix)
    # Input variables
    user_input = Input(shape=(1,), dtype='int32', name='user_input')
    item_input = Input(shape=(1,), dtype='int32', name='item_input')
    # Embedding layer
    user_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=0))(user_input)
    item_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=1))(item_input)

其中user_item_matrix是6040 * 3706矩阵。假设user_rating和item_rating的形状为(?,3706)和(?,6040)。但是,实际情况是:

user_rating:  (?, 1, 3706)
item_rating:  (6040, ?, 1)

我对为什么6040发生在应该位于的轴0上感到困惑? (批量大小)。我尝试使用Permute和Reshape解决此问题,但仍然无法正常工作。有解决这个问题的好方法吗?谢谢。

1 个答案:

答案 0 :(得分:0)

您可以看到document关于tf.gather()的信息:

  

产生形状为 params.shape [:axis] +的输出张量   indexs.shape + params.shape [轴+ 1:]

您的参数形状为(6040,3706),索引的形状为(?,1)

因此,当设置params.shape[:0] + indices.shape + params.shape[1:]时,输出形状为() + (?,1) + (3706,) = axis=0

当设置params.shape[:1] + indices.shape + params.shape[2:]时,输出形状为(6040,) + (?,1) + () = axis=1

您可以使用tf.transpose()重新排列轴。

import tensorflow as tf
import keras.backend as K
from keras.layers import Input,Lambda
import numpy as np

user_item_matrix = K.constant(np.zeros(shape=(6040,3706)))
# Input variables
user_input = Input(shape=(1,), dtype='int32', name='user_input')
item_input = Input(shape=(1,), dtype='int32', name='item_input')
# Embedding layer
user_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=0))(K.squeeze(user_input,axis=1))
item_rating = Lambda(lambda x: tf.transpose(tf.gather(user_item_matrix, tf.to_int32(x), axis=1),(1,0)))(K.squeeze(item_input,axis=1))
print(user_rating.shape)
print(item_rating.shape)

# print
(?, 3706)
(?, 6040)