Keras - 权重初始化为nans

时间:2018-01-10 23:40:37

标签: keras keras-rl

我正在尝试为基于策略的RL创建神经网络。我编写了类来构建网络并生成如下操作:

class Oracle(object):

def __init__(self, input_dim, output_dim, hidden_dims=None):
    if hidden_dims is None:
        hidden_dims = [32, 32]
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.__build_network(input_dim,output_dim,hidden_dims)
    self.__build_train_fn()


def __build_network(self,input_dim, output_dim, hidden_dims):
    """Create a base network"""
    inputs = Input(shape=(input_dim,))
    net = inputs
    # a layer instance is callable on a tensor, and returns a tensor
    for h_dim in hidden_dims:
        net = Dense(h_dim, activation='relu',kernel_initializer='RandomNormal',bias_initializer='zeros')(net)

    net = Dense(output_dim, activation='softmax',kernel_initializer='RandomNormal',bias_initializer='zeros')(net)

    # This creates a model that includes
    # the Input layer and three Dense layers
    self.model = Model(inputs=inputs, outputs=net)
    return self.model

def __build_train_fn(self):
    """Create a train function
            It replaces `model.fit(X, y)` because we use the output of model and use it for training.
            For example, we need action placeholder
            called `action_one_hot` that stores, which action we took at state `s`.
            Hence, we can update the same action.
            This function will create
            `self.train_fn([state, action_one_hot, discount_reward])`
            which would train the model.
    """
    action_prob_placeholder = self.model.output
    action_onehot_placeholder = K.placeholder(shape=(None, self.output_dim),
                                              name="action_onehot")
    discount_reward_placeholder = K.placeholder(shape=(None,),
                                                name="discount_reward")

    action_prob = K.sum(action_prob_placeholder * action_onehot_placeholder, axis=1)
    log_action_prob = K.log(action_prob)

    loss = - log_action_prob * discount_reward_placeholder
    loss = K.mean(loss)

    adam = optimizers.Adam()

    updates = adam.get_updates(params=self.model.trainable_weights,
                               constraints=[],
                               loss=loss)

    self.train_fn = K.function(inputs=[self.model.input,
                                       action_onehot_placeholder,
                                       discount_reward_placeholder],
                               outputs=[],
                               updates=updates)

def get_action(self, state):
    """Returns an action at given `state`
    Args:
        state (1-D or 2-D Array): It can be either 1-D array of shape (state_dimension, )
            or 2-D array shape of (n_samples, state_dimension)
    Returns:
        action: an integer action value ranging from 0 to (n_actions - 1)
    """
    shape = state.shape

    if len(shape) == 1:
        assert shape == (self.input_dim,), "{} != {}".format(shape, self.input_dim)
        state = np.expand_dims(state, axis=0)

    elif len(shape) == 2:
        assert shape[1] == (self.input_dim), "{} != {}".format(shape, self.input_dim)

    else:
        raise TypeError("Wrong state shape is given: {}".format(state.shape))

    action_prob = np.squeeze(self.model.predict(state))
    assert len(action_prob) == self.output_dim, "{} != {}".format(len(action_prob), self.output_dim)
    print(state)
    print(state.shape)
    weights = self.model.get_weights()
    print(weights)
    return np.random.choice(np.arange(self.output_dim), p=action_prob)

我想在基于策略的RL中使用它。问题是即使我将权重初始化为Random normal(或其他初始值设定项),权重输出也有很多nans。此外,action_prob也以nan形式出现。权重的代表性输出如下。谁能告诉我这是如何解决的?

[array([[  1.97270699e-02,              nan,  -1.53264655e-02,
                     nan,              nan,   9.83271226e-02,
                     nan,   1.67111661e-02,              nan,
         -5.40489666e-02,              nan,  -3.19434591e-02,
                     nan,  -8.62319861e-03,              nan,
          3.90832238e-02,              nan,              nan,
                     nan,  -3.34417708e-02,              nan,
          4.17598374e-02,   1.23961531e-02,   1.13383524e-01,
          1.52971387e-01,  -7.35234842e-02,   4.81316447e-03,
                     nan,              nan,   9.02018696e-02,
         -5.64984754e-02,              nan],
       [  3.42946462e-02,              nan,  -2.32576765e-02,
                     nan,              nan,  -1.62454545e-02,
                     nan,   7.62931630e-02,              nan,
          7.09382221e-02,              nan,  -9.45277140e-02,
                     nan,   6.81431815e-02,              nan,
          5.43346964e-02,              nan,              nan,
                     nan,  -5.25366806e-04,              nan,
         -3.03930230e-02,   1.90449376e-02,  -6.84814155e-02,
         -4.24950942e-02,  -4.82842028e-02,   3.00289365e-03,
                     nan,              nan,   1.14762083e-01,
         -1.53483404e-02,              nan],
       [  1.11763954e-01,              nan,  -2.40741558e-02,
                     nan,              nan,  -2.25515720e-02,
                     nan,   8.37199837e-02,              nan,
          8.01791809e-03,              nan,   4.11959179e-02,
                     nan,  -8.09677169e-02,              nan,
          1.09827537e-02,              nan,              nan,
                     nan,   3.24306265e-03,              nan,
         -4.61481474e-02,  -4.44600247e-02,   5.97798042e-02,
         -2.80357362e-03,   4.99138907e-02,  -3.16888206e-02,
                     nan,              nan,   4.79343869e-02,
         -3.04902103e-02,              nan],
       [  9.96000832e-04,              nan,   7.03881904e-02,
                     nan,              nan,   3.29129435e-02,
                     nan,   2.59399302e-02,              nan,
          3.94702554e-02,              nan,   5.41977606e-05,
                     nan,  -8.05872083e-02,              nan,
          7.35593066e-02,              nan,              nan,
                     nan,  -3.20138596e-02,              nan,
         -4.88653146e-02,  -3.05510052e-02,   1.61004122e-02,
          3.60239707e-02,  -2.89578568e-02,  -8.55704099e-02,
                     nan,              nan,  -4.69469689e-02,
          5.44301942e-02,              nan],
       [  2.39880346e-02,              nan,   1.02485856e-02,
                     nan,              nan,  -3.28975841e-02,
                     nan,   3.20423655e-02,              nan,
          7.26358453e-03,              nan,  -3.04405931e-02,
                     nan,   1.31638274e-02,              nan,
         -6.58982694e-02,              nan,              nan,
                     nan,  -8.48279800e-03,              nan,
          5.07000796e-02,  -3.43187563e-02,   1.69583317e-02,
          5.02665602e-02,   6.59292564e-02,   5.91163523e-03,
                     nan,              nan,   1.64841004e-02,
          1.03674673e-01,              nan],
       [  2.22617369e-02,              nan,  -9.83130708e-02,
                     nan,              nan,  -8.62144455e-02,
                     nan,  -1.24993315e-03,              nan,
         -3.39315496e-02,              nan,  -3.71638462e-02,
                     nan,  -2.51251217e-02,              nan,
         -3.30121554e-02,              nan,              nan,
                     nan,   6.95239231e-02,              nan,
          3.96330692e-02,  -7.67886639e-02,   3.19798961e-02,
         -7.02575818e-02,   5.36917103e-03,  -7.84784183e-02,
                     nan,              nan,  -1.12238321e-02,
          5.90852983e-02,              nan],
       [ -1.23783462e-02,              nan,   8.54373630e-03,
                     nan,              nan,   2.71492247e-02,
                     nan,  -4.39056493e-02,              nan,
          1.54177221e-02,              nan,   8.08294937e-02,
                     nan,  -2.47991290e-02,              nan,
         -4.90374281e-04,              nan,              nan,
                     nan,  -2.03785431e-02,              nan,
         -2.94432435e-02,  -4.85701524e-02,  -5.98664656e-02,
          5.03640659e-02,  -1.06101505e-01,  -5.01858108e-02,
                     nan,              nan,   1.59794372e-02,
         -5.52875735e-03,              nan],
       [ -6.50038645e-02,              nan,  -2.88410280e-02,
                     nan,              nan,   5.70952846e-03,
                     nan,   2.29494330e-02,              nan,
          2.96308636e-03,              nan,  -1.30019784e-02,
                     nan,   1.38891954e-02,              nan,
          9.82243866e-02,              nan,              nan,
                     nan,  -4.53725718e-02,              nan,
          7.28782360e-03,  -1.97060239e-02,   1.30356764e-02,
         -1.77630689e-02,  -5.27498014e-02,  -5.70283793e-02,
                     nan,              nan,  -4.40920331e-03,
         -8.47700890e-03,              nan],
       [ -7.09274644e-03,              nan,  -2.85792332e-02,
                     nan,              nan,   1.90456193e-02,
                     nan,   2.33339947e-02,              nan,
         -7.10851625e-02,              nan,  -2.07360443e-02,
                     nan,  -8.23910628e-03,              nan,
          1.53461788e-02,              nan,              nan,
                     nan,   8.74896254e-03,              nan,
         -1.04130013e-02,  -8.23952537e-03,   3.29020806e-02,
         -8.53802171e-03,  -5.38858548e-02,   2.94392351e-02,
                     nan,              nan,   2.28152424e-03,
          3.86046581e-02,              nan],
       [  6.32084534e-02,              nan,   1.79775548e-03,
                     nan,              nan,  -5.96092641e-02,
                     nan,   1.74504239e-03,              nan,
          9.05414373e-02,              nan,  -3.55534554e-02,
                     nan,  -3.89753282e-02,              nan,
          8.71098042e-03,              nan,              nan,
                     nan,   7.47531727e-02,              nan,
          5.26362322e-02,   1.46157984e-02,   3.21042910e-03,
         -7.87475239e-03,   4.22325032e-03,   1.58537421e-02,
                     nan,              nan,   3.45352525e-03,
          9.88092553e-03,              nan],
       [  8.60697851e-02,              nan,   7.76077956e-02,
                     nan,              nan,   1.35996595e-01,
                     nan,   7.12691769e-02,              nan,
         -2.70256456e-02,              nan,   9.95257962e-03,
                     nan,  -2.21844148e-02,              nan,
          4.18028049e-02,              nan,              nan,
                     nan,   6.15538433e-02,              nan,
         -3.34422104e-02,   7.96959698e-02,   3.36392457e-03,
         -9.79953539e-03,   1.52911739e-02,  -9.56133530e-02,
                     nan,              nan,   3.26185785e-02,
         -5.18142292e-03,              nan],
       [ -7.14878365e-02,              nan,   3.30364555e-02,
                     nan,              nan,  -7.56359026e-02,
                     nan,  -8.38122815e-02,              nan,
          3.50784622e-02,              nan,   6.51308149e-02,
                     nan,  -8.44882503e-02,              nan,
          1.97267421e-02,              nan,              nan,
                     nan,  -4.02851999e-02,              nan,
         -3.84002179e-02,   3.23568434e-02,   9.30055231e-03,
          2.97283176e-02,  -3.93995969e-03,   1.24160219e-02,
                     nan,              nan,  -5.86424842e-02,
         -5.61306179e-02,              nan],
       [  5.52838258e-02,              nan,  -2.10575890e-02,
                     nan,              nan,  -1.46265700e-02,
                     nan,  -6.19944222e-02,              nan,
         -4.26368900e-02,              nan,  -1.77203845e-02,
                     nan,   7.23404884e-02,              nan,
          1.19749429e-02,              nan,              nan,
                     nan,  -1.97013188e-02,              nan,
         -9.93668661e-03,  -1.43543081e-02,  -1.89676192e-02,
         -3.46484780e-02,  -2.41095871e-02,   2.64016148e-02,
                     nan,              nan,   3.39512643e-03,
         -2.40868814e-02,              nan],
       [  4.85769324e-02,              nan,  -2.96661835e-02,
                     nan,              nan,  -1.16411140e-02,
                     nan,  -9.32439044e-03,              nan,
         -2.47888379e-02,              nan,  -2.11149845e-02,
                     nan,   1.55771989e-02,              nan,
         -3.60703245e-02,              nan,              nan,
                     nan,  -8.21380615e-02,              nan,
          7.12675974e-02,   3.52902263e-02,   5.15214726e-03,
          4.55725230e-02,  -3.67484652e-02,  -1.13544762e-02,
                     nan,              nan,  -3.86700444e-02,
         -3.91620398e-02,              nan],
       [ -5.83947077e-03,              nan,   5.90741597e-02,
                     nan,              nan,  -4.57256138e-02,
                     nan,  -8.41458961e-02,              nan,
         -7.60969743e-02,              nan,   2.50754189e-02,
                     nan,   2.75974572e-02,              nan,
          2.27455739e-02,              nan,              nan,
                     nan,  -1.64209884e-02,              nan,
         -2.64473110e-02,  -1.31150903e-02,   3.04512922e-02,
         -5.81411598e-03,   1.68283712e-02,  -1.44851422e-02,
                     nan,              nan,  -2.56322809e-02,
          1.11139610e-01,              nan],
       [  8.34780037e-02,              nan,   6.61360845e-03,
                     nan,              nan,  -1.08085848e-01,
                     nan,  -1.87303626e-03,              nan,
         -2.97805574e-02,              nan,  -4.96098958e-02,
                     nan,  -2.47526560e-02,              nan,
          5.78494631e-02,              nan,              nan,
                     nan,   9.74192936e-03,              nan,
         -4.88330796e-02,   1.02368537e-02,  -2.99407393e-02,
         -3.94638889e-02,  -1.45375028e-01,  -8.38985574e-03,
                     nan,              nan,  -2.59864815e-02,
         -5.39724007e-02,              nan],
       [  2.34477259e-02,              nan,   6.47758618e-02,
                     nan,              nan,  -2.06562635e-02,
                     nan,  -1.50227742e-02,              nan,
         -4.99106087e-02,              nan,  -8.75398964e-02,
                     nan,  -1.91738885e-02,              nan,
          9.81663391e-02,              nan,              nan,
                     nan,   8.30503032e-02,              nan,
         -6.02204986e-02,  -5.43463342e-02,  -2.73545366e-02,
         -3.97464111e-02,  -1.08450698e-03,   1.27358735e-02,
                     nan,              nan,  -6.65350258e-02,
         -7.63151273e-02,              nan],
       [ -1.75849702e-02,              nan,   5.18983677e-02,
                     nan,              nan,   2.52664816e-02,
                     nan,  -7.14112073e-03,              nan,
          2.89890468e-02,              nan,  -3.46427821e-02,
                     nan,   1.85990240e-02,              nan,
         -4.50296048e-03,              nan,              nan,
                     nan,  -5.50862215e-02,              nan,
          1.02454759e-01,   9.34040993e-02,   1.45452050e-02,
          2.90963929e-02,   3.19026299e-02,   1.89037640e-02,
                     nan,              nan,  -1.68684160e-03,
          9.94853582e-03,              nan],
       [ -9.39413719e-03,              nan,  -3.46053950e-03,
                     nan,              nan,   3.13128680e-02,
                     nan,  -2.45536752e-02,              nan,
          4.08208035e-02,              nan,   2.67537422e-02,
                     nan,   8.34849998e-02,              nan,
         -2.65908819e-02,              nan,              nan,
                     nan,  -2.63154972e-03,              nan,
          4.54281829e-02,   1.24697601e-02,   5.25561944e-02,
          5.75856939e-02,  -8.61058664e-03,   2.86082458e-02,
                     nan,              nan,  -4.48538922e-02,
          6.58497736e-02,              nan],
       [ -4.35961820e-02,              nan,   5.22863083e-02,
                     nan,              nan,  -8.59688129e-03,
                     nan,  -5.25927730e-02,              nan,
          7.24843144e-02,              nan,  -4.00458984e-02,
                     nan,  -2.85069328e-02,              nan,
          2.43122727e-02,              nan,              nan,
                     nan,   1.57326814e-02,              nan,
          4.99758229e-04,   1.23931235e-02,   1.90575924e-02,
         -4.64425469e-03,   5.54191284e-02,   2.38004271e-02,
                     nan,              nan,  -7.39056617e-03,
          3.59723084e-02,              nan],
       [  6.80808276e-02,              nan,  -1.49172200e-02,
                     nan,              nan,  -1.84247848e-02,
                     nan,   7.11160824e-02,              nan,
          4.74170335e-02,              nan,  -8.48565064e-03,
                     nan,   6.96734041e-02,              nan,
          1.07453577e-01,              nan,              nan,
                     nan,   3.21782194e-02,              nan,
          3.53086367e-02,  -2.57775784e-02,  -3.70149538e-02,
          8.49922895e-02,   4.88188267e-02,   4.43161186e-03,
                     nan,              nan,   7.35458219e-03,
         -4.75145914e-02,              nan],
       [ -1.23953104e-01,              nan,  -4.27762084e-02,
                     nan,              nan,   2.04169434e-02,
                     nan,   5.78987077e-02,              nan,
         -6.60712123e-02,              nan,  -2.07597148e-02,
                     nan,   3.00809499e-02,              nan,
          1.40863642e-01,              nan,              nan,
                     nan,  -4.05914113e-02,              nan,
         -4.87232655e-02,   1.49445562e-02,   3.01859360e-02,
          2.01087426e-02,   7.96428975e-03,   2.58545913e-02,
                     nan,              nan,  -3.26734572e-03,
          2.30945610e-02,              nan]], dtype=float32), array([  0.,  nan,   0.,  nan,  nan,   0.,  nan,   0.,  nan,   0.,  nan,
         0.,  nan,   0.,  nan,   0.,  nan,  nan,  nan,   0.,  nan,   0.,
         0.,   0.,   0.,   0.,   0.,  nan,  nan,   0.,   0.,  nan], dtype=float32), array([[        nan,         nan,         nan, ...,         nan,
                nan,  0.08562656],
       [        nan,         nan,         nan, ...,         nan,
                nan, -0.03227361],
       [        nan,         nan,         nan, ...,         nan,
                nan, -0.1371294 ],
       ..., 
       [        nan,         nan,         nan, ...,         nan,
                nan,  0.01600872],
       [        nan,         nan,         nan, ...,         nan,
                nan, -0.0156843 ],
       [        nan,         nan,         nan, ...,         nan,
                nan, -0.036583  ]], dtype=float32), array([ nan,  nan,  nan,  nan,  nan,  nan,   0.,   0.,  nan,   0.,   0.,
         0.,   0.,   0.,  nan,  nan,  nan,   0.,  nan,   0.,   0.,   0.,
        nan,   0.,  nan,  nan,  nan,  nan,  nan,  nan,  nan,   0.], dtype=float32), array([[ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan],
       [ nan,  nan,  nan]], dtype=float32), array([ nan,  nan,  nan], dtype=float32)]

1 个答案:

答案 0 :(得分:0)

我面临着同样的问题。当我尝试使用Keras实现3层GRU堆叠在一起时,我发现每当其中一层具有nan值时。因此,即使计算出的损失也是微不足道的。初始化程序为'glorot_uniform'。我暂时无法解决问题。但是最近,当我使用命令更新keras和tensorflow时-

pip install keras --upgrade
pip install --upgrade tensorflow-gpu

问题得到解决,然后我可以将MSCOCO数据集上的损失减少到大约1.3。

问题可能比与不同版本的兼容性更深刻。但是这样做对我有帮助,认为可能对您有帮助。