带有Tensorflow 2的opt_einsum

时间:2020-04-16 13:54:13

标签: python tensorflow tensorflow2.0

我想用张量流计算多个复杂的张量收缩。我在opt_einsum的“交错”输入表单中有收缩。当我尝试使用张量流进行区分时得到警告

WARNING:tensorflow:AutoGraph could not transform <function _gcd_import at 0x000001ED736ABEA0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function _gcd_import at 0x000001ED736ABEA0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code

一个说明问题的最小示例是

import numpy
import tensorflow as tf
import tensorflow.keras.optimizers as opt
import opt_einsum

a = tf.Variable(numpy.zeros((2, 2)))
b = tf.Variable(numpy.zeros((2, 2)))

opt = opt.Adam()        
opt.minimize(
    tf.function(
        lambda : opt_einsum.contract(a, (1, 2), b, (1, 2), backend = "tensorflow")
    ), var_list = [a, b]
)

我尝试遵循opt_einsum软件包中的说明,但它们似乎是为Tensorflow 1编写的。有没有简单的方法可以使此工作正常进行。

0 个答案:

没有答案