我正在尝试在模型中应用tf.GradientTape
。在此之前,我尝试使用一个玩具示例。
import numpy as np
import tensorflow as tf
X = tf.range(10.)
Y = 50.*X
class CGMM(object):
def __init__(self):
self.beta = tf.Variable(1. , dtype=np.float32)
@tf.function
def objfun(self):
beta_mu = self.beta
obj = tf.reduce_mean(tf.square(beta_mu*self.X - self.Y))
return obj
def build_model(self,X,Y):
self.X,self.Y=X,Y
opt = tf.optimizers.Adam(0.001)
for i in range(100):
with tf.GradientTape() as tape:
loss = self.objfun
vars = self.beta
grads = tape.gradient(loss, vars)
processed_grads = [process_gradient(g) for g in grads]
opt.apply_gradients(zip(processed_grads, vars))
opt_beta_mu =self.beta
return opt_beta_mu
model =CGMM()
opt_beta = model.build_model(X,Y)
print(opt_beta)
但是,我遇到此错误-
<ipython-input-17-381a82ea919f> in <module>
66
67 model =CGMM()
---> 68 opt_beta = model.build_model(X,Y)
69 print(opt_beta)
<ipython-input-17-381a82ea919f> in build_model(self, X, Y)
51 loss = self.objfun
52 vars = self.beta
---> 53 grads = tape.gradient(loss, vars)
54
55 # Process the gradients, for example cap them, etc.
/Users/Mine/Python/tf2_env/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
1016 flat_targets = []
1017 for t in nest.flatten(target):
-> 1018 if not backprop_util.IsTrainable(t):
1019 logging.vlog(
1020 logging.WARN, "The dtype of the target tensor must be "
/Users/Mine/Python/tf2_env/lib/python3.6/site-packages/tensorflow/python/eager/backprop_util.py in IsTrainable(tensor_or_dtype)
28 else:
29 dtype = tensor_or_dtype
---> 30 dtype = dtypes.as_dtype(dtype)
31 return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
32 dtypes.complex64, dtypes.complex128,
/Users/Mine/Python/tf2_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py in as_dtype(type_value)
641
642 raise TypeError("Cannot convert value %r to a TensorFlow DType." %
--> 643 (type_value,))
TypeError: Cannot convert value <tensorflow.python.eager.def_function.Function object at 0x14905d128> to a TensorFlow DType.
您能帮我解决这个问题吗?