Skip to content
This repository has been archived by the owner on Apr 17, 2024. It is now read-only.

Commit

Permalink
rebase SSO to the newly updated HPNL
Browse files Browse the repository at this point in the history
Signed-off-by: Haodong Tang <[email protected]>
  • Loading branch information
Tang Haodong committed Apr 17, 2019
1 parent 3a6ab31 commit 7d3e290
Show file tree
Hide file tree
Showing 15 changed files with 270 additions and 379 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.spark.shuffle.pmof.MetadataResolver;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerId$;
import org.apache.spark.network.pmof.RdmaTransferService;
import org.apache.spark.network.pmof.PmofTransferService;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
Expand Down Expand Up @@ -266,8 +266,8 @@ void closeAndWriteOutput() throws IOException {
}
BlockManagerId shuffleServerId = blockManager.shuffleServerId();
if (enable_rdma) {
BlockManagerId blockManagerId = BlockManagerId$.MODULE$.apply(shuffleServerId.executorId(), RdmaTransferService.shuffleNodesMap().get(shuffleServerId.host()).get(),
RdmaTransferService.getTransferServiceInstance(blockManager, null, false).port(), shuffleServerId.topologyInfo());
BlockManagerId blockManagerId = BlockManagerId$.MODULE$.apply(shuffleServerId.executorId(), PmofTransferService.shuffleNodesMap().get(shuffleServerId.host()).get(),
PmofTransferService.getTransferServiceInstance(blockManager, null, false).port(), shuffleServerId.topologyInfo());
mapStatus = MapStatus$.MODULE$.apply(blockManagerId, partitionLengths);
} else {
mapStatus = MapStatus$.MODULE$.apply(shuffleServerId, partitionLengths);
Expand Down
61 changes: 61 additions & 0 deletions src/main/scala/org/apache/spark/network/pmof/Client.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.apache.spark.network.pmof

import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap

import com.intel.hpnl.core.{Connection, EqService}
import org.apache.spark.shuffle.pmof.PmofShuffleManager

class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) {
final val outstandingReceiveFetches: ConcurrentHashMap[Long, ReceivedCallback] =
new ConcurrentHashMap[Long, ReceivedCallback]()
final val outstandingReadFetches: ConcurrentHashMap[Int, ReadCallback] =
new ConcurrentHashMap[Int, ReadCallback]()
final val shuffleBufferMap: ConcurrentHashMap[Int, ShuffleBuffer] = new ConcurrentHashMap[Int, ShuffleBuffer]()

def getEqService: EqService = clientFactory.eqService

def read(shuffleBuffer: ShuffleBuffer, reqSize: Int,
rmaAddress: Long, rmaRkey: Long, localAddress: Int,
callback: ReadCallback, isDeferred: Boolean = false): Unit = {
if (!isDeferred) {
outstandingReadFetches.putIfAbsent(shuffleBuffer.getRdmaBufferId, callback)
shuffleBufferMap.putIfAbsent(shuffleBuffer.getRdmaBufferId, shuffleBuffer)
}
val ret = con.read(shuffleBuffer.getRdmaBufferId, localAddress, reqSize, rmaAddress, rmaRkey)
if (ret == -11) {
if (isDeferred) {
clientFactory.deferredReadList.addFirst(
new ClientDeferredRead(this, shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress)
)
} else {
clientFactory.deferredReadList.addLast(
new ClientDeferredRead(this, shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress)
)
}
}
}

def send(byteBuffer: ByteBuffer, seq: Long, msgType: Byte,
callback: ReceivedCallback, isDeferred: Boolean): Unit = {
assert(con != null)
if (callback != null) {
outstandingReceiveFetches.putIfAbsent(seq, callback)
}
val sendBuffer = this.con.takeSendBuffer(false)
if (sendBuffer == null) {
if (isDeferred) {
clientFactory.deferredSendList.addFirst(
new ClientDeferredSend(this, byteBuffer, seq, msgType, callback)
)
} else {
clientFactory.deferredSendList.addLast(
new ClientDeferredSend(this, byteBuffer, seq, msgType, callback)
)
}
return
}
sendBuffer.put(byteBuffer, msgType, seq)
con.send(sendBuffer.remaining(), sendBuffer.getBufferId)
}
}
122 changes: 122 additions & 0 deletions src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package org.apache.spark.network.pmof

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingDeque}

import com.intel.hpnl.core._
import org.apache.spark.SparkConf
import org.apache.spark.shuffle.pmof.PmofShuffleManager

import scala.collection.mutable.ArrayBuffer

class ClientFactory(conf: SparkConf) {
final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE
final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16)
final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1)

final val eqService = new EqService(workers, BUFFER_NUM, false).init()
final val cqService = new CqService(eqService).init()

final val conArray: ArrayBuffer[Connection] = ArrayBuffer()
final val deferredSendList = new LinkedBlockingDeque[ClientDeferredSend]()
final val deferredReadList = new LinkedBlockingDeque[ClientDeferredRead]()
final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]()
final val conMap = new ConcurrentHashMap[Connection, Client]()

def init(): Unit = {
eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2)
cqService.addExternalEvent(new ExternalHandler {
override def handle(): Unit = {
handleDeferredSend()
handleDeferredRead()
}
})
val clientRecvHandler = new ClientRecvHandler
val clientReadHandler = new ClientReadHandler
eqService.setRecvCallback(clientRecvHandler)
eqService.setReadCallback(clientReadHandler)
cqService.start()
}

def createClient(shuffleManager: PmofShuffleManager, address: String, port: Int): Client = {
val socketAddress: InetSocketAddress = InetSocketAddress.createUnresolved(address, port)
var client = clientMap.get(socketAddress)
if (client == null) {
ClientFactory.this.synchronized {
client = clientMap.get(socketAddress)
if (client == null) {
val con = eqService.connect(address, port.toString, 0)
client = new Client(this, shuffleManager, con)
clientMap.put(socketAddress, client)
conMap.put(con, client)
}
}
}
client
}

def stop(): Unit = {
cqService.shutdown()
}

def waitToStop(): Unit = {
cqService.join()
eqService.shutdown()
eqService.join()
}

def getEqService: EqService = eqService

class ClientRecvHandler() extends Handler {
override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = {
val buffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId)
val rpcMessage: ByteBuffer = buffer.get(blockBufferSize)
val seq = buffer.getSeq
val msgType = buffer.getType
val callback = conMap.get(con).outstandingReceiveFetches.get(seq)
if (msgType == 0.toByte) {
callback.onSuccess(null)
} else {
val metadataResolver = conMap.get(con).shuffleManager.metadataResolver
val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(rpcMessage)
callback.onSuccess(blockInfoArray)
}
}
}

class ClientReadHandler() extends Handler {
override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = {
def fun(v1: Int): Unit = {
conMap.get(con).shuffleBufferMap.remove(v1)
conMap.get(con).outstandingReadFetches.remove(v1)
}

val callback = conMap.get(con).outstandingReadFetches.get(rdmaBufferId)
val shuffleBuffer = conMap.get(con).shuffleBufferMap.get(rdmaBufferId)
callback.onSuccess(shuffleBuffer, fun)
}
}

def handleDeferredSend(): Unit = {
if (!deferredSendList.isEmpty) {
val deferredSend = deferredSendList.pollFirst()
deferredSend.client.send(deferredSend.byteBuffer, deferredSend.seq,
deferredSend.msgType, deferredSend.callback, isDeferred = true)
}
}

def handleDeferredRead(): Unit = {
if (!deferredReadList.isEmpty) {
val deferredRead = deferredReadList.pollFirst()
deferredRead.client.read(deferredRead.shuffleBuffer, deferredRead.reqSize,
deferredRead.rmaAddress, deferredRead.rmaRkey, deferredRead.localAddress, null, isDeferred = true)
}
}
}

class ClientDeferredSend(val client: Client, val byteBuffer: ByteBuffer, val seq: Long, val msgType: Byte,
val callback: ReceivedCallback) {}

class ClientDeferredRead(val client: Client, val shuffleBuffer: ShuffleBuffer, val reqSize: Int,
val rmaAddress: Long, val rmaRkey: Long, val localAddress: Int) {}
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,19 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}

import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.pmof.{MetadataResolver, PmofShuffleManager}
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId}
import org.apache.spark.{SparkConf, SparkEnv}

import scala.collection.mutable

class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManager, val hostname: String,
var port: Int, val supportRma: Boolean) extends TransferService {
final var server: RdmaServer = _
final private var recvHandler: ServerRecvHandler = _
final private var connectHandler: ServerConnectHandler = _
final private var clientFactory: RdmaClientFactory = _
private var appId: String = _
class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManager,
val hostname: String, var port: Int) extends TransferService {
final var server: Server = _
final private var clientFactory: ClientFactory = _
private var nextReqId: AtomicLong = _
final val metadataResolver: MetadataResolver = this.shuffleManager.metadataResolver

private val serializer = new JavaSerializer(conf)

override def fetchBlocks(host: String,
port: Int,
executId: String,
Expand All @@ -34,8 +28,8 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage

def fetchBlock(reqHost: String, reqPort: Int, rmaAddress: Long, rmaLength: Int,
rmaRkey: Long, localAddress: Int, shuffleBuffer: ShuffleBuffer,
rdmaClient: RdmaClient, callback: ReadCallback): Unit = {
rdmaClient.read(shuffleBuffer, rmaLength, rmaAddress, rmaRkey, localAddress, callback)
client: Client, callback: ReadCallback): Unit = {
client.read(shuffleBuffer, rmaLength, rmaAddress, rmaRkey, localAddress, callback)
}

def fetchBlockInfo(blockIds: Array[BlockId], receivedCallback: ReceivedCallback): Unit = {
Expand All @@ -45,12 +39,12 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage

def syncBlocksInfo(host: String, port: Int, byteBuffer: ByteBuffer, msgType: Byte,
callback: ReceivedCallback): Unit = {
clientFactory.createClient(shuffleManager, host, port, supportRma = false).
clientFactory.createClient(shuffleManager, host, port).
send(byteBuffer, nextReqId.getAndIncrement(), msgType, callback, isDeferred = false)
}

def getClient(reqHost: String, reqPort: Int): RdmaClient = {
clientFactory.createClient(shuffleManager, reqHost, reqPort, supportRma = true)
def getClient(reqHost: String, reqPort: Int): Client = {
clientFactory.createClient(shuffleManager, reqHost, reqPort)
}

override def close(): Unit = {
Expand All @@ -65,15 +59,11 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage
}

def init(): Unit = {
this.server = new RdmaServer(conf, shuffleManager, hostname, port, supportRma)
this.appId = conf.getAppId
this.recvHandler = new ServerRecvHandler(server, appId, serializer)
this.connectHandler = new ServerConnectHandler(server)
this.server.setRecvHandler(this.recvHandler)
this.server.setConnectHandler(this.connectHandler)
this.clientFactory = new RdmaClientFactory(conf)
this.server = new Server(conf, shuffleManager, hostname, port)
this.clientFactory = new ClientFactory(conf)
this.server.init()
this.server.start()
this.clientFactory.init()
this.port = server.port
val random = new Random().nextInt(Integer.MAX_VALUE)
this.nextReqId = new AtomicLong(random)
Expand All @@ -82,32 +72,31 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage
override def init(blockDataManager: BlockDataManager): Unit = {}
}

object RdmaTransferService {
object PmofTransferService {
final val env: SparkEnv = SparkEnv.get
final val conf: SparkConf = env.conf
final val CHUNKSIZE: Int = conf.getInt("spark.shuffle.pmof.chunk_size", 4096*3)
final val driverHost: String = conf.get("spark.driver.rhost", defaultValue = "172.168.0.43")
final val driverPort: Int = conf.getInt("spark.driver.rport", defaultValue = 61000)
val shuffleNodes: Array[Array[String]] =
final val shuffleNodes: Array[Array[String]] =
conf.get("spark.shuffle.pmof.node", defaultValue = "").split(",").map(_.split("-"))
val shuffleNodesMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()
final val shuffleNodesMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()
for (array <- shuffleNodes) {
shuffleNodesMap.put(array(0), array(1))
}
private val initialized = new AtomicBoolean(false)
private var transferService: RdmaTransferService = _
private var transferService: PmofTransferService = _
def getTransferServiceInstance(blockManager: BlockManager, shuffleManager: PmofShuffleManager = null,
isDriver: Boolean = false): RdmaTransferService = {
isDriver: Boolean = false): PmofTransferService = {
if (!initialized.get()) {
RdmaTransferService.this.synchronized {
PmofTransferService.this.synchronized {
if (initialized.get()) return transferService
if (isDriver) {
transferService =
new RdmaTransferService(conf, shuffleManager, driverHost, driverPort, false)
new PmofTransferService(conf, shuffleManager, driverHost, driverPort)
} else {
transferService =
new RdmaTransferService(conf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host),
0, false)
new PmofTransferService(conf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), 0)
}
transferService.init()
initialized.set(true)
Expand Down
Loading

0 comments on commit 7d3e290

Please sign in to comment.