Hive UDAF用于查找最常出现的列值

时间:2016-07-17 20:37:53

标签: java hadoop hive aggregate-functions

我正在尝试创建一个Hive UDAF来查找最常出现的列(字符串)值(不是单个字符或子字符串,使用精确的列值)。假设以下是我的名为my_table的表(破折号用于直观地分隔列)。

User_Id - Item  - Count
 1  - A - 1
 1  - B - 1
 1  - A - 1
 1  - A - 1
 1  - A - 1
 1  - C - 1
 2  - C - 1
 2  - C - 1
 2  - A - 1
 2  - C - 1

假设我调用以下脚本:

Select User_Id, findFrequent(*) from my_table group by User_Id

我应该得到以下输出,因为对于User_Id = 1,A出现4次而B和C只出现一次。因此,User_Id = 1的最频繁的一个是A.类似地,User_Id = 2的最频繁的一个是C.换句话说,每个唯一的User_Id应该只有一个最频繁的列值。

1 - A
2 - C

我按照这个例子https://github.com/rathboma/hive-extension-examples/blob/master/src/main/java/com/matthewrathbone/example/TotalNumOfLettersGenericUDAF.java创建了一个类但到目前为止没有运气。这是我的代码:

@Description(name = "FindMostCommonString", value = "_FUNC_(expr) - Returns most commonly found string of a column.")
public class FindMostCommonString extends AbstractGenericUDAFResolver {

@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
        throws SemanticException {
    if (parameters.length != 1) {
        throw new UDFArgumentTypeException(parameters.length - 1,
                "Exactly one argument is expected.");
    }

    ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);

    if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
        throw new UDFArgumentTypeException(0,
                        "Argument must be PRIMITIVE, but "
                        + oi.getCategory().name()
                        + " was passed.");
    }

    PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;

    if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
        throw new UDFArgumentTypeException(0,
                        "Argument must be String, but "
                        + inputOI.getPrimitiveCategory().name()
                        + " was passed.");
    }

    return new MostCommonStringEvaluator();
}

public static class MostCommonStringEvaluator extends GenericUDAFEvaluator {

    PrimitiveObjectInspector inputOI;
    ObjectInspector outputOI;
    MapObjectInspector mapOI;

    HashMap<String, Integer> total = new HashMap<String, Integer>();


    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
            throws HiveException {

        assert (parameters.length == 1);
        super.init(m, parameters);

        // init input object inspectors

        if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
            inputOI = (PrimitiveObjectInspector) parameters[0];
        }
        else{
            mapOI =  (MapObjectInspector) parameters[0];
        }

        outputOI = ObjectInspectorFactory.getReflectionObjectInspector(String.class,
                ObjectInspectorOptions.JAVA);


        return outputOI;

    }


    static class StringCountAgg implements AggregationBuffer {
        HashMap<String, Integer> strCount; 
        void add(String str){

            if(strCount.containsKey(str)){
                strCount.put(str,strCount.get(str)+1);
            }
            else{
                strCount.put(str,1);
            }
        }

        StringCountAgg(){
            strCount = new HashMap<String, Integer>();
        }
    }

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
        StringCountAgg result = new StringCountAgg();
        return result;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
        StringCountAgg myagg = new StringCountAgg();
    }

    private boolean warned = false;

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
            throws HiveException {
        assert (parameters.length == 1);
        if (parameters[0] != null) {
            StringCountAgg myagg = (StringCountAgg) agg;
            Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
            myagg.add((String)p1);
        }
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
        StringCountAgg myagg = (StringCountAgg) agg;
        appendToHashMap(total, myagg.strCount);
        return total;
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial)
            throws HiveException {
        if (partial != null) {

            StringCountAgg myagg1 = (StringCountAgg) agg;

            HashMap<String, Integer>  partialRes = (HashMap<String, Integer> ) mapOI.getMap(partial);

            appendToHashMap(myagg1.strCount, partialRes);
        }
    }

    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
        StringCountAgg myagg = (StringCountAgg) agg;
        appendToHashMap(total, myagg.strCount);
        String result = null;
        int maxCount = 0;

        for(String key: total.keySet()){

            if(total.get(key) > maxCount){
                maxCount = total.get(key);
                result = key;
            }
        }

        return result;
    }


    private void appendToHashMap(HashMap<String, Integer> main, HashMap<String, Integer> strCount) {
        for(String key: strCount.keySet()){
            if(main.containsKey(key)){
                main.put(key,main.get(key)+strCount.get(key));
            }
            else{
                main.put(key, strCount.get(key));
            }
        }
    }

}
}

2 个答案:

答案 0 :(得分:1)

不确定你被困在哪里但我用这种方法解决了你的问题:

<强> 输入:

yum install java-1.8.0-openjdk java-1.8.0-openjdk-devel
yum install tomcat8-webapps tomcat8-admin-webapps 

更新:基于SQL的解决方案: -

hive> Select * from test;
OK
1   A   1
1   B   1
1   A   1
1   A   1
1   A   1
1   C   1
2   C   1
2   C   1
2   A   1
2   C   1
Time taken: 0.15 seconds, Fetched: 10 row(s)

更新:为将来的读者删除了不相关的解决方案

答案 1 :(得分:1)

select User_Id,Item from HiveTable;
+---------+------+
| User_Id | Item |
+---------+------+
|       1 | A    |
|       1 | B    |
|       1 | A    |
|       1 | A    |
|       1 | A    |
|       1 | C    |
|       2 | C    |
|       2 | C    |
|       2 | C    |
|       2 | A    |
|       2 | C    |
+---------+------+

查询 -

select User_Id, Item from 
(
select User_Id,count(*) as total,Item from HiveTable group by User_Id, Item order by total desc
)q3 group by User_Id;

输出

+---------+------+
| User_Id | Item |
+---------+------+
|       1 | A    |
|       2 | C    |
+---------+------+

希望这会有所帮助