使用输入列表来训练Keras模型

时间:2019-06-18 15:33:49

标签: python python-3.x tensorflow keras tensorflow2.0

我有以下模型:

class VDNMixer(kr.Model):
    def __init__(self, hidden_units, observation_shape, action_dim):
        super().__init__('VDNMixer')
        self.action_dim = action_dim
        input_layer = kl.Input(shape=observation_shape)
        dense = MLP(hidden_units, 0, observation_shape)(input_layer)
        stream_adv, stream_val = tf.split(dense, 2, axis=1)
        advantage = kl.Dense(action_dim, activation=None, use_bias=None)(stream_adv)
        advantage = tf.subtract(advantage, tf.reduce_mean(advantage, axis=1, keepdims=True))
        value = kl.Dense(1, activation=None, use_bias=None)(stream_val)
        Qout = value + advantage
        self.q = kr.Model(inputs=input_layer, outputs=[Qout])

    def __call__(self, list_of_obs_act, training=True):
        """
        :param list_of_obs_act: list of n arrays with dimensions [None, obs_dim+1]
        :param training: (bool)
        :return: n
        """
        result = []
        for obs in list_of_obs_act:
            # n_agents, obs_shape
            act = obs[:, -1]
            obs = obs[:, :-1]
            qout = self.q(obs)
            result.append(tf.reduce_sum(tf.one_hot(tf.cast(act, tf.int32), self.action_dim) * qout))
        return tf.stack(result)

我可以用它来精确地计算损失函数:

def td_loss(targets, qtot):
    return tf.keras.losses.mean_squared_error(targets, qtot)

# model params
observation_shape = (100, )
action_dim = 5
q_kwargs = {'hidden_units': [512, 256, 128],
            'observation_shape': observation_shape,
            'action_dim': action_dim}
net = VDNMixer(**q_kwargs)

# create dummy data
batch_size = 10
agents = np.random.randint(0, 8, size=batch_size)
list_of_obs_act = [np.hstack((np.random.rand(a, observation_shape[0]), np.random.randint(0, action_dim, size=(a, 1))))
                   for a in agents]
t_targets = np.random.rand(batch_size)

# compute loss
print(td_loss(net(list_of_obs_act), t_targets))
>> tf.Tensor(1.0777278029536537, shape=(), dtype=float64)

但是,如果我使用keras高级API来训练模型:

net.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=td_loss,
            metrics=[])
net.train_on_batch(list_of_obs_act, t_targets)

我收到以下错误:

Traceback (most recent call last):
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3296, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-10-ecf35d2000b8>", line 62, in <module>
    net.train_on_batch(list_of_obs_act, t_targets)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 896, in train_on_batch
    extract_tensors_from_dataset=True)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 2335, in _standardize_user_data
    self._set_inputs(cast_inputs)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 2553, in _set_inputs
    outputs = self(inputs, **kwargs)
  File "<ipython-input-10-ecf35d2000b8>", line 37, in __call__
    qout = self.q(obs)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 662, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 753, in call
    return self._run_internal_graph(inputs, training=training, mask=mask)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 895, in _run_internal_graph
    output_tensors = layer(computed_tensors, **kwargs)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 662, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/joao/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 169, in wrapper
    raise e.ag_error_metadata.to_exception(type(e))
ValueError: in converted code:
    relative to /home/joao:
    github/artificial_life/models/base.py:25 call  *
        inputs = layer(inputs)
    anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:611 __call__
        self.name)
    anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/input_spec.py:163 assert_input_compatibility
        layer_name + ' is incompatible with the layer: '
    ValueError: Input 0 of layer dense_15 is incompatible with the layer: its rank is undefined, but the layer requires a defined rank.

有什么想法吗?

我正在使用Tensorflow:v2.0.0-beta0-16-g1d91213fe7 2.0.0-beta1

Python 3.6.8

Ubuntu 18.04

0 个答案:

没有答案