This repository has been archived by the owner on Apr 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rebase SSO to the newly updated HPNL
Signed-off-by: Haodong Tang <[email protected]>
- Loading branch information
Tang Haodong
committed
Apr 17, 2019
1 parent
3a6ab31
commit 7d3e290
Showing
15 changed files
with
270 additions
and
379 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package org.apache.spark.network.pmof | ||
|
||
import java.nio.ByteBuffer | ||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
import com.intel.hpnl.core.{Connection, EqService} | ||
import org.apache.spark.shuffle.pmof.PmofShuffleManager | ||
|
||
class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) { | ||
final val outstandingReceiveFetches: ConcurrentHashMap[Long, ReceivedCallback] = | ||
new ConcurrentHashMap[Long, ReceivedCallback]() | ||
final val outstandingReadFetches: ConcurrentHashMap[Int, ReadCallback] = | ||
new ConcurrentHashMap[Int, ReadCallback]() | ||
final val shuffleBufferMap: ConcurrentHashMap[Int, ShuffleBuffer] = new ConcurrentHashMap[Int, ShuffleBuffer]() | ||
|
||
def getEqService: EqService = clientFactory.eqService | ||
|
||
def read(shuffleBuffer: ShuffleBuffer, reqSize: Int, | ||
rmaAddress: Long, rmaRkey: Long, localAddress: Int, | ||
callback: ReadCallback, isDeferred: Boolean = false): Unit = { | ||
if (!isDeferred) { | ||
outstandingReadFetches.putIfAbsent(shuffleBuffer.getRdmaBufferId, callback) | ||
shuffleBufferMap.putIfAbsent(shuffleBuffer.getRdmaBufferId, shuffleBuffer) | ||
} | ||
val ret = con.read(shuffleBuffer.getRdmaBufferId, localAddress, reqSize, rmaAddress, rmaRkey) | ||
if (ret == -11) { | ||
if (isDeferred) { | ||
clientFactory.deferredReadList.addFirst( | ||
new ClientDeferredRead(this, shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress) | ||
) | ||
} else { | ||
clientFactory.deferredReadList.addLast( | ||
new ClientDeferredRead(this, shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress) | ||
) | ||
} | ||
} | ||
} | ||
|
||
def send(byteBuffer: ByteBuffer, seq: Long, msgType: Byte, | ||
callback: ReceivedCallback, isDeferred: Boolean): Unit = { | ||
assert(con != null) | ||
if (callback != null) { | ||
outstandingReceiveFetches.putIfAbsent(seq, callback) | ||
} | ||
val sendBuffer = this.con.takeSendBuffer(false) | ||
if (sendBuffer == null) { | ||
if (isDeferred) { | ||
clientFactory.deferredSendList.addFirst( | ||
new ClientDeferredSend(this, byteBuffer, seq, msgType, callback) | ||
) | ||
} else { | ||
clientFactory.deferredSendList.addLast( | ||
new ClientDeferredSend(this, byteBuffer, seq, msgType, callback) | ||
) | ||
} | ||
return | ||
} | ||
sendBuffer.put(byteBuffer, msgType, seq) | ||
con.send(sendBuffer.remaining(), sendBuffer.getBufferId) | ||
} | ||
} |
122 changes: 122 additions & 0 deletions
122
src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package org.apache.spark.network.pmof | ||
|
||
import java.net.InetSocketAddress | ||
import java.nio.ByteBuffer | ||
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingDeque} | ||
|
||
import com.intel.hpnl.core._ | ||
import org.apache.spark.SparkConf | ||
import org.apache.spark.shuffle.pmof.PmofShuffleManager | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
class ClientFactory(conf: SparkConf) { | ||
final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE | ||
final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16) | ||
final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) | ||
|
||
final val eqService = new EqService(workers, BUFFER_NUM, false).init() | ||
final val cqService = new CqService(eqService).init() | ||
|
||
final val conArray: ArrayBuffer[Connection] = ArrayBuffer() | ||
final val deferredSendList = new LinkedBlockingDeque[ClientDeferredSend]() | ||
final val deferredReadList = new LinkedBlockingDeque[ClientDeferredRead]() | ||
final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]() | ||
final val conMap = new ConcurrentHashMap[Connection, Client]() | ||
|
||
def init(): Unit = { | ||
eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2) | ||
cqService.addExternalEvent(new ExternalHandler { | ||
override def handle(): Unit = { | ||
handleDeferredSend() | ||
handleDeferredRead() | ||
} | ||
}) | ||
val clientRecvHandler = new ClientRecvHandler | ||
val clientReadHandler = new ClientReadHandler | ||
eqService.setRecvCallback(clientRecvHandler) | ||
eqService.setReadCallback(clientReadHandler) | ||
cqService.start() | ||
} | ||
|
||
def createClient(shuffleManager: PmofShuffleManager, address: String, port: Int): Client = { | ||
val socketAddress: InetSocketAddress = InetSocketAddress.createUnresolved(address, port) | ||
var client = clientMap.get(socketAddress) | ||
if (client == null) { | ||
ClientFactory.this.synchronized { | ||
client = clientMap.get(socketAddress) | ||
if (client == null) { | ||
val con = eqService.connect(address, port.toString, 0) | ||
client = new Client(this, shuffleManager, con) | ||
clientMap.put(socketAddress, client) | ||
conMap.put(con, client) | ||
} | ||
} | ||
} | ||
client | ||
} | ||
|
||
def stop(): Unit = { | ||
cqService.shutdown() | ||
} | ||
|
||
def waitToStop(): Unit = { | ||
cqService.join() | ||
eqService.shutdown() | ||
eqService.join() | ||
} | ||
|
||
def getEqService: EqService = eqService | ||
|
||
class ClientRecvHandler() extends Handler { | ||
override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { | ||
val buffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId) | ||
val rpcMessage: ByteBuffer = buffer.get(blockBufferSize) | ||
val seq = buffer.getSeq | ||
val msgType = buffer.getType | ||
val callback = conMap.get(con).outstandingReceiveFetches.get(seq) | ||
if (msgType == 0.toByte) { | ||
callback.onSuccess(null) | ||
} else { | ||
val metadataResolver = conMap.get(con).shuffleManager.metadataResolver | ||
val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(rpcMessage) | ||
callback.onSuccess(blockInfoArray) | ||
} | ||
} | ||
} | ||
|
||
class ClientReadHandler() extends Handler { | ||
override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { | ||
def fun(v1: Int): Unit = { | ||
conMap.get(con).shuffleBufferMap.remove(v1) | ||
conMap.get(con).outstandingReadFetches.remove(v1) | ||
} | ||
|
||
val callback = conMap.get(con).outstandingReadFetches.get(rdmaBufferId) | ||
val shuffleBuffer = conMap.get(con).shuffleBufferMap.get(rdmaBufferId) | ||
callback.onSuccess(shuffleBuffer, fun) | ||
} | ||
} | ||
|
||
def handleDeferredSend(): Unit = { | ||
if (!deferredSendList.isEmpty) { | ||
val deferredSend = deferredSendList.pollFirst() | ||
deferredSend.client.send(deferredSend.byteBuffer, deferredSend.seq, | ||
deferredSend.msgType, deferredSend.callback, isDeferred = true) | ||
} | ||
} | ||
|
||
def handleDeferredRead(): Unit = { | ||
if (!deferredReadList.isEmpty) { | ||
val deferredRead = deferredReadList.pollFirst() | ||
deferredRead.client.read(deferredRead.shuffleBuffer, deferredRead.reqSize, | ||
deferredRead.rmaAddress, deferredRead.rmaRkey, deferredRead.localAddress, null, isDeferred = true) | ||
} | ||
} | ||
} | ||
|
||
class ClientDeferredSend(val client: Client, val byteBuffer: ByteBuffer, val seq: Long, val msgType: Byte, | ||
val callback: ReceivedCallback) {} | ||
|
||
class ClientDeferredRead(val client: Client, val shuffleBuffer: ShuffleBuffer, val reqSize: Int, | ||
val rmaAddress: Long, val rmaRkey: Long, val localAddress: Int) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.