spark-scala中的库存利润计算

时间:2016-09-15 07:20:50

标签: scala apache-spark

5元组(PRODUCT_ID, TRANSACTION_TYPE, QUANTITY, PRICE, DATE)的表格。 Transaction_Type可以是"购买"或"出售"。 对于Quantity上显示的PriceDate是购买或出售产品的实例数。

销售的产品与已经存在的库存相抵消,这也是该库存的最早实例。

净利润的计算方法是将已售出的库存与最早的买入库存相抵消,如果没有完全解决,则使用下一个买入库存,依此类推。

例如,请考虑以下表值:

1, Buy, 10, 100.0, Jan 1

2, Buy, 20, 200.0, Jan 2

1, Buy, 15, 150.0, Jan 3

1, Sell, 5, 120.0, Jan 5

1, Sell, 10, 125.0, Jan 6 

HDFS上已存有数百个文件,具有上面显示的架构。

然后利润计算应该如下工作:

  • 当产品1于1月5日出售时,这5个单位应抵消 1月1日买入交易首先(导致利润为 5 *(120.0-100.0))。
  • 然后,当产品1在1月6日进一步销售时 出售的单位数量超过1月1日的剩余数量。1月3日的买入数量可以考虑其余部分。
  • 即1月6日销售产品1的利润为5 *(125.0-100.0)+ 5 *(125.00-150.0)。
  • 因此,1月6日交易的利润值为= 5 *(25)+ 5 *( - 25)= 125 - 125 = 0。 截至1月6日的净利润为100(从1月5日交易)+ 0(从1月6日交易)= 100。
  • 计算该数据中最后一个日期的最终利润。

以下是代码段。但它无法获得NullPointer异常。还有更好的建议吗?

import org.apache.spark.SparkContext._
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.rdd._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row


case class Inventory(PRODUCT_ID: Int, TRANSACTION_TYPE: String, QUANTITY: Long, PRICE: Double, DATE: String)

object MatchingInventory{
    def main(args:Array[String])= {

        val conf = new SparkConf().setAppName("XYZ")
        val sc = new SparkContext(conf)


        val sqlcontext = new SQLContext(sc)
        // Create a schema RDD of Inventory objects from the data that has any number of text file.
        import sqlcontext.implicits._
        val dfInvent= sc.textFile("Invent.txt")
        .map(_.split(","))
        .map(p => Inventory(p(0).trim.toInt, p(1).trim, p(2).trim.toLong, p(3).trim.toDouble, p(4).trim))
        .toDF().cache()
        dfInvent.show()

        val idDF =  dfInvent.map{row => row.getInt(0)}.distinct 
        //idDF.show()
        val netProfit = sc.accumulator(0.0)
        idDF.foreach{id =>
        val sellDF = dfInvent.filter((dfInvent("PRODUCT_ID").contains(id)) && (dfInvent("TRANSACTION_TYPE").contains("Sell")))
        val buyDF = dfInvent.filter((dfInvent("PRODUCT_ID").contains(id)) && (dfInvent("TRANSACTION_TYPE").contains("Buy")))    
         var soldQ:Long = sellDF.map{row => row.getLong(2)}.reduce(_+_) 
         var sellPrice:Double = sellDF.map{row => row.getLong(2)*row.getDouble(3)}.reduce(_+_) //reduce sends the result back to driver
         var profit:Double = 0.0
         // profit for each bought item
         buyDF.foreach{row => 
                           if((soldQ > 0) && (soldQ < row.getLong(2))){profit += sellPrice -(soldQ*row.getDouble(3));soldQ = 0}
                           else if((soldQ > 0) && (soldQ > row.getLong(2))){profit += sellPrice - (row.getLong(2)*row.getDouble(3));soldQ = soldQ - row.getLong(2)}
                                else{}} 
        netProfit += profit}
        println("Inventory net Profit" + netProfit)
    }

}

2 个答案:

答案 0 :(得分:0)

我试过这样的事情。这是一个可行的代码,唯一的问题是我在后期使用collect来进行买卖之间的同步,这将导致大数据的内存问题。

from pyspark.sql import  SQLContext
from pyspark import SparkConf
from pyspark import SparkContext
import sys
from pyspark.sql.functions import *

if __name__ == "__main__":

    sc = SparkContext()

    sqlContext = SQLContext(sc)
    df = sqlContext.read.format('com.databricks.spark.csv').options(header='false', inferschema='true').load('test.csv')

    df = df.withColumn("C1", ltrim(df.C1))

    df.registerTempTable("tempTable")
    df = sqlContext.sql("select * from tempTable order by C0")

    dt = df.map(lambda s: (str(s[0])+'-'+ s[1], str(s[2]) + ',' +str(s[3])))
    dt = dt.reduceByKey(lambda a, b : a + '-' + b)

    ds = dt.collect()

    dicTran = {}
    for x in ds:
        key = (x[0].split('-'))[0]
        tratype = (x[0].split('-'))[1]


        val = {}
        if key in dicTran:
            val = dicTran[key]

        val[tratype] = x[1]
        dicTran[key] = val

    profit = 0

    for key, value in dicTran.iteritems():
        if 'Sell' in value:
            buy = value['Buy']
            sell = value['Sell']

            ls = sell.split('-')
            sellAmount = 0
            sellquant = 0
            for x in ls:
                y = x.split(',')
                sellAmount= sellAmount + float(y[0]) * float(y[1])
                sellquant = sellquant + float(y[0])

            lb = buy.split('-')
            for x in lb:
                y = x.split(',')

                if float(y[0]) >= sellquant:
                    profit += sellAmount - sellquant * float(y[1])
                else:
                    sellAmount -= float(y[0]) * float(y[1])
                    sellquant -= float(y[0])

    print 'profit', profit    



    #

这是我认为的逻辑

1)对于所有相同的ID和交易类型,我通过分隔符汇总数量和价格 2)然后我收集并拆分它们以计算利润

我知道这会在大型数据集上崩溃,因为使用了collect但却无法做到更好。我也会尝试你的解决方案。

答案 1 :(得分:0)

所以我在这里提出了一个解决方案

import org.apache.spark.SparkContext._
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.rdd._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import java.text.SimpleDateFormat
import java.sql.Date
import scala.math.Ordering


//Defining Schema
case class Inventory(PRODUCT_ID: Int, TRANSACTION_TYPE: String, QUANTITY: Long, PRICE: Double, pDate:java.sql.Date)


object MatchingInventory{
    def main(args:Array[String])= {

        val conf = new SparkConf().setAppName("XYZ")
        val sc = new SparkContext(conf)


        val sqlcontext = new SQLContext(sc)

        import sqlcontext.implicits._

        val format = new SimpleDateFormat("MMM d")
        //Read data from directory which has multple files
        val dfInvent= sc.textFile("data/*.txt")
        .map(_.split(","))
        .map(p => Inventory(p(0).trim.toInt, p(1).trim, p(2).trim.toLong, p(3).trim.toDouble, new Date(format.parse(p(4)).getTime)))
        .cache()

        def calculateProfit(data:Iterable[Inventory]):Double  = {
            var soldQ:Long = 0
            var sellPrice:Double = 0
            var profit:Double = 0
            val v = data

            for(i <- v ){
                if(i.TRANSACTION_TYPE == "Sell")
                {
                  soldQ = soldQ + i.QUANTITY
                  profit = profit+ i.PRICE*i.QUANTITY

                }
            }

            for(i <- v){
                if(i.TRANSACTION_TYPE == "Buy")
                {
                    if((soldQ > 0) && (soldQ < i.QUANTITY || soldQ == i.QUANTITY)){profit = profit -(soldQ*i.PRICE);soldQ = 0}
                    else if((soldQ > 0) && (soldQ > i.QUANTITY)){profit = profit - (i.QUANTITY*i.PRICE);soldQ = soldQ - i.QUANTITY}
                    else{}
                }
            }
           profit
        }

        val key: RDD[((Int), Iterable[Inventory])] = dfInvent.keyBy(r => (r.PRODUCT_ID)).groupByKey
        val values: RDD[((Int), List[Inventory])] = key.mapValues(v => v.toList.sortBy(_.pDate.getTime))


        val pro = values.map{ case(k,v) => (k, calculateProfit(v))}
        val netProfit = pro.map{ case(k,v) => v}.reduce(_+_)
        println("Inventory NetProfit" + netProfit)

    }