如何使用Java中的lambda函数对每列和每行的列表列表求和?

时间:2017-11-08 08:33:31

标签: java lambda java-8

我有一个像这样的列表列表:

List<List<Wrapper>> listOfLists = new ArrayList<>();

class Wrapper {
   private int value = 0;
   public int getValue() {
      return value;
   }
}

所以看起来像这样:

[
[Wrapper(3), Wrapper(4), Wrapper(5)],
[Wrapper(1), Wrapper(2), Wrapper(9)],
[Wrapper(4), Wrapper(10), Wrapper(11)],
]

是否有一种简洁的方法可以使用Java 8中的lambda函数展平下面的列表列表:

(per column): [Wrapper(8), Wrapper(16), Wrapper(25)]
(per row): [Wrapper(12), Wrapper(12), Wrapper(25)]

可能它适用于不同大小的内部列表:

[
[Wrapper(3), Wrapper(5)],
[Wrapper(1), Wrapper(2), Wrapper(9)],
[Wrapper(4)],
]

这将导致:

(per column): [Wrapper(8), Wrapper(7), Wrapper(9)]
(per row): [Wrapper(8), Wrapper(11), Wrapper(4)]

似乎比以下更复杂: Turn a List of Lists into a List Using Lambdas3 ways to flatten a list of lists. Is there a reason to prefer one of them?

我最初的做法与列表类似: https://stackoverflow.com/a/36878011/986160

谢谢!

5 个答案:

答案 0 :(得分:7)

每行实际上非常简单:

List<Wrapper> perRow = listOfLists.stream()
            .map(x -> x.stream().mapToInt(Wrapper::getValue).sum())
            .map(Wrapper::new)
            .collect(Collectors.toList());

另一方面,每列不是那么简单:

private static List<Wrapper> perColumn(List<List<Wrapper>> listOfList) {
    int depth = listOfList.size();
    int max = listOfList.stream().map(List::size).max(Comparator.naturalOrder()).get();
    return IntStream.range(0, max)
            .map(x -> IntStream.range(0, depth)
                    .map(y -> listOfList.get(y).size() < y ? 0 : listOfList.get(y).get(x).getValue())
                    .sum())
            .mapToObj(Wrapper::new)
            .collect(Collectors.toList());
}

答案 1 :(得分:3)

如果不是List<List<Wrapper>>而是Guava Table,那么解决方案就会轻而易举:

Table<Integer, Integer, Wrapper> table = HashBasedTable.create();

table.put(0, 0, new Wrapper(3));
table.put(0, 1, new Wrapper(5));
// TODO fill the table 

List<Wrapper> byRow = table.rowMap().values().stream()
    .map(row -> row.values().stream().mapToInt(Wrapper::getValue).sum())
    .map(Wrapper::new)
    .collect(Collectors.toList());

List<Wrapper> byColumn = table.columnMap().values().stream()
    .map(column -> column.values().stream().mapToInt(Wrapper::getValue).sum())
    .map(Wrapper::new)
    .collect(Collectors.toList());

此代码假定Wrapper具有接受int值作为参数的构造函数。有关详细信息,请参阅Table.rowMapTable.columnMap文档。

答案 2 :(得分:2)

如果他们可能有帮助(想要分享评论,文本太久),请使用List<Wrapper>类型的resultPerRow和resultPerColumn:

IntStream.range(0, listOfLists.size()).forEach(
            i -> resultPerRow.add(new Wrapper(listOfLists.get(i).stream().mapToInt(Wrapper::getValue).sum())));

IntStream.range(0, listOfLists.stream().mapToInt(List::size).max().orElse(0))
            .forEach(i -> resultPerColumn.add(new Wrapper(IntStream.range(0, listOfLists.size())
                    .map(y -> listOfLists.get(y).size() < y ? 0 : listOfLists.get(y).get(i).getValue())
                    .sum())));

答案 3 :(得分:2)

检查Matrix.straightSum(listOfLists)Matrix.crossSum(listOfLists)方法。

class Matrix {

    public static List<Wrapper> crossSum(List<List<Wrapper>> input){
        return input.stream().reduce(Matrix::sum).orElse(Collections.emptyList());
    }

    public static List<Wrapper> straightSum(List<List<Wrapper>> input){
        return input.stream().map(Matrix::sum).collect(Collectors.toList());
    }

    private static List<Wrapper> sum(List<Wrapper> big, List<Wrapper> small){
        if (big.size() < small.size()){
            return sum(small, big);
        }
        return IntStream.range(0, big.size()).boxed()
                .map(idx -> sum(big, small, idx))
                .collect(Collectors.toList());
    }

    private static Wrapper sum(List<Wrapper> big, List<Wrapper> small, int idx){
        if (idx < small.size()){
            return new Wrapper(big.get(idx).getValue() + small.get(idx).getValue());
        }
        return big.get(idx);
    }

    private static Wrapper sum(List<Wrapper> line){
        return new Wrapper(line.stream().mapToInt(Wrapper::getValue).sum());
    }
}

答案 4 :(得分:1)

使用自定义函数List sum(List,List)

解决方案的两分钱
static class Wrapper
{
    public int value = 0;

    public Wrapper( int value )
    {
        super();
        this.value = value;
    }

    public int getValue()
    {
        return value;
    }

    public void setValue( int value )
    {
        this.value = value;
    }

    @Override
    public String toString()
    {
        return "Wrapper [value=" + value + "]";
    }

}

public static List<Wrapper> sum( List<Wrapper> l1, List<Wrapper> l2 )
{
    int l1Size = l1.size();
    int l2Size = l2.size();
    int size = Math.max( l1.size(), l2.size() );
    List<Wrapper> result = new ArrayList<>( size );
    int v1 = 0;
    int v2 = 0;
    for ( int i = 0; i < size; i++ )
    {
        v1 = i < l1Size && l1.get( i ) != null ? l1.get( i ).getValue() : 0;
        v2 = i < l2Size && l2.get( i ) != null ? l2.get( i ).getValue() : 0;
        result.add( new Wrapper( v1 + v2 ) );
    }
    return result;
}

public static void main( String[] args )
    throws Exception
{

    List<List<Wrapper>> listOfLists = new ArrayList<>();
    List<Wrapper> list1 = new ArrayList<>();
    list1.add( new Wrapper( 3 ) );
    list1.add( new Wrapper( 4 ) );
    list1.add( new Wrapper( 5 ) );

    List<Wrapper> list2 = new ArrayList<>();
    list2.add( new Wrapper( 1 ) );
    list2.add( new Wrapper( 2 ) );
    list2.add( new Wrapper( 9 ) );

    List<Wrapper> list3 = new ArrayList<>();
    list3.add( new Wrapper( 4 ) );
    list3.add( new Wrapper( 10 ) );
    list3.add( new Wrapper( 11 ) );

    listOfLists.add( list1 );
    listOfLists.add( list2 );
    listOfLists.add( list3 );

    List<Wrapper> colSum = listOfLists.stream().reduce( Collections.emptyList(), ( l1, l2 ) -> sum( l1, l2 ) );

    List<Wrapper> rowSum =
        listOfLists.stream().map( list -> list.stream().reduce( new Wrapper( 0 ),
                                                                ( w1, w2 ) -> new Wrapper( w1.value
                                                                    + w2.value ) ) ).collect( Collectors.toList() );
    System.out.println( "Sum by column : " );
    colSum.forEach( System.out::println );

    System.out.println( "Sum by row : " );
    rowSum.forEach( System.out::println );

}