我正在尝试在tensorflow中实现DDPG,但遇到了一个问题,即我的一堆渐变为0,权重从未更新。具体来说,当我使用tf.gradients查找演员输出相对于其权重的梯度时,其中很多都是0。我的网络是完全连接(密集)的,其中有两个Relu层,后面是一个tanh,所以我看不到如何将任何渐变设为0。我还尝试了线性输出层,只是为了查看并遇到相同的问题。
这是我采用渐变色的方式。当我打印出来时,value_grads看起来还不错。
self.value_grads = tf.gradients(self.critic.output, self.critic.input)
self.param_grads = tf.gradients(
self.actor.output,self.actor.trainable_weights,
grad_ys = self.value_grads[0][:,6:])
self.updateActor = self.Adam.apply_gradients(
list(zip(self.param_grads,self.actor.trainable_weights)))
评论者输入的最后一部分是参与者(动作)的输出,这是数组切片的目的。如果我将grad_ys保留为默认值,则会出现同样的问题。
我的模型在Keras中定义为
def actor_model_init(self,lr):
model = Sequential()
model.add(Dense(400, input_dim=6, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(2, activation='tanh'))
return model
def critic_model_init(self,lr):
model = Sequential()
model.add(Dense(400, input_dim=8, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(1, activation='linear'))
return model
当我打印出self.actor_weights和grad_ys时,我得到: 渐变:
[[[-0.0157675687 0.0120204324 0 ... 0 0.0208223816 -0.0349246264]
[-0.00571825635 0.00396730239 0 ... 0 0.00937585346 -0.0138687501]
[-0.0158380773 0.0105671259 0 ... 0 0.0250782985 -0.0364054628]
[-0.0158380773 0.0105671259 0 ... 0 0.0250782985 -0.0364054628]
[-0.0288313236 0.0195191056 0 ... 0 0.0462683141 -0.0670839325]
[-0.0282119 0.0190116 0 ... 0 0.0459179357 -0.0661953315]], [-0.00791903865 0.00528356293 0 ... 0 0.0125391493 -0.0182027314], [[0.0113580525 -0.171025231 0 ... 0 0 0.00811411068]
[0.00957783218 -0.102332 0 ... 0 0 0.00356913568]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0.0111253224 -0.128508702 0 ... 0 0 0.00531342905]
[0.00626153918 -0.124018401 0 ... 0 0 0.00682871882]], [0.0142762065 -0.213435426 0 ... 0 0 0.00935945194], [[-0.000175745212 0.00259550055]
[-0.112493247 0.958190382]
[0 0]
...
[0 0]
[0 0]
[-5.41314e-05 0.000472108193]], [-0.0436145514 0.362336487]]
演员权重:
[[[0.0054333806 0.12122912 0.0479801297 ... 0.0306232758 0.104769051 -0.0758013055]
[0.0759132877 -0.0792496726 -0.104717404 ... 0.0820670947 0.0845757276 0.115226477]
[0.0852553844 0.114217594 0.00587947667 ... -0.109474026 0.0511197932 -0.00158583815]
[0.0148470057 -0.0479621775 0.0104149431 ... 0.0410911106 -0.100606494 -0.0555957966]
[0.0467446335 0.0142703867 -0.099080056 ... -0.041407384 0.11107491 0.106032036]
[0.101174556 0.0360424556 0.0209118873 ... -0.0135738291 0.0136910854 0.0971500501]], [-0.00377752353 0.00614813669 0 ... -0.00220888574 -0.00626849663 0.00305078411], [[-0.0534874797 -8.69747601e-05 -0.066062957 ... -0.00342122815 0.104449756 0.0936678797]
[0.0770198926 0.103114337 0.0361573808 ... 0.0649240315 -0.0637648925 -0.000988217071]
[0.084307462 -0.0977794454 -0.0771790445 ... 0.0954664052 0.0909464359 -0.0377970114]
...
[-0.072436884 0.0706674233 0.0323147923 ... -0.00516427308 0.0530067384 0.094952]
[-0.0498984344 -0.0488674156 -0.0538246594 ... -0.00791313685 0.0444330312 0.0678459927]
[-0.11838764 0.0993790403 0.0238035 ... 0.0809683651 0.109502122 -0.110104255]], [-0.000307052454 0.0118757728 -0.00224457239 ... -0.00128154433 -0.00476269191 -0.00587580958], [[0.249329627 0.235500798]
[0.501975 -0.528630257]
[0.0883877 -0.53192538]
...
[-0.509910524 0.0830168799]
[-0.418986678 -0.384885371]
[-0.399918109 0.196372628]], [0.00718917884 0.0016854346]]