Spark | ML | Random Forest |从RandomForestClassificationModel的.txt加载训练模型。 toDebugString

时间:2016-12-16 05:22:04

标签: apache-spark serialization random-forest apache-spark-ml

使用Spark 1.6和ML库我使用RandomForestClassificationModel保存经过培训的toDebugString()的结果:

 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
 val stringModel =rfModel.toDebugString
 //save stringModel into a file in the driver in format .txt 

所以我的想法是,将来读取文件.txt并加载训练的randomForest,是否可能?

谢谢!

3 个答案:

答案 0 :(得分:0)

那不会奏效。 ToDebugString只是一个调试信息,用于了解它是如何计算的。

如果你想保留这个东西供以后使用,你可以做同样的事情,这是(尽管我们在纯java中)只是序列化RandomForestModel对象。可能存在与默认java序列化的版本不兼容,因此我们使用Hessian来执行此操作。它通过版本更新工作 - 我们从spark 1.6.1开始,它仍然适用于spark 2.0.2。

答案 1 :(得分:0)

如果你没有坚持使用ml,那么juste使用mllib的实现:你使用mllib获得的RandomForestModel具有-ffpe-trap=invalid,zero,overflow函数。

答案 2 :(得分:0)

至少对于Spark 2.1.0,您可以使用以下Java(抱歉 - 无Scala)代码执行此操作。但是,依赖于未经注册的未记录格式可能并不是最明智的想法。

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
 * RandomForest.
 */
public abstract class RandomForest {

    private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);

    protected final List<Node> trees = new ArrayList<>();

    /**
     * @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
     * @throws IOException
     */
    public RandomForest(final URL model) throws IOException {
        try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) {
            Node node;
            while ((node = load(reader)) != null) {
                trees.add(node);
            }
        }
        if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
        if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
    }

    private static Node load(final BufferedReader reader) throws IOException {
        final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
        final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
        Node root = null;
        final List<Node> stack = new ArrayList<>();
        String line;
        while ((line = reader.readLine()) != null) {
            final String trimmed = line.trim();
            //System.out.println(trimmed);
            if (trimmed.startsWith("RandomForest")) {
                // skip the "Tree 1" line
                reader.readLine();
            } else if (trimmed.startsWith("Tree")) {
                break;
            } else if (trimmed.startsWith("If")) {
                // extract feature index
                final Matcher m = ifPattern.matcher(trimmed);
                m.matches();
                final int featureIndex = Integer.parseInt(m.group(1));
                final String operator = m.group(2);
                final String operand = m.group(3);
                final Predicate<Float> predicate;
                if ("<=".equals(operator)) {
                    predicate = new LessOrEqual(Float.parseFloat(operand));
                } else if (">".equals(operator)) {
                    predicate = new Greater(Float.parseFloat(operand));
                } else if ("in".equals(operator)) {
                    predicate = new In(parseFloatArray(operand));
                } else if ("not in".equals(operator)) {
                    predicate = new NotIn(parseFloatArray(operand));
                } else {
                    predicate = null;
                }
                final Node node = new Node(featureIndex, predicate);

                if (stack.isEmpty()) {
                    root = node;
                } else {
                    insert(stack, node);
                }
                stack.add(node);
            } else if (trimmed.startsWith("Predict")) {
                final Matcher m = predictPattern.matcher(trimmed);
                m.matches();
                final Object node = Float.parseFloat(m.group(1));
                insert(stack, node);
            }
        }
        return root;
    }

    private static void insert(final List<Node> stack, final Object node) {
        Node parent = stack.get(stack.size() - 1);
        while (parent.getLeftChild() != null && parent.getRightChild() != null) {
            stack.remove(stack.size() - 1);
            parent = stack.get(stack.size() - 1);
        }
        if (parent.getLeftChild() == null) parent.setLeftChild(node);
        else parent.setRightChild(node);
    }

    private static float[] parseFloatArray(final String set) {
        final StringTokenizer st = new StringTokenizer(set, "{,}");
        final float[] floats = new float[st.countTokens()];
        for (int i=0; st.hasMoreTokens(); i++) {
            floats[i] = Float.parseFloat(st.nextToken());
        }
        return floats;
    }

    public abstract float predict(final float[] features);

    public String toDebugString() {
        try {
            final StringWriter sw = new StringWriter();
            for (int i=0; i<trees.size(); i++) {
                sw.write("Tree " + i + ":\n");
                print(sw, "", trees.get(0));
            }
            return sw.toString();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private static void print(final Writer w, final String indent, final Object object) throws IOException {
        if (object instanceof Number) {
            w.write(indent + "Predict: " + object + "\n");
        } else if (object instanceof Node) {
            final Node node = (Node) object;
            // left node
            w.write(indent + node + "\n");
            print(w, indent + " ", node.getLeftChild());
            w.write(indent + "Else\n");
            print(w, indent + " ", node.getRightChild());
        }
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}";
    }

    /**
     * Node.
     */
    protected static class Node {

        private final int featureIndex;
        private final Predicate<Float> predicate;
        private Object leftChild;
        private Object rightChild;

        public Node(final int featureIndex, final Predicate<Float> predicate) {
            Objects.requireNonNull(predicate);
            this.featureIndex = featureIndex;
            this.predicate = predicate;
        }

        public void setLeftChild(final Object leftChild) {
            this.leftChild = leftChild;
        }

        public void setRightChild(final Object rightChild) {
            this.rightChild = rightChild;
        }

        public Object getLeftChild() {
            return leftChild;
        }

        public Object getRightChild() {
            return rightChild;
        }

        public Object eval(final float[] features) {
            Object result = this;
            do {
                final Node node = (Node)result;
                result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
            } while (result instanceof Node);

            return result;
        }

        @Override
        public String toString() {
            return "If (feature " + featureIndex + " " + predicate + ")";
        }

    }

    private static class LessOrEqual implements Predicate<Float> {
        private final float value;

        public LessOrEqual(final float value) {
            this.value = value;
        }

        @Override
        public boolean test(final Float f) {
            return f <= value;
        }

        @Override
        public String toString() {
            return "<= " + value;
        }
    }

    private static class Greater implements Predicate<Float> {
        private final float value;

        public Greater(final float value) {
            this.value = value;
        }

        @Override
        public boolean test(final Float f) {
            return f > value;
        }

        @Override
        public String toString() {
            return "> " + value;
        }
    }

    private static class In implements Predicate<Float> {
        private final float[] array;

        public In(final float[] array) {
            this.array = array;
        }

        @Override
        public boolean test(final Float f) {
            for (int i=0; i<array.length; i++) {
                if (array[i] == f) return true;
            }
            return false;
        }

        @Override
        public String toString() {
            return "in " + Arrays.toString(array);
        }
    }

    private static class NotIn implements Predicate<Float> {
        private final float[] array;

        public NotIn(final float[] array) {
            this.array = array;
        }

        @Override
        public boolean test(final Float f) {
            for (int i=0; i<array.length; i++) {
                if (array[i] == f) return false;
            }
            return true;
        }

        @Override
        public String toString() {
            return "not in " + Arrays.toString(array);
        }
    }
}

要使用该类进行分类,请使用:

import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

/**
 * RandomForestClassifier.
 */
public class RandomForestClassifier extends RandomForest {

    public RandomForestClassifier(final URL model) throws IOException {
        super(model);
    }

    @Override
    public float predict(final float[] features) {
        final Map<Object, Integer> counts = new HashMap<>();
        trees.stream().map(node -> node.eval(features))
                .forEach(result -> {
                    Integer count = counts.get(result);
                    if (count == null) {
                        counts.put(result, 1);
                    } else {
                        counts.put(result, count + 1);
                    }
                });
        return (Float)counts.entrySet()
                .stream()
                .sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
                .map(Map.Entry::getKey)
                .findFirst().get();
    }
}

回归:

import java.io.IOException;
import java.net.URL;

/**
 * RandomForestRegressor.
 */
public class RandomForestRegressor extends RandomForest {

    public RandomForestRegressor(final URL model) throws IOException {
        super(model);
    }

    @Override
    public float predict(final float[] features) {
        return (float)trees
                .stream()
                .mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
                .average()
                .getAsDouble();
    }
}