从神经网络中查找最重要的输入

时间:2019-01-20 03:25:15

标签: python-3.x tensorflow keras

我训练了一个具有37个输入的神经网络。它具有约85%的准确性。我有可能找出哪个Input最有效。我尝试了这段代码,但无法弄清楚如何找到最重要的输入

weights = model.layers[0].get_weights()[0]
biases = model.layers[0].get_weights()[1]

1 个答案:

答案 0 :(得分:1)

一种可能的解决方案是用keras.wrappers.scikit_learn包装模型,然后在scikit-learn中使用Recursive Feature elimination

var output = {
    name: "start",
    children: []
};
var len = rawData.length;
for (var i = 0; i < len; i++) {
    rawChild = rawData[i];
    cat = createJson({}, rawChild.dimension.filter(n => n), rawChild.metric[0]);
    if (i == 0)
        output.children.push(cat);
    else {
        mergeData(output, output.children, cat);
    }
}


function mergeData(parent, child, cat) {
    if (child) {
        for (var index = 0; index < child.length; index++) {
            var element = child[index];

            if (cat.children) {
                if (element.name == cat.name) {
                    parent = mergeData(element, element.children, cat.children[0]);
                    return parent;
                } else {
                    continue;
                }
            } else {
                if (element.name == cat.name) {
                    parent = mergeData(element, element.children, cat);
                    return parent;
                } else {
                    continue;
                }
            }

        }
        parent.children.push(cat);
        return parent;
    } else {
        return;
    }
}
console.log(util.inspect(output, false, null, true));

function createJson(mainObj, names, value) {
    if (!Array.isArray(names)) {
        mainObj.name = names;
        mainObj.value = value;
        return mainObj;
    } else {
        for (var index = 0; index < names.length; index++) {
            if (index == names.length - 1) {
                mainObj = createJson(mainObj, names[index], value);
            } else {
                mainObj.name = names[index];
                newarr = names;
                newarr.shift();
                mainObj.children = [createJson({}, newarr, value)];
            }
        }
    }
    return mainObj;
}

ranking pixel with rfe

如果您需要可视化权重,请参见here