Keras中的图卷积

时间:2019-01-08 09:20:53

标签: graph keras neural-network

如何在Keras中实现图卷积? 理想情况下,该层为接受2个输入的层的形式-节点的集合(按时间顺序)和每个节点的邻居的(相同的时间维度长度)整数索引集(在时间维度上)。

1 个答案:

答案 0 :(得分:0)

如果我们能够将项目收集为Conv图层的样式和形状,则可以使用常规卷积。
可以使用此Keras层(使用张量流的聚集)来完成聚集。

class GatherFromIndices(Layer):
    """
    To have a graph convolution (over a fixed/fixed degree kernel) from a given sequence of nodes, we need to gather
    the data of each node's neighbours before running a simple Conv1D/conv2D,
     that would be effectively a defined convolution (or even TimeDistributed(Dense()) can be used - only
     based on data format we would output).
    This layer should do exactly that.

    Does not support non integer values, values lesser than 0 zre automatically masked.
    """
    def __init__(self, mask_value=0, include_self=True, flatten_indices_features=False, **kwargs):
        Layer.__init__(self, **kwargs)
        self.mask_value = mask_value
        self.include_self = include_self
        self.flatten_indices_features = flatten_indices_features

    def get_config(self):
        config = {'mask_value': self.mask_value,
                  'include_self': self.include_self,
                  'flatten_indices_features': self.flatten_indices_features,
                  }
        base_config = super(GatherFromIndices, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    #def build(self, input_shape):
        #self.built = True

    def compute_output_shape(self, input_shape):
        inp_shape, inds_shape = input_shape
        indices = inds_shape[-1]
        if self.include_self:
            indices += 1
        features = inp_shape[-1]
        if self.flatten_indices_features:
            return tuple(list(inds_shape[:-1]) + [indices * features])
        else:
            return tuple(list(inds_shape[:-1]) + [indices, features])

    def call(self, inputs, training=None):
        inp, inds = inputs
        # assumes input in the shape of (inp=[...,batches, sequence_len, features],
        #  inds = [...,batches,sequence_ind_len, neighbours]... indexing into inp)
        # for output we want to get  [...,batches,sequence_ind_len, indices,features]

        assert_shapes = tf.Assert(tf.reduce_all(tf.equal(tf.shape(inp)[:-2], tf.shape(inds)[:-2])), [inp])
        assert_positive_ins_shape = tf.Assert(tf.reduce_all(tf.greater(tf.shape(inds), 0)), [inds])
        # the shapes need to be the same (with the exception of the last dimension)
        with tf.control_dependencies([assert_shapes, assert_positive_ins_shape]):
            inp_shape = tf.shape(inp)
            inds_shape = tf.shape(inds)

            features_dim = -1
            # ^^ todo for future variablility of the last dimension, because maybe can be made to take not the last
            # dimension as features, but something else.

            inp_p = tf.reshape(inp, [-1, inp_shape[features_dim]])
            ins_p = tf.reshape(inds, [-1, inds_shape[features_dim]])

            # we have lost the batchdimension by reshaping, so we save it by adding the size to the respective indexes
            # we do it because we use the gather_nd as nonbatched (so we do not need to provide batch indices)
            resized_range = tf.range(tf.shape(ins_p)[0])
            different_seqs_ids_float = tf.scalar_mul(1.0 / tf.to_float(inds_shape[-2]), tf.to_float(resized_range))
            different_seqs_ids = tf.to_int32(tf.floor(different_seqs_ids_float))
            different_seqs_ids_packed = tf.scalar_mul(inp_shape[-2], different_seqs_ids)
            thseq = tf.expand_dims(different_seqs_ids_packed, -1)

            # in case there are negative indices, make them all be equal to -1
            #  and add masking value to the ending of inp_p - that way, everything that should be masked
            #  will get the masking value as features.
            mask = tf.greater_equal(ins_p, 0)  # extract where minuses are, because the will all default to default value
            # .. before the mod operation, if provided greater id numbers, to wrap correctly small sequences
            offset_ins_p = tf.mod(ins_p, inp_shape[-2]) + thseq  # broadcast to ins_p
            minus_1 = tf.scalar_mul(tf.shape(inp_p)[0], tf.ones_like(mask, dtype=tf.int32))
            '''
            On GPU, if we use index = -1 anywhere it would throw a warning:
            OP_REQUIRES failed at gather_nd_op.cc:50 : Invalid argument: 
            flat indices = [-1] does not index into param.
            Which is a warning, that there are -1s. We are using that as feature and know about that.
            '''
            offset_ins_p = tf.where(mask, offset_ins_p, minus_1)
            # also possible to do something like  tf.multiply(offset_ins_p, mask) + tf.scalar_mul(-1, mask)
            mask_value_last = tf.zeros((inp_shape[-1],))
            if self.mask_value != 0:
                mask_value_last += tf.constant(self.mask_value)  # broadcasting if needed
            inp_p = tf.concat([inp_p, tf.expand_dims(mask_value_last, 0)], axis=0)

            # expand dims so that it would slice n times instead having slice of length n indices
            neighb_p = tf.gather_nd(inp_p, tf.expand_dims(offset_ins_p, -1))  # [-1,indices, features]

            out_shape = tf.concat([inds_shape, inp_shape[features_dim:]], axis=-1)
            neighb = tf.reshape(neighb_p, out_shape)
            # ^^ [...,batches,sequence_len, indices,features]

            if self.include_self: # if is set, add self at the 0th position
                self_originals = tf.expand_dims(inp, axis=features_dim-1)
                # ^^ [...,batches,sequence_len, 1, features]
                neighb = tf.concat([neighb, self_originals], axis=features_dim-1)

            if self.flatten_indices_features:
                neighb = tf.reshape(neighb, tf.concat([inds_shape[:-1], [-1]], axis=-1))

            return neighb

使用可调试的交互式测试:

def allow_tf_debug(func):
    """
    Decorator for tests that use tensorflow, to make them more breakpoint-friendly, i.e. to be able to call .eval()
    on tensors immediately.
    """
    def interactive_wrapper():
        sess = tf.InteractiveSession()
        ret = func()
        sess.close()
        return ret
    return interactive_wrapper


@allow_tf_debug
def test_gather_from_indices():
    gat = GatherFromIndices(include_self=False, flatten_indices_features=False)
    # test for include_self=True is not included
    # test for flatten_indices_features not included

    seq = [  # batch of sequences
        # sequences of 2d features
        [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8]],
        [[10, 1], [11, 2], [12, 3], [13, 4], [14, 5], [15, 6], [16, 7], [17, 8]]
        ]

    ids = [  # batch of sequences
        # sequences of 3 ids of each item in sequence
        [[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [5, 5, 5], [6, 6, 6], [7, 7, 7]],
        [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [5, 6, 7], [6, 7, 0], [7, 0, -1]]
        # minus one should mean masking
        ]

    def compute_assert_2ways_gathers(seq, ids):
        seq = np.array(seq, dtype=np.float32)
        ids = np.array(ids, dtype=np.int32)
        # intended_look
        result_np = None
        if len(ids.shape) == 3:  # classical batches
            result_np = np.empty(list(ids.shape) + [seq.shape[-1]])
            for b, seq_in_batch in enumerate(ids):
                for i, sid in enumerate(seq_in_batch):
                    for c, copyid in enumerate(sid):
                        assert ids[b,i,c] == copyid
                        if ids[b,i,c] < 0:
                            result_np[b, i, c, :] = 0
                        else:
                            result_np[b, i, c, :] = seq[b, ids[b,i,c], :]
        elif len(ids.shape) == 4:  # some other batching format...
            result_np = np.empty(list(ids.shape) + [seq.shape[-1]])
            for mb, mseq_in_batch in enumerate(ids):
                for b, seq_in_batch in enumerate(mseq_in_batch):
                    for i, sid in enumerate(seq_in_batch):
                        for c, copyid in enumerate(sid):
                            assert ids[mb, b, i, c] == copyid
                            if ids[mb, b, i, c] < 0:
                                result_np[mb, b, i, c, :] = 0
                            else:
                                result_np[mb, b, i, c, :] = seq[mb, b, ids[mb, b, i, c], :]

        output_shape_kerascomputed = gat.compute_output_shape([seq.shape, ids.shape])
        assert isinstance(output_shape_kerascomputed, tuple)
        assert list(output_shape_kerascomputed) == list(result_np.shape)
        #with tf.get_default_session() as sess:
        sess = tf.get_default_session()
        gat.build(seq.shape)
        result = gat.call([tf.constant(seq), tf.constant(ids)])
        tf_result = sess.run(result)

        assert list(tf_result.shape) == list(output_shape_kerascomputed)
        assert np.all(np.equal(tf_result, result_np))

    compute_assert_2ways_gathers(seq, ids)
    compute_assert_2ways_gathers(seq * 5, ids * 5)
    compute_assert_2ways_gathers([seq] * 3, [ids] * 3)

每个节点5个邻居的用法示例:

fields_input = Input(shape=(None, 10, name='nodedata')
neighbours_ids_input = Input(shape=(None, 5), name='nodes_neighbours_ids', dtype='int32')

fields_input_with_neighbours = GatherFromIndices(mask_value=0,
                                                 include_self=True, flatten_indices_features=True)\
    ([fields_input, neighbours_ids_input])

fields = Conv1D(128, kernel_size=5, padding='same',
                activation='relu')(fields_input_with_neighbours)  #  data_format="channels_last"