我们如何在flatmap函数中访问SQLContext?

时间:2015-07-28 14:59:02

标签: apache-spark

当使用单例类在下面访问时,SQlContext在本地模式下工作正常,但是当提交spark master时,它变为null并抛出nullpointer异常。怎么解决这个问题? 在我们的用例中,FlatMapFunction需要查询另一个DStream,返回的结果用于创建新流。

已扩展JavaStatefulNetworkWordCount示例以打印对状态的更改。我需要使用sqlcontext从另一个dstream中的有状态dstream访问rdds以创建另一个dstream。如何实现这一目标?

import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;

import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

import scala.Tuple2;

import com.google.common.base.Optional;
import com.google.common.collect.Lists;

public class JavaStatefulNetworkWordCount {
  private static final Pattern SPACE = Pattern.compile(" ");

  public static void main(String[] args) {
    if (args.length < 2) {
      System.err.println("Usage: JavaStatefulNetworkWordCount <hostname> <port>");
      System.exit(1);
    }

    // Update the cumulative count function
    final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> updateFunction =
        new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
          @Override
          public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {

            Integer newSum = state.or(0);
            for (Integer value : values) {
              newSum += value;
            }
            return Optional.of(newSum);
          }
        };

    // Create the context with a 1 second batch size
    SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount");
//    sparkConf.setMaster("local[5]");
//  sparkConf.set("spark.executor.uri", "target/rkspark-0.0.1-SNAPSHOT.jar");
    JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
    ssc.checkpoint(".");
    SQLContext sqlContext = JavaSQLContextSingleton.getInstance(ssc.sparkContext().sc());
    // Initial RDD input to updateStateByKey
    List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
            new Tuple2<String, Integer>("world", 1));
    JavaPairRDD<String, Integer> initialRDD = ssc.sc().parallelizePairs(tuples);

    JavaReceiverInputDStream<String> lines = ssc.socketTextStream(
            args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2);

    JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
      @Override
      public Iterable<String> call(String x) {
        return Lists.newArrayList(SPACE.split(x));
      }
    });

    JavaPairDStream<String, Integer> wordsDstream = words.mapToPair(
        new PairFunction<String, String, Integer>() {
          @Override
          public Tuple2<String, Integer> call(String s) {
            return new Tuple2<String, Integer>(s, 1);
          }
        });

    // This will give a Dstream made of state (which is the cumulative count of the words)
    JavaPairDStream<String, Integer> stateDstream = wordsDstream.updateStateByKey(updateFunction,
            new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD);
    JavaDStream<WordCount> countStream = stateDstream.map(new Function<Tuple2<String, Integer>, WordCount>(){
        @Override
        public WordCount call(Tuple2<String, Integer> v1) throws Exception {
            return new WordCount(v1._1,v1._2);
        }});  
    countStream.foreachRDD(new Function<JavaRDD<WordCount>,Void>() {
        @Override
        public Void call(JavaRDD<WordCount> rdd) {
          SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context());
          DataFrame wordsDataFrame = sqlContext.createDataFrame(rdd, WordCount.class);
          wordsDataFrame.registerTempTable("words");
          return null;
        }
      });
    wordsDstream.map(new Function<Tuple2<String,Integer>,String>(){

        @Override
        public String call(Tuple2<String, Integer> v1) throws Exception {
            // Below sql context becomes null when run on a master instead of local.            
            SQLContext sqlContext = JavaSQLContextSingleton.getInstance();
            DataFrame counterpartyIds = sqlContext.sql("select * from words where word ='"+v1._1()+"'");
            Row[] rows = counterpartyIds.cache().collect();
            if(rows.length>0){
                Row row = rows[0];
                return row.getInt(0)+"-"+ row.getString(1);
            } else {
                return "";
            }
        }
    }).print();
    ssc.start();
    ssc.awaitTermination();
  }
}
class JavaSQLContextSingleton {
  static private transient SQLContext instance = null;
  static public SQLContext getInstance(SparkContext sparkContext) {
    if (instance == null) {
      instance = new SQLContext(sparkContext);
    }
    return instance;
  }
}
   import java.io.Serializable;
    public class WordCount implements Serializable{
    public String getWord() {
        return word;
    }
    public void setWord(String word) {
        this.word = word;
    }
    public int getCount() {
        return count;
    }
    public void setCount(int count) {
        this.count = count;
    }
    String word;
    public WordCount(String word, int count) {
        super();
        this.word = word;
        this.count = count;
    }
    int count;
   } 

1 个答案:

答案 0 :(得分:3)

SparkContext(以及SQLContext)仅在驱动程序中可用,而不是序列化到Workers。您的程序在本地工作,因为它在上下文可用的驱动程序的上下文中运行。