将数据从一个现有行复制到Spark SQL中的另一个现有行

时间:2019-10-07 08:49:47

标签: apache-spark apache-spark-sql

我有一个数据集 A

+---------+---------+-----+
|price    |   status| id  |
+---------+---------+-----+
| null    | offline | 1   |      
|  3.4$   | online  | 2   | 
|  4.4$   | online  | 1   |
| null    | offline | 2   |   
+---------+---------+-----+

我想从 A 创建新的数据集 B ,该数据集将替换处于离线状态价格值>价格在在线状态行中具有相同ID的

我的预期输出是

+---------+---------+-----+
|price    |   status| id  |
+---------+---------+-----+
|  4.4$   | offline | 1   |      
|  3.4$   | online  | 2   | 
|  4.4$   | online  | 1   |
|  3.4$   | offline | 2   |   
+---------+---------+-----+

我怎么能达到同样的目的?

4 个答案:

答案 0 :(得分:1)

我相信您可以通过以下方法实现

from random import randint
from numpy import array
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, LSTM

ngram_order = 3
num_hidlayer = 50

# define vocabulary in one hot encoding.
vocabs = ['0', '45', '90', '135', '180', '225', '270', '315']
onehots = {}
for i, c in enumerate(vocabs):
    onehots[c] = list(to_categorical(i, len(vocabs)))

# define model.
m = Sequential()
m.add(LSTM(num_hidlayer, input_shape=(ngram_order - 1, len(vocabs))))
m.add(Dense(len(vocabs), activation='softmax'))
m.compile(
        loss='categorical_crossentropy', optimizer='adam',
        metrics=['accuracy'])
print(m.summary())

# format data as input and output pairs.
full_inputs = [
    ['0', '0'], ['90', '135'], ['0', '45'], ['0', '45'], ['0', '45'],
    ['180', '180'], ['180', '180'], ['180', '180'], ['180', '180']
]
for j in range(len(full_inputs)):
    for i in range(len(full_inputs[0])):
        full_inputs[j][i] = onehots[str(full_inputs[j][i])]
full_outputs = ['45', '135', '90', '0', '45', '135', '180', '180', '180']
for i in range(len(full_outputs)):
    full_outputs[i] = onehots[str(full_outputs[i])]

# train the model by randomly picking 2 samples at a time.
count = 0
while True:
    i = randint(0, len(full_inputs) - 1)
    j = randint(0, len(full_inputs) - 1)
    inputs = [full_inputs[i], full_inputs[j]]
    outputs = [full_outputs[i], full_outputs[j]]
    m.fit(array(inputs), array(outputs), epochs=3000, verbose=2)
    m.save('output' + str(count) + '.h5')
    count += 1

输出:

    val input_df = List((null, "offline", "1"), ("3.4$", "online", "2"), ("4.4$", "online", "1"), (null, "offline", "2")).toDF("price", "status", "id")
    input_df.createOrReplaceTempView("TABLE1")
    spark.sql("""select case when a.price is null then b.price end as price, a.status,a.id from table1 a inner join table1 b on a.id = b.id where a.status <> b.status and a.price is null 
                 union all
                 select * from table1 where price is not null""").show()

答案 1 :(得分:1)

如果您具有唯一的对(状态,ID),则可以使用Window函数来实现,方法是通过这种方式复制数据

   val dfx: DataFrame = Seq(
      (Some(4), "online", 1),
      (None, "offline", 1),
      (Some(3), "online", 2),
      (None, "offline", 2)
    ).toDF("price", "status", "id")

   dfx.show()
     +-----+-------+---+
     |price| status| id|
     +-----+-------+---+
     |    4| online|  1|
     | null|offline|  1|
     |    3| online|  2|
     | null|offline|  2|
     +-----+-------+---+

    import org.apache.spark.sql.functions.{col, lag, coalesce}
    val windowPrice = Window.partitionBy(col("id")).orderBy("status")
    val dfx1 = dfx.withColumn("price2", lag(col("price"), -1) over windowPrice)
        .withColumn("correctedPrice", coalesce(col("price"), col("price2")))
        .drop("price", "price2")
    dfx1.show()
    ```
     +-------+---+--------------+
     | status| id|correctedPrice|
     +-------+---+--------------+
     |offline|  1|             4|
     | online|  1|             4|
     |offline|  2|             3|
     | online|  2|             3|
     +-------+---+--------------+

答案 2 :(得分:1)

也可以通过自连接在Python Spark SQL中添加答案。

from pyspark.sql.functions import *
from pyspark.sql.types import *

values = [
  (None,"offline",1), 
  ("3.4$","online",2), 
  ("4.4$","online",1), 
  (None,"offline",2)
]

rdd = sc.parallelize(values)
schema = StructType([
    StructField("price", StringType(), True),
    StructField("status", StringType(), True),
    StructField("id", IntegerType(), True)
])

data = spark.createDataFrame(rdd, schema)

data.show(20,False)
data.createOrReplaceTempView("data")

spark.sql("""
select case when a.price is null then b.price else b.price end as price,
       a.status, 
       b.id
from data as a inner join (select * from data where price is not null) b
on a.id = b.id
order by a.id
""").show(20,False)

结果:

+-----+-------+---+
|price|status |id |
+-----+-------+---+
|null |offline|1  |
|3.4$ |online |2  |
|4.4$ |online |1  |
|null |offline|2  |
+-----+-------+---+

+-----+-------+---+
|price|status |id |
+-----+-------+---+
|4.4$ |offline|1  |
|4.4$ |online |1  |
|3.4$ |online |2  |
|3.4$ |offline|2  |
+-----+-------+---+

答案 3 :(得分:1)

假设您有Dataset<Product>个简单的POJO,我用groupByKeyflatMapGroups来解决。想法如下:

  1. 按ID对数据分组
  2. 在每个组中搜索具有在线状态的产品并获取价格
  3. 退回所有价格合适的产品

这是代码:

Dataset<Product> transformed =
            data.groupByKey((MapFunction<Product, Integer>) product -> product.getId(), Encoders.INT())
                    .flatMapGroups(new FlatMapGroupsFunction<Integer, Product, Product>() {
                        @Override
                        public Iterator<Product> call(Integer integer, Iterator<Product> iterator)
                                throws Exception {

                            // get price
                            Double onlinePrice = null;

                            // prepare list to return
                            List<Product> emittedProducts = new ArrayList<>();

                            while (iterator.hasNext()) {
                                Product next = iterator.next();
                                emittedProducts.add(next);
                                if (next.getStatus().equals("online")) {
                                    onlinePrice = next.getPrice();
                                }

                            }

                            Double finalOnlinePrice = onlinePrice;
                            emittedProducts.stream().forEach(p -> p.setPrice(finalOnlinePrice));

                            return emittedProducts.iterator();
                        }
                    }, Encoders.bean(Product.class));

产品仅是POJO:

public static class Product implements Serializable

{
    public Double price;
    private String status;
    private int id;

    public Product(){}
    public Double getPrice() {
        return price;
    }

    public void setPrice(Double price) {
        this.price = price;
    }

    public String getStatus() {
        return status;
    }

    public void setStatus(String status) {
        this.status = status;
    }

    public int getId() {
        return id;
    }

    public void setId(int id) {
        this.id = id;
    }
}

设置:

    List<Product> p = new ArrayList<>();
    Product p1 = new Product();
    p1.setId(1);
    p1.setPrice(null);
    p1.setStatus("offline");

    Product p2 = new Product();
    p2.setId(2);
    p2.setPrice(3.4d);
    p2.setStatus("online");

    Product p3 = new Product();
    p3.setId(1);
    p3.setPrice(4.4d);
    p3.setStatus("online");

    Product p4 = new Product();
    p4.setId(2);
    p4.setPrice(null);
    p4.setStatus("offline");

    p.add(p1);
    p.add(p2);
    p.add(p3);
    p.add(p4);

    final Dataset<Product> data = spark.createDataset(p, Encoders.bean(Product.class));