如何在keras中定义部分分类层?

时间:2020-06-24 18:52:37

标签: tensorflow keras keras-layer

我正在尝试定义一个可以执行部分​​分类功能的“激活层”。我不知道如何描述它,因此请参见下面的示例:

function merge(a, b, prop){
    var reduced =  a.filter( aitem =>  b.find ( bitem => aitem[prop] === bitem[prop]) );
  return reduced;
}
console.log( "ES6", merge(odd, even, "name") );

从索引3中,我选择长度为2、3、2的子数组,并将每个子数组的最大值设置为1,将其他子数组的最大值设置为0。 我有以下代码:

if I have a tensor:
[1.1, 2.2, 2.5, 3.1, 4.1, 5.0, 3.2, 4.9, 50.5, 10.2], 
with cate_dims = [2, 3, 2], 
start = 3, I want to convert it to: 
[1.1, 2.2, 2.5, 0. , 1. , 1. , 0. , 0. , 1. , 0. ]

如果我提供带有值的张量,它会很好地工作。但是,如果我想将此层放在顺序模型中,就像这样:

from keras.layers import Layer
import tensorflow as tf
import numpy as np

def findmax(l):
    sess=tf.Session()
    sess.run(tf.compat.v1.global_variables_initializer())

    arr = l.eval(session=sess)
    MAX = max(arr)
    for i in range(len(arr)):
        if arr[i] == MAX:
            arr[i] = 1
        else:
            arr[i] = 0
    return tf.convert_to_tensor(arr)

class BinarizeCategorical(Layer):
    def __init__(self, start, cate_dims, **kwargs):
        super(BinarizeCategorical, self).__init__(**kwargs)
        self.cate_dims = cate_dims
        self.start = start

    def call(self, inputs):
        if self.cate_dims != []:
            con = tf.slice(inputs, [0], [self.start])
            out = con
            new_start = self.start
            for cate in self.cate_dims:
                t = tf.slice(inputs, [new_start], [cate])
                new_start += cate
                t = findmax(t)
                out = tf.concat([out, t], 0)
            return out
        else:
            return inputs

    def get_config(self):
        config = super(BinarizeCategorical, self).get_config()
        return config

    def compute_output_shape(self, input_shape):
        return input_shape

它将运行到错误:

start = 5
cate_dims = [2, 2, 2, 2]

def build_generator(start, cate_dims):
    
    model = Sequential()

    model.add(Dense(64, activation="relu", input_dim=10))
    model.add(Dense(64 * 2, activation='relu'))
    model.add(Dense(64 * 4, activation='relu'))
    model.add(Dense(13))
    model.add(BinarizeCategorical(start, cate_dims))

    model.summary() 
    
    noise = Input(shape=(10,))
    
    item = model(noise)
        
    return Model(noise, item)
generator = build_generator(5, [2,2,2,2])

我认为这是因为Input中没有任何值。它只是一个占位符。但是我不知道如何重写我的代码。有没有人可以帮助我? 谢谢!

0 个答案:

没有答案