天天看点

Spark源码阅读笔记之Broadcast(二)

Broadcast的Http传输机制是通过HttpBroadcastFactory和HttpBroadcast来实现的。

HttpBroadcastFactory代码:

class HttpBroadcastFactory extends BroadcastFactory {
  override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
    HttpBroadcast.initialize(isDriver, conf, securityMgr)
  }

  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
    new HttpBroadcast[T](value_, isLocal, id)

  override def stop() { HttpBroadcast.stop() }

  /**
   * Remove all persisted state associated with the HTTP broadcast with the given ID.
   * @param removeFromDriver Whether to remove state from the driver
   * @param blocking Whether to block until unbroadcasted
   */
  override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    HttpBroadcast.unpersist(id, removeFromDriver, blocking)
  }
}
           

HttpBroadcastFactory的initialize函数调用HttpBroadcast.initialize函数,主要完成的工作是根据配置获取根目录,并根据根目录启动Http服务,HttpBroadcast.initialize函数代码:

def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
      synchronized {
        if (!initialized) {
          bufferSize = conf.getInt("spark.buffer.size", )
          compress = conf.getBoolean("spark.broadcast.compress", true)
          securityManager = securityMgr
          if (isDriver) {
            createServer(conf)
            conf.set("spark.httpBroadcast.uri",  serverUri)
          }
          serverUri = conf.get("spark.httpBroadcast.uri")
          cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
          compressionCodec = CompressionCodec.createCodec(conf)
          initialized = true
        }
      }
  }
           

HttpBroadcast.initialize调用的HttpBroadcast.createServer函数代码:

private def createServer(conf: SparkConf) {
    broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
    val broadcastPort = conf.getInt("spark.broadcast.port", )
    server =
      new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
    server.start()
    serverUri = server.uri
    logInfo("Broadcast server started at " + serverUri)
  }
           

其中HttpServer类是对Jetty server的封装。

HttpBroadcastFactory的unbroadcast函数调用HttpBroadcast.unpersist函数,主要完成的逻辑是删除各个节点存储的Broadcast,并根据removeFromDriver参数判断是否要删除Http服务根目录下存储的对应的文件,HttpBroadcast.unpersist函数代码:

/**
   * Remove all persisted blocks associated with this HTTP broadcast on the executors.
   * If removeFromDriver is true, also remove these persisted blocks on the driver
   * and delete the associated broadcast file.
   */
  def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
    if (removeFromDriver) {
      val file = getFile(id)
      files.remove(file)
      deleteBroadcastFile(file)
    }
  }
           

调用的HttpBroadcast.deleteBroadcastFile函数代码:

private def deleteBroadcastFile(file: File) {
    try {
      if (file.exists) {
        if (file.delete()) {
          logInfo("Deleted broadcast file: %s".format(file))
        } else {
          logWarning("Could not delete broadcast file: %s".format(file))
        }
      }
    } catch {
      case e: Exception =>
        logError("Exception while deleting broadcast file: %s".format(file), e)
    }
  }
           

分析HttpBroadcast时需要注意两点:1、缓存机制;2、序列化和反序列化机制。先来看HttpBroadcast的代码:

private[spark] class HttpBroadcast[T: ClassTag](
    @transient var value_ : T, isLocal: Boolean, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

  override protected def getValue() = value_

  private val blockId = BroadcastBlockId(id)

  /*
   * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster
   * does not need to be told about this block as not only need to know about this data block.
   */
  HttpBroadcast.synchronized {
    SparkEnv.get.blockManager.putSingle(
      blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
  }

  if (!isLocal) {
    HttpBroadcast.write(id, value_)
  }

  /**
   * Remove all persisted state associated with this HTTP broadcast on the executors.
   */
  override protected def doUnpersist(blocking: Boolean) {
    HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
  }

  /**
   * Remove all persisted state associated with this HTTP broadcast on the executors and driver.
   */
  override protected def doDestroy(blocking: Boolean) {
    HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
  }

  /** Used by the JVM when serializing this object. */
  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    assertValid()
    out.defaultWriteObject()
  }

  /** Used by the JVM when deserializing this object. */
  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    HttpBroadcast.synchronized {
      SparkEnv.get.blockManager.getSingle(blockId) match {
        case Some(x) => value_ = x.asInstanceOf[T]
        case None => {
          logInfo("Started reading broadcast variable " + id)
          val start = System.nanoTime
          value_ = HttpBroadcast.read[T](id)
          /*
           * We cache broadcast data in the BlockManager so that subsequent tasks using it
           * do not need to re-fetch. This data is only used locally and no other node
           * needs to fetch this block, so we don't notify the master.
           */
          SparkEnv.get.blockManager.putSingle(
            blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
          val time = (System.nanoTime - start) / 
          logInfo("Reading broadcast variable " + id + " took " + time + " s")
        }
      }
    }
  }
}
           

HttpBroadcast在Driver中初始化时(调用SparkContext的broadcast函数),调用HttpBroadcast.write函数将Broadcast中的数据写入到Http服务根目录下,供其他的Executor下载。HttpBroadcast.write函数代码

private def write(id: Long, value: Any) {
    val file = getFile(id)
    val fileOutputStream = new FileOutputStream(file)
    try {
      val out: OutputStream = {
        if (compress) {
          compressionCodec.compressedOutputStream(fileOutputStream)
        } else {
          new BufferedOutputStream(fileOutputStream, bufferSize)
        }
      }
      val ser = SparkEnv.get.serializer.newInstance()
      val serOut = ser.serializeStream(out)
      serOut.writeObject(value)
      serOut.close()
      files += file
    } finally {
      fileOutputStream.close()
    }
  }

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
           

HttpBroadcast序列化时不会序列化需要传输的value,而是序列化该Broadcast的Id,然后在反序列化时,readObject根据Id,先调用BlockManager的getSingle函数,在BlockManager中读取该Broadcast的值,若没有则通过Http服务下载该Broadcast对应的文件,然后读取到内存中。每次读取到value后调用BlockManager的putSingle函数将该Broadcast缓存到BlockManager中,注意缓存时会设置tellMaster参数为false,即不通知Master,这样Master不知道该节点存储了该Broadcast,从而其他的Executor无法通过BlockManager来获取Broadcast的值,只有通过Http服务来获取,从而通过BlockManager实现了本地的缓存和用Http服务来远程传输Broadcast的机制。

readObject函数调用HttpBroadcast.read函数,HttpBroadcast.read函数根据Broadcast的Id和Http服务的uri生成该Broadcast对应的url,然后下载文件并读取,代码如下:

private def read[T: ClassTag](id: Long): T = {
    logDebug("broadcast read server: " +  serverUri + " id: broadcast-" + id)
    val url = serverUri + "/" + BroadcastBlockId(id).name

    var uc: URLConnection = null
    if (securityManager.isAuthenticationEnabled()) {
      logDebug("broadcast security enabled")
      val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
      uc = newuri.toURL.openConnection()
      uc.setConnectTimeout(httpReadTimeout)
      uc.setAllowUserInteraction(false)
    } else {
      logDebug("broadcast not using security")
      uc = new URL(url).openConnection()
      uc.setConnectTimeout(httpReadTimeout)
    }
    Utils.setupSecureURLConnection(uc, securityManager)

    val in = {
      uc.setReadTimeout(httpReadTimeout)
      val inputStream = uc.getInputStream
      if (compress) {
        compressionCodec.compressedInputStream(inputStream)
      } else {
        new BufferedInputStream(inputStream, bufferSize)
      }
    }
    val ser = SparkEnv.get.serializer.newInstance()
    val serIn = ser.deserializeStream(in)
    val obj = serIn.readObject[T]()
    serIn.close()
    obj
  }
           

继续阅读