欢迎关注公众号“Tim在路上”
在Spark中shuffleWriter有三种实现,分别是bypassMergeSortShuffleWriter, UnsafeShuffleWriter和SortShuffleWriter。但是shuffleReader却只有一种实现BlockStoreShuffleReader
。
从上一讲中可以知道,这时Spark已经获取到了shuffle元数据包括每个mapId和其location信息,并将其传递给BlockStoreShuffleReader类。接下来我们来详细分析下BlockStoreShuffleReader的实现。
// BlockStoreShuffleReader
override def read(): Iterator[Product2[K, C]] = {
// [1] 初始化ShuffleBlockFetcherIterator,负责从executor中获取 shuffle 块
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
blockManager,
mapOutputTracker,
blocksByAddress,
...
readMetrics,
fetchContinuousBlocksInBatch).toCompletionIterator
val serializerInstance = dep.serializer.newInstance()
// [2] 将shuffle 块反序列化为record迭代器
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
// [3] reduce端聚合数据:如果map端已经聚合过了,则对读取到的聚合结果进行聚合。如果map端没有聚合,则针对未合并的<k,v>进行聚合。
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// [4] reduce端排序数据:如果需要对key排序,则进行排序。基于sort的shuffle实现过程中,默认只是按照partitionId排序。在每一个partition内部并没有排序,因此添加了keyOrdering变量,提供是否需要对分区内部的key排序
// Sort the output if there is a sort ordering defined.
val resultIter: Iterator[Product2[K, C]] =dep.keyOrdering match {
caseSome(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering =Some(keyOrd), serializer =dep.serializer)
sorter.insertAllAndUpdateMetrics(aggregatedIter)
case None =>
aggregatedIter
}
// [5] 返回结果集迭代器
resultIter match {
case _: InterruptibleIterator[Product2[K, C]] => resultIter
case _ =>
// Use another interruptible iterator here to support task cancellation as aggregator
// or(and) sorter may have consumed previous interruptible iterator.
new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
}
从上面可见,在BlockStoreShuffleReader.read()读取数据有五步:
- [1] 初始化ShuffleBlockFetcherIterator,负责从executor中获取 shuffle 块
- [2] 将shuffle 块反序列化为record迭代器
- [3] reduce端聚合数据:如果map端已经聚合过了,则对读取到的聚合结果进行聚合。如果map端没有聚合,则针对未合并的<k,v>进行聚合。
- [4] reduce端排序数据:如果需要对key排序,则进行排序。基于sort的shuffle实现过程中,默认只是按照partitionId排序。在每一个partition内部并没有排序,因此添加了keyOrdering变量,提供是否需要对分区内部的key排序
- [5] 返回结果集迭代器
下面我们详细分析下ShuffleBlockFetcherIterator是如何进行fetch数据的
ShuffleBlockFetcherIterator是如何进行fetch数据的?
当shuffle reader创建 ShuffleBlockFetcherIterator 的实例时,迭代器调用在其initialize()方法。
// ShuffleBlockFetcherIterator
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(onCompleteCallback)
// Local blocks to fetch, excluding zero-sized blocks.
val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
val hostLocalBlocksByExecutor =
mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
// [1] 划分数据源的请求:本地、主机本地和远程块
// Partition blocks by the different fetch modes: local, host-local, push-merged-local and
// remote blocks.
val remoteRequests = partitionBlocksByFetchMode(
blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks)
// [2] 以随机顺序将远程请求添加到我们的队列中
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert((0 ==reqsInFlight) == (0 ==bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " +reqsInFlight+
", expected bytesInFlight = 0 but found bytesInFlight = " +bytesInFlight)
// [3] 发送remote fetch请求
// Send out initial requests for blocks, up to our maxBytesInFlight
fetchUpToMaxBytes()
val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum
val numFetches = remoteRequests.size -fetchRequests.size - numDeferredRequest
logInfo(s"Started$numFetches remote fetches in${Utils.getUsedTimeNs(startTimeNs)}" +
(if (numDeferredRequest > 0 ) s", deferred$numDeferredRequest requests" else ""))
// [4] 支持executor获取local和remote的merge shuffle数据
// Get Local Blocks
fetchLocalBlocks(localBlocks)
logDebug(s"Got local blocks in${Utils.getUsedTimeNs(startTimeNs)}")
// Get host local blocks if any
fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks)
}
在shuffle fetch的迭代器中,获取数据请求有下面四步:
- [1] 通过不同的获取模式对块进行分区:本地、主机本地和远程块
- [2] 以随机顺序将远程请求添加到我们的队列中
- [3] 发送remote fetch请求
- [4] 获取local blocks
- [5] 获取host blocks
- [6] 获取pushMerge的local blocks
划分数据源的请求
private[this] def partitionBlocksByFetchMode(
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
...
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
val localExecIds =Set(blockManager.blockManagerId.executorId, fallback)
for ((address, blockInfos) <- blocksByAddress) {
checkBlockSizes(blockInfos)
// [1] 如果是push-merged blocks, 判断其是否是主机的还是远程请求
if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
// These are push-merged blocks or shuffle chunks of these blocks.
if (address.host == blockManager.blockManagerId.host) {
numBlocksToFetch+= blockInfos.size
pushMergedLocalBlocks ++= blockInfos.map(_._1)
pushMergedLocalBlockBytes += blockInfos.map(_._2).sum
} else {
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
// [2] 如果是localexecIds, 放入localBlocks
} else if (localExecIds.contains(address.executorId)) {
val mergedBlockInfos =mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info =>FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
numBlocksToFetch+= mergedBlockInfos.size
localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
localBlockBytes += mergedBlockInfos.map(_.size).sum
// [3] 如果是host本地,并将其放入hostLocalBlocksByExecutor
} else if (blockManager.hostLocalDirManager.isDefined &&
address.host == blockManager.blockManagerId.host) {
val mergedBlockInfos =mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info =>FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
numBlocksToFetch+= mergedBlockInfos.size
val blocksForAddress =
mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
hostLocalBlocksByExecutor += address -> blocksForAddress
numHostLocalBlocks += blocksForAddress.size
hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
// [4] 如果是remote请求,收集fetch请求, 每个请求的最大请求数据大小,是max(maxBytesInFlight / 5, 1L),这是为了提高请求的并发度,保证至少向5个不同的节点发送请求获取数据,最大限度地利用各节点的资源
} else {
val (_, timeCost) = Utils.timeTakenMs[Unit] {
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
logDebug(s"Collected remote fetch requests for$address in$timeCost ms")
}
}
val (remoteBlockBytes, numRemoteBlocks) =
collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
pushMergedLocalBlockBytes
val blocksToFetchCurrentIteration =numBlocksToFetch- prevNumBlocksToFetch
...
this.hostLocalBlocks++= hostLocalBlocksByExecutor.values
.flatMap { infos => infos.map(info => (info._1, info._3)) }
collectedRemoteRequests
}
- [1] 如果是push-merged blocks, 判断其是否是主机的还是远程请求
- [2] 如果是localexecIds, 放入localBlocks
- [3] 如果是host本地,并将其放入hostLocalBlocksByExecutor
- [4] 如果是remote请求,收集fetch请求, 每个请求的最大请求数据大小,是max(maxBytesInFlight / 5, 1L),这是为了提高请求的并发度,保证至少向5个不同的节点发送请求获取数据,最大限度地利用各节点的资源
在划分完数据的请求类别后,会依次的进行remote fetch请求,local blocks请求,host blocks请求和获取pushMerge的local blocks。
那么数据是如何被Fetch的呢?接下来我们看下fetchUpToMaxBytes()方法。
private def fetchUpToMaxBytes(): Unit = {
// [1] 如果是延迟请求,如果可以远程块Fetch同时是未达到处理请求的字节数,进行send请求
if (deferredFetchRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <-deferredFetchRequests) {
while (isRemoteBlockFetchable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred fetch request for$remoteAddress with "
+ s"${request.blocks.length} blocks")
send(remoteAddress, request)
if (defReqQueue.isEmpty) {
deferredFetchRequests-= remoteAddress
}
}
}
}
// [2] 如果正常可以远程Fetch请求,直接send请求;如果达到处理请求的字节,则创建remoteAddress的延迟请求
// Process any regular fetch requests if possible.
while (isRemoteBlockFetchable(fetchRequests)) {
val request = fetchRequests.dequeue()
val remoteAddress = request.address
if (isRemoteAddressMaxedOut(remoteAddress, request)) {
logDebug(s"Deferring fetch request for$remoteAddress with${request.blocks.size} blocks")
val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
deferredFetchRequests(remoteAddress) = defReqQueue
} else {
send(remoteAddress, request)
}
}
}
Fetch请求字节数据:
- [1] 如果是延迟请求,如果可以远程块Fetch同时是未达到处理请求的字节数,进行send请求
- [2] 如果正常可以远程Fetch请求,直接send请求;如果达到处理请求的字节,则创建remoteAddress的延迟请求
它会验证该请求是否应被视为延迟。如果是,则将其添加到deferredFetchRequests中。否则,它会继续并从BlockStoreClient实现发送请求(如果启用了 shuffle 服务,则为ExternalBlockStoreClient ,否则为NettyBlockTransferService)。
// ShuffleBlockFetcherIterator
private[this] def sendRequest(req: FetchRequest): Unit = {
// ...
// [1] 创建了一个**BlockFetchingListener**,在完成请求后会被调用
val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// ...
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
// ...
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
}
}
// Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly.
// [2] 如果请求大小超过可以存储在内存中的请求的最大大小 ,则迭代器通过可选地定义DownloadFileManager来发送获取请求
if (req.size > maxReqSizeShuffleToMem) {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
在sendRequest中主要进行了以下两个步骤:
- [1] 创建了一个BlockFetchingListener,在完成请求后会被调用
- [2] 如果请求大小超过可以存储在内存中的请求的最大大小 ,则迭代器通过可选地定义DownloadFileManager来发送获取请求
首先,ShuffleBlockFetcherIterator迭代器创建了一个BlockFetchingListener,在其中定义成功执行和实现执行后的回调函数,如果成功执行,它会首先为迭代器加synchronized锁,然后将块数据添加到结果变量中。如果发生错误,同样会先加synchronized锁,然后它将添加一个标记类来指示获取失败。
其次,ShuffleBlockFetcherIterator会调用BlockStoreClient的fetchBlocks方法,在调用前会判断请求的内容的大小,如果超过门限,则传参定义DownloadFileManager,它会使得shuffleData将被下载到临时文件。
下面我们看下最终的fetchBlocks是如何实现的?
@Override
public void fetchBlocks(
String host,
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener,
DownloadFileManager downloadFileManager) {
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
// [1] 首先创建并初始化RetryingBlockFetcher类,用它加载shuffle files
int maxRetries = transportConf.maxIORetries();
RetryingBlockTransferor.BlockTransferStarter blockFetchStarter =
(inputBlockId, inputListener) -> {
// Unless this client is closed.
if (clientFactory != null) {
assert inputListener instanceof BlockFetchingListener :
"Expecting a BlockFetchingListener, but got " + inputListener.getClass();
TransportClient client = clientFactory.createClient(host, port, maxRetries > 0);
// [2] 创建OneForOneBlockFetcher,用其进行下载shuffle Data
new OneForOneBlockFetcher(client, appId, execId, inputBlockId,
(BlockFetchingListener) inputListener, transportConf, downloadFileManager).start();
} else {
logger.info("This clientFactory was closed. Skipping further block fetch retries.");
}
};
...
// [3] 调用OneForOneBlockFetcher的start方法
blockFetchStarter.createAndStart(blockIds, listener);
}
}
- [1] 首先创建并初始化RetryingBlockFetcher类,用它加载shuffle files
- [2] 创建OneForOneBlockFetcher,用其进行下载shuffle Data
OneForOneBlockFetcher进行Shuffle 数据的下载
OneForOneBlockFetcher是基于RPC通信,从各个Executor端获取shuffle数据,我们首先来简要概述下:
- 首先,fetcher 会向持有 shuffle 文件的 executor发送FetchShuffleBlocks消息;
- 其次,executor将register new Stream 同时返回StreamHandle消息到fetcher, 它带有streamId;
- 在收到StreamHandle响应后,client将stream或load 数据块;
- 如果
downloadFileManager
不为空,则会将结果写入临时文件;对于内存的场景,shuffle bytes将加载到in-memory buffer中; - 最终,基于临时文件还是基于内存都会调用sendRequest中定义的BlockFetchingListener回调函数。
获取到的shuffle data会被放入到new LinkedBlockingQueue[FetchResult],并调用next()方法。如果所有可用的块数据都已被消耗,迭代器将执行之前提供的 fetchUpToMaxBytes()。
ShuffleBlockFetcherIterator初始化完成后
在ShuffleBlockFetcherIterator初始化完成后,我们再来看看剩余的工作:
private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterator)
extends TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
if (data != null) {
data.cleanup()locations(blocksByAddress)
data = null
}
}
def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context)
}
在ShuffleBlockFetcherIterator初始化完成后,会将其转换为CompletionIterator,在其中主要是进行资源的释放。然后借助于反序列化器将其将shuffle block反序列化为record迭代器。在将其包装为metricIter 同于更新task的metric。之后再将其封装为InterruptibleIterator迭代器。可中断迭代器的作用是每次执行hasNext方法时,它都会分析任务状态并最终终止托管此迭代器的任务。主要用于启用了推测执行的情况。
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
def hasNext: Boolean = {
// TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
context.killTaskIfInterrupted()
delegate.hasNext
}
接下来就是reduce端的聚合排序的操作, 注意这里需要在ShuffleDependency中定义, aggregator和keyOrdering,这些操作需要在PairRDDFunctions
中进行定义。
但是在SparkSQL中,它采用的是ShuffleExchangeExec并不会定义 aggregator和keyOrdering,那么Spark SQL是如何实现聚合和排序的呢?
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
...
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
val resultIter: Iterator[Product2[K, C]] =dep.keyOrdering match {
caseSome(keyOrd: Ordering[K]) =>
val sorter =
new ExternalSorter[K, C, C](context, ordering =Some(keyOrd), serializer =dep.serializer)
sorter.insertAllAndUpdateMetrics(aggregatedIter)
case None =>
aggregatedIter
}
其实通过其执行计划可以知道,其会在其中插入Sort算子来实现聚合排序。
到此为止,shuffle reader的大致过程已经走了一遍,但是还有很多的重要细节并没有展开探讨,那么这里就详细总结下整体的流程:
Fetch前的准备
- fetch reader 的调用主要是ShuffledRDD和ShuffledRowRDD中,通过传入 不同的partitionspecs给getReader传入不同的调用参数。
- 在getReader中会先通过mapOutputTracker获取mapid对应的shuffle文件的位置,然后在通过
BlockStoreShuffleReader
reader的唯一实现类进行shuffle fetch; - 在Driver端mapOutputTracker记录mapId和对应的文件位置主要由MapOutputTrackerMaster进行维护,在创建mapShuffleStage时会向master tracker中注册shuffleid, 在完成mapStage时会更新对应shuffleId中维护的mapid对应的位置信息。在Executor端从MapOutputTrackerWorker中获取位置信息,如果获取不到会向master tracker发送信息,同步信息过来;
处理Fetch请求
- 在BlockStoreShuffleReader中进行fetch时,会先创建ShuffleBlockFetcherIterator, 并将Fetch分为local, host local, remote不同方式;同时在Fetch时也会有些限制,包括每个Excutor阻塞的fetch request数和fetch shuffle数据是否大于分配的内存;如果请求的数据量过多,超过了内存限制,将通过写入临时文件实现;如果网络通信开销太大,fetcher 将停止读取,并在需要下一个 shuffle 块文件时恢复读取。
- 最终的Fetch是通过OneForOneBlockFetcher实现的,fetcher 会向持有 shuffle 文件的 executor发送FetchShuffleBlocks消息,executor将register new Stream 同时将数据封装为StreamHandle消息返回到fetcher,client最后再将加载数据块;最终调用BlockFetchingListener回调函数。
Fetch后的处理
- reduce端聚合数据:如果map端已经聚合过了,则对读取到的聚合结果进行聚合。如果map端没有聚合,则针对未合并的<k,v>进行聚合。
- reduce端排序数据:如果需要对key排序,则进行排序。基于sort的shuffle实现过程中,默认只是按照partitionId排序。在每一个partition内部并没有排序,因此添加了keyOrdering变量,提供是否需要对分区内部的key排序
- 另外需要注意的是SparkSQL中并不会设置ShuffleDependency的排序和聚合,而是通过规则在逻辑树中插入Sort算子实现的。
学完Shuffle Reader下面是一些思考题:
- 为什么在调用getReader时要根据partitionspecs的不同传递不同的参数?主要的作用是什么?
- 远程Fetch和本地Fetch最大的区别是什么?
- InterruptibleIterator 和 CompletionIterator 迭代器的作用是什么?