
14:Spark Streaming源码解读之State管理之updateStateByKey和mapWithState解密

    首先简单解释一下 什么是state(状态)管理?我们以wordcount为例。每个batchInterval会计算当前batch的单词计数,那如果需要计算从流开始到目前为止的单词出现的次数,该如计算呢?SparkStreaming提供了两种方法:updateStateByKey和mapWithState 。mapWithState 是1.6版本新增功能,目前属于实验阶段。mapWithState具官方说性能较updateStateByKey提升10倍。那么我们来看看他们到底是如何实现的。

一、updateStateByKey 解析 1.1 updateStateByKey 的使用实例 首先看一个 updateStateByKey函数使用的例子:

object UpdateStateByKeyDemo {              def main(args: Array[String]) {              val conf = new SparkConf().setAppName("UpdateStateByKeyDemo")              val ssc = new StreamingContext(conf,Seconds(20))              //要使用updateStateByKey方法,必须设置Checkpoint。              ssc.checkpoint("/checkpoint/")              val socketLines = ssc.socketTextStream("localhost",9999)                  socketLines.flatMap(_.split(",")).map(word=>(word,1))              .updateStateByKey(                 (currValues:Seq[Int],preValue:Option[Int]) =>{
                  val currValue = currValues.sum //将目前值相加              Some(currValue + preValue.getOrElse(0)) //目前值的和加上历史值              }).print()                  ssc.start()              ssc.awaitTermination()              ssc.stop()                  }              }           


1.2  updateStateByKey 方法源码分析

     我们知道map返回的是MappedDStream,而MappedDStream并没有updateStateByKey方法,并且它的父类DStream中也没有该方法。 但是DStream的伴生对象中有一个隐式转换函数

implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])              (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):              PairDStreamFunctions[K, V] = {              new PairDStreamFunctions[K, V](stream)              }           

PairDStreamFunction 中updateStateByKey的源码如下:

def updateStateByKey[S: ClassTag](              updateFunc: (Seq[V], Option[S]) => Option[S]              ): DStream[(K, S)] = ssc.withScope {              updateStateByKey(updateFunc, defaultPartitioner())              }           

其中 updateFunc就要传入的参数,他是一个函数, Seq [ V ] 表示当前key对应的所有值,Option[S] 是当前key的历史状态,返回的是新的状态。


def updateStateByKey[S: ClassTag](              updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],              partitioner: Partitioner,              rememberPartitioner: Boolean              ): DStream[(K, S)] = ssc.withScope {              new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)              }           

在这里面new出了一个StateDStream对象。在其compute方法中,会先获取上一个batch计算出的RDD(包含了至程序开始到上一个batch单词的累计计数),然后在获取本次batch中StateDStream的父类计算出的RDD(本次batch的单词计数)分别是prevStateRDD和parentRDD,然后在调用  computeUsingPreviousRDD 方法:

private [this] def computeUsingPreviousRDD (              parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = {              // Define the function for the mapPartition operation on cogrouped RDD;              // first map the cogrouped tuple to tuples of required type,              // and then apply the update function              val updateFuncLocal = updateFunc              val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {              val i = iterator.map { t =>              val itr = t._2._2.iterator              val headOption = if (itr.hasNext) Some(itr.next()) else None              (t._1, t._2._1.toSeq, headOption)              }              updateFuncLocal(i)              }              val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)              val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)              Some(stateRDD)              }           


二、mapWithState方法解析 2.1 mapWithState方法使用实例:

object StatefulNetworkWordCount {              def main(args: Array[String]) {              if (args.length < 2) {              System.err.println("Usage: StatefulNetworkWordCount <hostname> <port>")              System.exit(1)              }                  StreamingExamples.setStreamingLogLevels()                  val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")              // Create the context with a 1 second batch size              val ssc = new StreamingContext(sparkConf, Seconds(1))              ssc.checkpoint(".")                  // Initial state RDD for mapWithState operation              val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))                  // Create a ReceiverInputDStream on target ip:port and count the              // words in input stream of \n delimited test (eg. generated by 'nc')              val lines = ssc.socketTextStream(args(0), args(1).toInt)              val words = lines.flatMap(_.split(" "))              val wordDstream = words.map(x => (x, 1))                  // Update the cumulative count using mapWithState              // This will give a DStream made of state (which is the cumulative count of the words)              val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {              val sum = one.getOrElse(0) + state.getOption.getOrElse(0)              val output = (word, sum)              state.update(sum)              output              }                  val stateDstream = wordDstream.mapWithState(              StateSpec.function(mappingFunc).initialState(initialRDD))              stateDstream.print()              ssc.start()              ssc.awaitTermination()              }              }           



def mapWithState[StateType: ClassTag, MappedType: ClassTag](              spec: StateSpec[K, V, StateType, MappedType]              ): MapWithStateDStream[K, V, StateType, MappedType] = {              new MapWithStateDStreamImpl[K, V, StateType, MappedType](              self,              spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]              )              }           

MapWithStateDStreamImpl 中创建了一个 InternalMapWithStateDStream类型对象 internalStream,在 MapWithStateDStreamImpl的compute方法中调用了 internalStream的getOrCompute方法。

/** Internal implementation of the [[MapWithStateDStream]] */              private[streaming] class MapWithStateDStreamImpl[              KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](              dataStream: DStream[(KeyType, ValueType)],              spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])              extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {                  private val internalStream =              new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)                  override def slideDuration: Duration = internalStream.slideDuration                  override def dependencies: List[DStream[_]] = List(internalStream)                  override def compute(validTime: Time): Option[RDD[MappedType]] = {              internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }              }           

InternalMapWithStateDStream 中没有getOrCompute方法,这里调用的是其父类 DStream 的getOrCpmpute方法,该方法中最终会调用 InternalMapWithStateDStream的Compute方法:

/** Method that generates a RDD for the given time */              override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {              // Get the previous state or create a new empty state RDD              val prevStateRDD = getOrCompute(validTime - slideDuration) match {              case Some(rdd) =>              if (rdd.partitioner != Some(partitioner)) {              // If the RDD is not partitioned the right way, let us repartition it using the              // partition index as the key. This is to ensure that state RDD is always partitioned              // before creating another state RDD using it              MapWithStateRDD.createFromRDD[K, V, S, E](              rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)              } else {              rdd              }              case None =>              MapWithStateRDD.createFromPairRDD[K, V, S, E](              spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),              partitioner,              validTime              )              }                      // Compute the new state RDD with previous state RDD and partitioned data RDD              // Even if there is no data RDD, use an empty one to create a new state RDD              val dataRDD = parent.getOrCompute(validTime).getOrElse {              context.sparkContext.emptyRDD[(K, V)]              }              val partitionedDataRDD = dataRDD.partitionBy(partitioner)              val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>              (validTime - interval).milliseconds              }              Some(new MapWithStateRDD(              prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))              }           

根据给定的时间生成一个MapWithStateRDD,首先获取了先前状态的RDD:preStateRDD和当前时间的RDD:dataRDD,然后对dataRDD基于先前状态RDD的分区器进行重新分区获取partitionedDataRDD。最后将 preStateRDD, partitionedDataRDD和用户定义的函数mappingFunction传给新生成的 MapWithStateRDD对象返回。 下面看一下 MapWithStateRDD的compute方法:

override def compute(              partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {                  val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]              val prevStateRDDIterator = prevStateRDD.iterator(              stateRDDPartition.previousSessionRDDPartition, context)              val dataIterator = partitionedDataRDD.iterator(              stateRDDPartition.partitionedDataRDDPartition, context)                
     //prevRecord 代表一个分区的数据
                  val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None              val newRecord = MapWithStateRDDRecord.updateRecordWithData(              prevRecord,              dataIterator,              mappingFunction,              batchTime,              timeoutThresholdTime,              removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled              )              Iterator(newRecord)              }           

MapWithStateRDDRecord 对应 MapWithStateRDD  的一个分区:

private[streaming] case class MapWithStateRDDRecord[K, S, E](              var stateMap: StateMap[K, S], var mappedData: Seq[E])           

其中stateMap存储了key的状态,mappedData存储了mapping function函数的返回值 看一下 MapWithStateRDDRecord的 updateRecordWithData方法

def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](              prevRecord: Option[MapWithStateRDDRecord[K, S, E]],              dataIterator: Iterator[(K, V)],              mappingFunction: (Time, K, Option[V], State[S]) => Option[E],              batchTime: Time,              timeoutThresholdTime: Option[Long],              removeTimedoutData: Boolean              ): MapWithStateRDDRecord[K, S, E] = {                  // 创建一个新的 state map 从过去的Recoord中复制 (如果存在) 否则创建一下空的StateMap对象              val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }                  val mappedData = new ArrayBuffer[E]                
              val wrappedState = new StateImpl[S]()                  // Call the mapping function on each record in the data iterator, and accordingly              // update the states touched, and collect the data returned by the mapping function              dataIterator.foreach { case (key, value) =>                  //获取key对应的状态              wrappedState.wrap(newStateMap.get(key))                
              val returned = mappingFunction(batchTime, key, Some(value), wrappedState)                  //维护                
              if (wrappedState.isRemoved) {              newStateMap.remove(key)              } else if (wrappedState.isUpdated              || (wrappedState.exists && timeoutThresholdTime.isDefined)) {              newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)              }              mappedData ++= returned              }                  // Get the timed out state records, call the mapping function on each and collect the              // data returned              if (removeTimedoutData && timeoutThresholdTime.isDefined) {              newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>              wrappedState.wrapTimingOutState(state)              val returned = mappingFunction(batchTime, key, None, wrappedState)              mappedData ++= returned              newStateMap.remove(key)              }              }                  MapWithStateRDDRecord(newStateMap, mappedData)              }           

最终返回 MapWithStateRDDRecord对象交个 MapWithStateRDD的compute函数, MapWithStateRDD的compute函数将其封装成Iterator返回。
