TypeError:无法将值<tensorflow.python.eager.def_function.Function对象转换为TensorFlow DType

时间:2020-09-15 15:26:08

标签: python tensorflow tensorflow2.0

我正在尝试在模型中应用tf.GradientTape。在此之前,我尝试使用一个玩具示例。

import numpy as np
import tensorflow as tf


X = tf.range(10.)
Y = 50.*X

class CGMM(object):
    def __init__(self):
        self.beta =  tf.Variable(1. , dtype=np.float32)

    @tf.function
    def objfun(self):
        beta_mu = self.beta
        obj = tf.reduce_mean(tf.square(beta_mu*self.X - self.Y))
        return obj

    def build_model(self,X,Y):
        self.X,self.Y=X,Y
        opt = tf.optimizers.Adam(0.001) 

        for i in range(100):

            with tf.GradientTape() as tape:
              loss = self.objfun
            vars = self.beta
            grads = tape.gradient(loss, vars)


            processed_grads = [process_gradient(g) for g in grads]
            opt.apply_gradients(zip(processed_grads, vars))

        opt_beta_mu =self.beta

        return opt_beta_mu


model =CGMM()
opt_beta = model.build_model(X,Y)
print(opt_beta)

但是,我遇到此错误-

<ipython-input-17-381a82ea919f> in <module>
     66 
     67 model =CGMM()
---> 68 opt_beta = model.build_model(X,Y)
     69 print(opt_beta)

<ipython-input-17-381a82ea919f> in build_model(self, X, Y)
     51               loss = self.objfun
     52             vars = self.beta
---> 53             grads = tape.gradient(loss, vars)
     54 
     55             # Process the gradients, for example cap them, etc.

/Users/Mine/Python/tf2_env/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
   1016     flat_targets = []
   1017     for t in nest.flatten(target):
-> 1018       if not backprop_util.IsTrainable(t):
   1019         logging.vlog(
   1020             logging.WARN, "The dtype of the target tensor must be "

/Users/Mine/Python/tf2_env/lib/python3.6/site-packages/tensorflow/python/eager/backprop_util.py in IsTrainable(tensor_or_dtype)
     28   else:
     29     dtype = tensor_or_dtype
---> 30   dtype = dtypes.as_dtype(dtype)
     31   return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
     32                               dtypes.complex64, dtypes.complex128,

/Users/Mine/Python/tf2_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py in as_dtype(type_value)
    641 
    642   raise TypeError("Cannot convert value %r to a TensorFlow DType." %
--> 643                   (type_value,))

TypeError: Cannot convert value <tensorflow.python.eager.def_function.Function object at 0x14905d128> to a TensorFlow DType.

您能帮我解决这个问题吗?

0 个答案:

没有答案