如何将张量传递给tf.tensordot axes变量?

时间:2018-01-25 22:44:16

标签: python tensorflow

我不知道这是否是张量流的错误,或者我是否在我的代码中做错了。无论如何,我试图将张量传递到tf.tensordot轴变量,并且在运行代码时似乎得到奇怪的输出。我的代码如下:

<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script>

如果我运行此代码,我会从print语句中获得输出:

t1 = tf.ones([2,3,4], dtype=tf.float32)
t2 = tf.ones([2,2,3,4], dtype=tf.float32)

axes = tf.constant([[2],[3]], dtype=tf.int32)
out = tf.tensordot(t1, t2, axes)
print(out)

with tf.Session() as sess:
    sess.run(out)

显然,计算已经完成,程序并没有抱怨任何错误。然而,创建的张量似乎没有任何形状。至少没有注册形状。如果我再次运行上面相同的代码,但现在将axes变量切换为:

Tensor("Tensordot:0", dtype=float32)

我从print语句中得到以下输出:

axes = [[2],[3]]

我试过四处看看是否有任何我做错的事,或者我是否认为这是错误的,但我还没有找到关于这个主题的更多信息。我怀疑这是张量流中的一个错误,如果是,那么这个问题是否有任何变通方法,我仍然可以使用张量来定义轴,或者轴从张量信息中得到它的形状?

为了给你一个我来自哪里的暗示:在我的主算法中,我试图将t1和t2的最后一个维度(当然,当他们有任意等级时)与tensordot相乘

1 个答案:

答案 0 :(得分:0)

在查看tf.tensordot's source code之后,我确实设法解决了我的上一个问题。显然,张量等级可以由tf.rank和tf.get_shape()。n_dims访问。因此,我可以编写以下代码:

t1 = tf.ones([2,3,4], dtype=tf.float32)
t2 = tf.ones([2,2,3,4], dtype=tf.float32)

t1_rank = t1.get_shape().ndims
t2_rank = t2.get_shape().ndims

axes = [[t1_rank-1], [t2_rank-1]]
out = tf.tensordot(t1, t2, axes)
print(out)

with tf.Session() as sess:
    sess.run(out)

从print语句中获取以下输出:

Tensor("Tensordot:0", shape=(2, 3, 2, 2, 3), dtype=float32)

我有一个小小的怀疑,这是一种效率低下的方法,因为张量等级现在必须通过python才能插入轴并最终输入tf.tensordot。然而,这似乎是他们解决排名计算的方式,也是设置图表的一部分,因此它可能只是一次性计算。我不太了解图表的设置方式,所以不要相信我的意思。