我是TensorFlow的新手(尤其是自定义的内置损失/训练/等功能以外的工具),并且在实现自定义损失功能时遇到麻烦,我想解决这个问题。我已经在二维上编写了一个理想化滑翔机的简单仿真,并且我想训练一个神经网络使其尽可能远地飞行。该模型的输入是一个包含状态变量(位置,间距和它们的导数)的数组,而所需的输出是一个改变间距(基本上是模拟elevator flaps的角度)的控制变量。为了实现我想要的训练,损失函数通过模型提供控制来模拟飞行,并返回行驶距离的负值。但是,当我尝试训练模型时,计算出的梯度变成空的。我在做错什么,我在以正确的方式解决这个问题吗?
我的代码:
def fall(control_model):
#initialize physics constants and state variables
dt, g = 1/25, 9.805
x, y, theta = 0, 100, np.radians(-15)
vx, vy, vtheta = 0, 0, 0
while y > 0: #for each time step until we hit the ground:
#preliminary calculations for aerodynamics
vsq, vang, aoa = vx*vx + vy*vy, np.arctan2(vy, vx), theta - vang
while aoa <= -np.pi:
aoa += 2*np.pi
while aoa > np.pi:
aoa -= 2*np.pi
aero, aeroang = 1*vsq*np.square(np.sin(aoa)), aoa%np.pi + np.pi/2 + vang
#make an array of state variables and pass it to the model to get the control variable c
state = np.asarray([[x/100, y/100, theta/np.pi, vx/10, vy/10, vtheta/np.pi]], dtype = np.float32)
c = control_model(state).numpy()[0][0]
#integrate acceleration into speed into position
vx += aero*np.cos(aeroang)*dt
vy += (aero*np.sin(aeroang) - g)*dt
vtheta += (
0.1*vsq*np.cos(aoa)*0.5*np.sin(2*np.radians(c)) #control term
-0.05*vsq*np.square(np.sin(aoa))*np.sign(aoa) #angle of attack tends to zero
-0.8*vtheta)*dt #damping
x += vx*dt
y += vy*dt
theta += vtheta*dt
return -x #the loss is the negative of distance traveled
control = tf.keras.Sequential() #simple model for MWE
control.add(tf.keras.layers.Dense(4, activation = "relu", input_shape = (6,)))
control.add(tf.keras.layers.Dense(1, activation = "sigmoid"))
with tf.GradientTape() as tape:
loss2 = tf.Variable(fall(control))
gradients = tape.gradient(loss2, control.trainable_variables)
print(gradients) #prints [None, None, None, None]
答案 0 :(得分:1)
您需要为以下每个变量调用g.watch: 请参阅:https://www.tensorflow.org/api_docs/python/tf/GradientTape
input_images_tensor = tf.constant(input_images_numpy)
with tf.GradientTape() as g:
g.watch(input_images_tensor)
output_tensor = model(input_images_tensor)
gradients = g.gradient(output_tensor, input_images_tensor)