Spark Scala:将子类型传递给接受父类型的函数

时间:2016-11-09 21:50:03

标签: scala oop apache-spark rdd

假设我有一个抽象类A。我还有班级BC,它们继承自班级A

abstract class A {
  def x: Int
}
case class B(i: Int) extends A {
  override def x = -i
}
case class C(i: Int) extends A {
  override def x = i
}

鉴于这些类,我构建了以下RDD:

val data = sc.parallelize(Seq(
      Set(B(1), B(2)),
      Set(B(1), B(3)),
      Set(B(1), B(5))
    )).cache
      .zipWithIndex
      .map {case(k, v) => (v, k)}

我还有以下函数将RDD作为输入并返回每个元素的计数:

def f(data: RDD[(Long, Set[A])]) = {
  data.flatMap({
    case (k, v) => v map { af =>
      (af, 1)
    }
  }).reduceByKey(_ + _)
}

请注意,RDD正在接受类型A。现在,我希望val x = f(data)按预期返回计数,因为BA的子类型,但我收到以下编译错误:

type mismatch;
 found   : org.apache.spark.rdd.RDD[(Long, scala.collection.immutable.Set[B])]
 required: org.apache.spark.rdd.RDD[(Long, Set[A])]
    val x = f(data)

如果我将函数签名更改为f(data: RDD[(Long, Set[B])]),则此错误消失;但是,我不能这样做,因为我想在RDD中使用其他子类(如C)。

我也尝试了以下方法:

def f[T <: A](data: RDD[(Long, Set[T])]) = {
  data.flatMap({
    case (k, v) => v map { af =>
      (af, 1)
    }
  }) reduceByKey(_ + _)
}

但是,这也给了我以下运行时错误:

value reduceByKey is not a member of org.apache.spark.rdd.RDD[(T, Int)]
possible cause: maybe a semicolon is missing before `value reduceByKey'?
      }) reduceByKey(_ + _)

我很感激你的帮助。

1 个答案:

答案 0 :(得分:2)

Set[T]T上不变,这意味着A B的{​​{1}}子类型不是Set[A]的子类型,也不是Set[B]的超类型} RDD[T] TCollection[+T]进一步限制选项时也是不变的,因为即使使用了协变List[+T](例如ClassTag),也会出现同样的情况。

我们可以使用该方法的多态形式来替代: 上面版本中缺少的是import scala.reflect.{ClassTag} def f[T:ClassTag](data: RDD[(Long, Set[T])]) = { data.flatMap({ case (k, v) => v map { af => (af, 1) } }) reduceByKey(_ + _) } ,Spark需要在擦除后保留类信息。

这应该有效:

val intRdd = sparkContext.parallelize(Seq((1l, Set(1,2,3)), (2L, Set(4,5,6))))
val res1= f(intRdd).collect
// Array[(Int, Int)] = Array((4,1), (1,1), (5,1), (6,1), (2,1), (3,1))

val strRdd = sparkContext.parallelize(Seq((1l, Set("a","b","c")), (2L, Set("d","e","f"))))
val res2 = f(strRdd).collect
// Array[(String, Int)] = Array((d,1), (e,1), (a,1), (b,1), (f,1), (c,1))

让我们看看:

  <%= form_for(user) do |f| %>
  <% if user.errors.any? %>
    <div id="error_explanation">
      <h2><%= pluralize(user.errors.count, "error") %> prohibited this user from being saved:</h2>
      <ul>
      <% user.errors.full_messages.each do |message| %>
        <li><%= message %></li>
      <% end %>
      </ul>
    </div>
  <% end %>

  <div class="field">
    <%= f.label :first_name %>
    <%= f.text_field :first_name %>
  </div>

  <div class="field">
    <%= f.label :last_name %>
    <%= f.text_field :last_name %>
  </div>

  <div class="field">
    <%= f.label :email %>
    <%= f.email_field :email %>
  </div>

  <div class="field">
    <%= f.label :display_name %>
    <%= f.text_field :display_name %>
  </div>
  <div class="field">
    <%= f.label :password %>
    <%= f.password_field :password %>
  </div>
  <div class="field">
    <%= f.label :password_confirmation, "Confirmation" %>
    <%= f.password_field :password_confirmation %>
  </div>
  <div class="actions">
    <%= f.submit %>
  </div>
<% end %>