From 22ea58f3b5a030803d7e69e3f584d10db4052821 Mon Sep 17 00:00:00 2001 From: Haodong Date: Fri, 11 Oct 2019 20:48:16 +0800 Subject: [PATCH 1/5] fix spill error Signed-off-by: Haodong Tang --- core/pom.xml | 2 +- .../pmof/PersistentMemoryMetaHandler.java | 317 +++++++++--------- .../storage/pmof/PmemBlockObjectStream.scala | 29 +- 3 files changed, 178 insertions(+), 170 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 535cf3b7..b400270b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-core_2.11 - 2.3.3 + 2.3.0 provided diff --git a/core/src/main/java/org/apache/spark/storage/pmof/PersistentMemoryMetaHandler.java b/core/src/main/java/org/apache/spark/storage/pmof/PersistentMemoryMetaHandler.java index d4cff41b..a9c4727e 100644 --- a/core/src/main/java/org/apache/spark/storage/pmof/PersistentMemoryMetaHandler.java +++ b/core/src/main/java/org/apache/spark/storage/pmof/PersistentMemoryMetaHandler.java @@ -3,7 +3,6 @@ import java.nio.channels.FileLock; import java.sql.Connection; import org.sqlite.SQLiteConfig; -import java.sql.DatabaseMetaData; import java.sql.DriverManager; import java.sql.SQLException; import java.sql.PreparedStatement; @@ -34,44 +33,46 @@ public void createTable(String root_dir) { + " UNIQUE(shuffleId, device)\n" + ");\n"; - synchronized (file) { - FileOutputStream fos = null; - Connection conn = null; - Statement stmt = null; - try { - fos = new FileOutputStream(file); - fos.getChannel().lock(); - conn = DriverManager.getConnection(url); - stmt = conn.createStatement(); - stmt.execute(sql); + FileOutputStream fos = null; + FileLock fl = null; + Connection conn = null; + Statement stmt = null; + try { + fos = new FileOutputStream(file); + fl = fos.getChannel().lock(); + conn = DriverManager.getConnection(url); + stmt = conn.createStatement(); + stmt.execute(sql); - sql = "CREATE TABLE IF NOT EXISTS devices (\n" - + " device text UNIQUE,\n" - + " mount_count int\n" - + ");"; - stmt.execute(sql); + sql = "CREATE TABLE IF NOT EXISTS devices (\n" + + " device text UNIQUE,\n" + + " mount_count int\n" + + ");"; + stmt.execute(sql); + } catch (SQLException e) { + System.out.println("createTable failed:" + e.getMessage()); + } catch (IOException e) { + e.printStackTrace(); + } finally { + try { + if (stmt != null) stmt.close(); } catch (SQLException e) { - System.out.println("createTable failed:" + e.getMessage()); - } catch (IOException e) { e.printStackTrace(); - } finally { - try { - if (stmt != null) stmt.close(); - } catch (SQLException e) { - e.printStackTrace(); - } - try { - if (conn != null) conn.close(); - } catch (SQLException e) { - e.printStackTrace(); + } + try { + if (conn != null) conn.close(); + } catch (SQLException e) { + e.printStackTrace(); + } + try { + if (fl != null) { + fl.release(); } - try { - if (fos != null) { - fos.close(); - } - } catch (IOException e) { - e.printStackTrace(); + if (fos != null) { + fos.close(); } + } catch (IOException e) { + e.printStackTrace(); } } System.out.println("Metastore DB connected: " + url); @@ -79,41 +80,43 @@ public void createTable(String root_dir) { public void insertRecord(String shuffleId, String device) { String sql = "INSERT OR IGNORE INTO metastore(shuffleId,device) VALUES('" + shuffleId + "','" + device + "')"; - synchronized (file) { - FileOutputStream fos = null; - Connection conn = null; - Statement stmt = null; + FileOutputStream fos = null; + FileLock fl = null; + Connection conn = null; + Statement stmt = null; + try { + fos = new FileOutputStream(file); + fl = fos.getChannel().lock(); + SQLiteConfig config = new SQLiteConfig(); + config.setBusyTimeout("30000"); + conn = DriverManager.getConnection(url); + stmt = conn.createStatement(); + stmt.executeUpdate(sql); + } catch (SQLException e) { + e.printStackTrace(); + System.exit(-1); + } catch (IOException e) { + e.printStackTrace(); + } finally { try { - fos = new FileOutputStream(file); - fos.getChannel().lock(); - SQLiteConfig config = new SQLiteConfig(); - config.setBusyTimeout("30000"); - conn = DriverManager.getConnection(url); - stmt = conn.createStatement(); - stmt.executeUpdate(sql); + if (stmt != null) stmt.close(); } catch (SQLException e) { e.printStackTrace(); - System.exit(-1); - } catch (IOException e) { + } + try { + if (conn != null) conn.close(); + } catch (SQLException e) { e.printStackTrace(); - } finally { - try { - if (stmt != null) stmt.close(); - } catch (SQLException e) { - e.printStackTrace(); - } - try { - if (conn != null) conn.close(); - } catch (SQLException e) { - e.printStackTrace(); + } + try { + if (fl != null) { + fl.release(); } - try { - if (fos != null) { - fos.close(); - } - } catch (IOException e) { - e.printStackTrace(); + if (fos != null) { + fos.close(); } + } catch (IOException e) { + e.printStackTrace(); } } } @@ -121,49 +124,51 @@ public void insertRecord(String shuffleId, String device) { public String getShuffleDevice(String shuffleId){ String sql = "SELECT device FROM metastore where shuffleId = ?"; String res = ""; - synchronized (file) { - FileOutputStream fos = null; - Connection conn = null; - PreparedStatement stmt = null; - ResultSet rs = null; + FileOutputStream fos = null; + FileLock fl = null; + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + fos = new FileOutputStream(file); + fl = fos.getChannel().lock(); + conn = DriverManager.getConnection(url); + stmt = conn.prepareStatement(sql); + stmt.setString(1, shuffleId); + rs = stmt.executeQuery(); + while (rs.next()) { + res = rs.getString("device"); + } + } catch (SQLException e) { + e.printStackTrace(); + System.exit(-1); + } catch (IOException e) { + e.printStackTrace(); + } finally { try { - fos = new FileOutputStream(file); - fos.getChannel().lock(); - conn = DriverManager.getConnection(url); - stmt = conn.prepareStatement(sql); - stmt.setString(1, shuffleId); - rs = stmt.executeQuery(); - while (rs.next()) { - res = rs.getString("device"); - } + if (rs != null) rs.close(); } catch (SQLException e) { e.printStackTrace(); - System.exit(-1); - } catch (IOException e) { + } + try { + if (stmt != null) stmt.close(); + } catch (SQLException e) { e.printStackTrace(); - } finally { - try { - if (rs != null) rs.close(); - } catch (SQLException e) { - e.printStackTrace(); - } - try { - if (stmt != null) stmt.close(); - } catch (SQLException e) { - e.printStackTrace(); - } - try { - if (conn != null) conn.close(); - } catch (SQLException e) { - e.printStackTrace(); + } + try { + if (conn != null) conn.close(); + } catch (SQLException e) { + e.printStackTrace(); + } + try { + if (fl != null) { + fl.release(); } - try { - if (fos != null) { - fos.close(); - } - } catch (IOException e) { - e.printStackTrace(); + if (fos != null) { + fos.close(); } + } catch (IOException e) { + e.printStackTrace(); } } return res; @@ -175,69 +180,71 @@ public String getUnusedDevice(ArrayList full_device_list){ HashMap device_count = new HashMap(); String device = ""; int count; - synchronized (file) { - FileOutputStream fos = null; - Connection conn = null; - Statement stmt = null; - ResultSet rs = null; - try { - fos = new FileOutputStream(file); - fos.getChannel().lock(); - SQLiteConfig config = new SQLiteConfig(); - config.setBusyTimeout("30000"); - conn = DriverManager.getConnection(url); - stmt = conn.createStatement(); - rs = stmt.executeQuery(sql); - while (rs.next()) { - device_list.add(rs.getString("device")); - device_count.put(rs.getString("device"), rs.getInt("mount_count")); - } - full_device_list.removeAll(device_list); - if (full_device_list.size() == 0) { - // reuse old device, picked the device has smallest mount_count - device = getDeviceWithMinCount(device_count); - if (device != null && device.length() == 0) { - throw new SQLException(); - } - count = (Integer) device_count.get(device) + 1; - sql = "UPDATE devices SET mount_count = " + count + " WHERE device = '" + device + "'\n"; - } else { - device = full_device_list.get(0); - count = 1; - sql = "INSERT OR IGNORE INTO devices(device, mount_count) VALUES('" + device + "', " + count + ")\n"; + FileOutputStream fos = null; + FileLock fl = null; + Connection conn = null; + Statement stmt = null; + ResultSet rs = null; + try { + fos = new FileOutputStream(file); + fl = fos.getChannel().lock(); + SQLiteConfig config = new SQLiteConfig(); + config.setBusyTimeout("30000"); + conn = DriverManager.getConnection(url); + stmt = conn.createStatement(); + rs = stmt.executeQuery(sql); + while (rs.next()) { + device_list.add(rs.getString("device")); + device_count.put(rs.getString("device"), rs.getInt("mount_count")); + } + full_device_list.removeAll(device_list); + if (full_device_list.size() == 0) { + // reuse old device, picked the device has smallest mount_count + device = getDeviceWithMinCount(device_count); + if (device != null && device.length() == 0) { + throw new SQLException(); } + count = (Integer) device_count.get(device) + 1; + sql = "UPDATE devices SET mount_count = " + count + " WHERE device = '" + device + "'\n"; + } else { + device = full_device_list.get(0); + count = 1; + sql = "INSERT OR IGNORE INTO devices(device, mount_count) VALUES('" + device + "', " + count + ")\n"; + } - System.out.println(sql); + System.out.println(sql); - stmt.executeUpdate(sql); + stmt.executeUpdate(sql); + } catch (SQLException e) { + e.printStackTrace(); + System.exit(-1); + } catch (IOException e) { + e.printStackTrace(); + } finally { + try { + if (rs != null) rs.close(); } catch (SQLException e) { e.printStackTrace(); - System.exit(-1); - } catch (IOException e) { + } + try { + if (stmt != null) stmt.close(); + } catch (SQLException e) { e.printStackTrace(); - } finally { - try { - if (rs != null) rs.close(); - } catch (SQLException e) { - e.printStackTrace(); - } - try { - if (stmt != null) stmt.close(); - } catch (SQLException e) { - e.printStackTrace(); - } - try { - if (conn != null) conn.close(); - } catch (SQLException e) { - e.printStackTrace(); + } + try { + if (conn != null) conn.close(); + } catch (SQLException e) { + e.printStackTrace(); + } + try { + if (fl != null) { + fl.release(); } - try { - if (fos != null) { - fos.close(); - } - } catch (IOException e) { - e.printStackTrace(); + if (fos != null) { + fos.close(); } + } catch (IOException e) { + e.printStackTrace(); } } System.out.println("Metastore DB: get unused device, should be " + device + "."); diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala index 9ca3a228..e0a6a11c 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala @@ -90,20 +90,21 @@ private[spark] class PmemBlockObjectStream( bytesStream.flush() val bufSize = bytesStream.asInstanceOf[PmemOutputStream].size //logInfo(blockId.name + " do spill, size is " + bufSize) - - recordsArray += recordsPerBlock - recordsPerBlock = 0 - size += bufSize - - if (blockId.isShuffle == true) { - val writeMetrics = taskMetrics.shuffleWriteMetrics - writeMetrics.incWriteTime(System.nanoTime() - start) - writeMetrics.incBytesWritten(bufSize) - } else { - taskMetrics.incDiskBytesSpilled(bufSize) - } - bytesStream.asInstanceOf[PmemOutputStream].reset() - spilled = true + if (bufSize > 0) { + recordsArray += recordsPerBlock + recordsPerBlock = 0 + size += bufSize + + if (blockId.isShuffle == true) { + val writeMetrics = taskMetrics.shuffleWriteMetrics + writeMetrics.incWriteTime(System.nanoTime() - start) + writeMetrics.incBytesWritten(bufSize) + } else { + taskMetrics.incDiskBytesSpilled(bufSize) + } + bytesStream.asInstanceOf[PmemOutputStream].reset() + spilled = true + } } } From 9af8e3d8666b43a1739a712a06c9df0e64973edc Mon Sep 17 00:00:00 2001 From: Haodong Tang Date: Fri, 18 Oct 2019 08:52:23 +0800 Subject: [PATCH 2/5] fix pmem buffer boundry issue Signed-off-by: Haodong Tang --- .../apache/spark/storage/pmof/PmemBuffer.java | 10 +++-- .../spark/storage/pmof/PmemInputStream.scala | 6 +-- native/src/PersistentMemoryPool.cpp | 38 +++++++------------ native/src/PersistentMemoryPool.h | 9 ++--- native/src/PmemBuffer.h | 26 +++++++++---- native/src/Request.cpp | 27 ++----------- native/src/Request.h | 16 ++------ 7 files changed, 52 insertions(+), 80 deletions(-) diff --git a/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java b/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java index b21988a6..557aff2e 100644 --- a/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java +++ b/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java @@ -12,7 +12,8 @@ public class PmemBuffer { private native int nativeGetPmemBufferRemaining(long pmBuffer); private native long nativeGetPmemBufferDataAddr(long pmBuffer); private native long nativeDeletePmemBuffer(long pmBuffer); - + + private boolean closed = false; long pmBuffer; PmemBuffer() { pmBuffer = nativeNewPmemBuffer(); @@ -53,7 +54,10 @@ long getDirectAddr() { return nativeGetPmemBufferDataAddr(pmBuffer); } - void close() { - nativeDeletePmemBuffer(pmBuffer); + synchronized void close() { + if (!closed) { + nativeDeletePmemBuffer(pmBuffer); + closed = true; + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala index 0636662c..c1388e54 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala @@ -14,7 +14,6 @@ class PmemInputStream( var remaining: Int = 0 var available_bytes: Int = persistentMemoryHandler.getPartitionSize(blockId).toInt val blockInfo: Array[(Long, Int)] = persistentMemoryHandler.getPartitionBlockInfo(blockId) - var is_closed = false def loadNextStream(): Int = { if (index >= blockInfo.length) @@ -70,10 +69,7 @@ class PmemInputStream( } override def close(): Unit = { - if (!is_closed) { - buf.close() - is_closed = true - } + buf.close() } def deleteBlock(): Unit = { diff --git a/native/src/PersistentMemoryPool.cpp b/native/src/PersistentMemoryPool.cpp index 0054549e..034a8f72 100644 --- a/native/src/PersistentMemoryPool.cpp +++ b/native/src/PersistentMemoryPool.cpp @@ -10,9 +10,7 @@ PMPool::PMPool(const char* dev, int maxStage, int maxMap, long size): maxStage(maxStage), maxMap(maxMap), - stop(false), - dev(dev), - worker(&PMPool::process, this) { + dev(dev) { const char *pool_layout_name = "pmem_spark_shuffle"; // if this is a fsdax device @@ -32,13 +30,6 @@ PMPool::PMPool(const char* dev, int maxStage, int maxMap, long size): } PMPool::~PMPool() { - while(request_queue.size() > 0) { - fprintf(stderr, "%s request queue size is %ld\n", dev, request_queue.size()); - sleep(1); - } - fprintf(stderr, "%s request queue size is %ld\n", dev, request_queue.size()); - stop = true; - worker.join(); pmemobj_close(pmpool); } @@ -46,16 +37,6 @@ long PMPool::getRootAddr() { return (long)pmpool; } -void PMPool::process() { - Request *cur_req; - while(!stop) { - cur_req = (Request*)request_queue.dequeue(); - if (cur_req != nullptr) { - cur_req->exec(); - } - } -} - long PMPool::setMapPartition( int partitionNum, int stageId, @@ -66,7 +47,8 @@ long PMPool::setMapPartition( bool clean, int numMaps) { WriteRequest write_request(this, maxStage, numMaps, partitionNum, stageId, 0, mapId, partitionId, size, data, clean); - request_queue.enqueue((void*)&write_request); + std::lock_guard lk(mtx); + write_request.exec(); return write_request.getResult(); } @@ -79,7 +61,8 @@ long PMPool::setReducePartition( bool clean, int numMaps) { WriteRequest write_request(this, maxStage, 1, partitionNum, stageId, 1, 0, partitionId, size, data, clean); - request_queue.enqueue((void*)&write_request); + std::lock_guard lk(mtx); + write_request.exec(); return write_request.getResult(); } @@ -89,6 +72,7 @@ long PMPool::getMapPartition( int mapId, int partitionId ) { ReadRequest read_request(this, mb, stageId, 0, mapId, partitionId); + std::lock_guard lk(mtx); read_request.exec(); return read_request.getResult(); } @@ -99,6 +83,7 @@ long PMPool::getReducePartition( int mapId, int partitionId ) { ReadRequest read_request(this, mb, stageId, 1, mapId, partitionId); + std::lock_guard lk(mtx); read_request.exec(); read_request.getResult(); return 0; @@ -106,18 +91,21 @@ long PMPool::getReducePartition( long PMPool::getMapPartitionBlockInfo(BlockInfo *blockInfo, int stageId, int mapId, int partitionId) { MetaRequest meta_request(this, blockInfo, stageId, 0, mapId, partitionId); + std::lock_guard lk(mtx); meta_request.exec(); return meta_request.getResult(); } long PMPool::getReducePartitionBlockInfo(BlockInfo *blockInfo, int stageId, int mapId, int partitionId) { MetaRequest meta_request(this, blockInfo, stageId, 1, mapId, partitionId); + std::lock_guard lk(mtx); meta_request.exec(); return meta_request.getResult(); } long PMPool::getMapPartitionSize(int stageId, int mapId, int partitionId) { SizeRequest size_request(this, stageId, 0, mapId, partitionId); + std::lock_guard lk(mtx); size_request.exec(); return size_request.getResult(); } @@ -130,13 +118,15 @@ long PMPool::getReducePartitionSize(int stageId, int mapId, int partitionId) { long PMPool::deleteMapPartition(int stageId, int mapId, int partitionId) { DeleteRequest delete_request(this, stageId, 0, mapId, partitionId); - request_queue.enqueue((void*)&delete_request); + std::lock_guard lk(mtx); + delete_request.exec(); return delete_request.getResult(); } long PMPool::deleteReducePartition(int stageId, int mapId, int partitionId) { DeleteRequest delete_request(this, stageId, 1, mapId, partitionId); - request_queue.enqueue((void*)&delete_request); + std::lock_guard lk(mtx); + delete_request.exec(); return delete_request.getResult(); } diff --git a/native/src/PersistentMemoryPool.h b/native/src/PersistentMemoryPool.h index a48e3d75..7a6af087 100644 --- a/native/src/PersistentMemoryPool.h +++ b/native/src/PersistentMemoryPool.h @@ -33,14 +33,13 @@ using namespace std; +#include +#include + class PMPool { public: PMEMobjpool *pmpool; - std::thread worker; - WorkQueue request_queue; - bool stop; - TOID(struct StageArrayRoot) stageArrayRoot; int maxStage; int maxMap; @@ -65,7 +64,7 @@ class PMPool { long deleteMapPartition(int stageId, int mapId, int partitionId); long deleteReducePartition(int stageId, int mapId, int partitionId); private: - void process(); + std::mutex mtx; }; #endif diff --git a/native/src/PmemBuffer.h b/native/src/PmemBuffer.h index 30ca58e4..47d0a1e2 100644 --- a/native/src/PmemBuffer.h +++ b/native/src/PmemBuffer.h @@ -26,25 +26,37 @@ class PmemBuffer { } int load(char* pmem_data_addr, int pmem_data_len) { + if (pmem_data_addr == nullptr || pmem_data_len == 0) + return 0; std::lock_guard lock(buffer_mtx); if (buf_data_capacity == 0 && pmem_data_len > 0) { + buf_data_capacity = pmem_data_len; buf_data = (char*)malloc(sizeof(char) * pmem_data_len); } - buf_data_capacity = remaining + pmem_data_len; - if (remaining > 0 && buf_data_capacity > 0) { + if (remaining > 0) { + buf_data_capacity = remaining + pmem_data_len; char* tmp_buf_data = buf_data; buf_data = (char*)malloc(sizeof(char) * buf_data_capacity); if (buf_data != nullptr && tmp_buf_data != nullptr) { memcpy(buf_data, tmp_buf_data + pos, remaining); } free(tmp_buf_data); - } - - pos = remaining; - if (buf_data != nullptr && pmem_data_addr != nullptr) { + pos = remaining; memcpy(buf_data + pos, pmem_data_addr, pmem_data_len); - } + } else if (remaining == 0) { + if (buf_data_capacity < pmem_data_len) { + free(buf_data); + buf_data_capacity = pmem_data_len; + buf_data = (char*)malloc(sizeof(char) * buf_data_capacity); + } + if (buf_data != nullptr) { + memcpy(buf_data, pmem_data_addr, pmem_data_len); + } + } else { + + } + remaining += pmem_data_len; pos = 0; pos_dirty = pos + remaining; diff --git a/native/src/Request.cpp b/native/src/Request.cpp index fa253cd0..2044dc5e 100644 --- a/native/src/Request.cpp +++ b/native/src/Request.cpp @@ -5,6 +5,7 @@ #include #include #include +#include /****** Request ******/ TOID(struct PartitionArrayItem) Request::getPartitionBlock() { @@ -151,11 +152,6 @@ void WriteRequest::exec() { } long WriteRequest::getResult() { - while (!processed) { - usleep(5); - } - //cv.wait(lck, [&]{return processed;}); - //fprintf(stderr, "get Result for %d_%d_%d\n", stageId, mapId, partitionId); return (long)data_addr; } @@ -201,22 +197,18 @@ void WriteRequest::setPartition() { D_RW(partitionArrayItem)->last_block = *partitionBlock; data_addr = (char*)pmemobj_direct(D_RW(*partitionBlock)->data); - //printf("setPartition data_addr: %p\n", data_addr); pmemobj_tx_add_range_direct((const void *)data_addr, size); memcpy(data_addr, data, size); } TX_ONCOMMIT { committed = true; - block_cv.notify_all(); } TX_ONABORT { fprintf(stderr, "set Partition of %d_%d_%d failed, type is %d, partitionNum is %d, maxStage is %d, maxMap is %d. Error: %s\n", stageId, mapId, partitionId, typeId, partitionNum, maxStage, maxMap, pmemobj_errormsg()); exit(-1); } TX_END - - block_cv.wait(block_lck, [&]{return committed;}); - - processed = true; - //cv.notify_all(); + if (!committed) { + assert(0 == "content not committed."); + } } /****** ReadRequest ******/ @@ -240,14 +232,11 @@ void ReadRequest::getPartition() { char* data_addr; while(!TOID_IS_NULL(partitionBlock)) { data_addr = (char*)pmemobj_direct(D_RO(partitionBlock)->data); - //printf("getPartition data_addr: %p\n", data_addr); memcpy(mb->buf + off, data_addr, D_RO(partitionBlock)->data_size); off += D_RO(partitionBlock)->data_size; partitionBlock = D_RO(partitionBlock)->next_block; } - - //printf("getPartition length is %d\n", data_length); } /****** MetaRequest ******/ @@ -299,9 +288,6 @@ void DeleteRequest::exec() { } long DeleteRequest::getResult() { - while (!processed) { - usleep(5); - } return ret; } @@ -326,13 +312,8 @@ void DeleteRequest::deletePartition() { D_RW(partitionArrayItem)->numBlocks = 0; } TX_ONCOMMIT { committed = true; - block_cv.notify_all(); } TX_ONABORT { fprintf(stderr, "delete Partition of %d_%d_%d failed, type is %d. Error: %s\n", stageId, mapId, partitionId, typeId, pmemobj_errormsg()); exit(-1); } TX_END - - block_cv.wait(block_lck, [&]{return committed;}); - - processed = true; } diff --git a/native/src/Request.h b/native/src/Request.h index 17173f86..1576ec6c 100644 --- a/native/src/Request.h +++ b/native/src/Request.h @@ -4,6 +4,8 @@ #include #include #include +#include +#include #define TOID_ARRAY_TYPE(x) TOID(x) #define TOID_ARRAY(x) TOID_ARRAY_TYPE(TOID(x)) @@ -119,25 +121,13 @@ class Request { typeId(typeId), mapId(mapId), partitionId(partitionId), - processed(false), - committed(false), - lck(mtx), block_lck(block_mtx) { - } + committed(false) {} ~Request(){} virtual void exec() = 0; virtual long getResult() = 0; protected: - // add lock to make this request blocked - std::mutex mtx; - std::condition_variable cv; - bool processed; - std::unique_lock lck; - // add lock to make func blocked - std::mutex block_mtx; - std::condition_variable block_cv; bool committed; - std::unique_lock block_lck; int stageId; int typeId; From 8c9c445fd397019581fd5550c05db1af1508fecc Mon Sep 17 00:00:00 2001 From: Haodong Tang Date: Fri, 18 Oct 2019 19:18:57 +0800 Subject: [PATCH 3/5] fix rdma read concurrent issue Signed-off-by: Haodong Tang --- .../storage/pmof/PmemBlockObjectStream.scala | 4 +-- .../PmofShuffleBlockFetcherIterator.scala | 28 ++++++------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala index e0a6a11c..f7039ce1 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala @@ -61,8 +61,8 @@ private[spark] class PmemBlockObjectStream( var inputStream: InputStream = _ override def write(key: Any, value: Any): Unit = { - objStream.writeObject(key) - objStream.writeObject(value) + objStream.writeKey(key) + objStream.writeValue(value) records += 1 recordsPerBlock += 1 if (blockId.isShuffle == true) { diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala index e5a52ba0..84345025 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage.pmof import java.io.{File, IOException, InputStream} import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} import javax.annotation.concurrent.GuardedBy import org.apache.spark.internal.Logging @@ -104,13 +105,10 @@ final class RdmaShuffleBlockFetcherIterator( @volatile private[this] var currentResult: SuccessFetchResult = _ /** Current bytes in flight from our requests */ - private[this] var bytesInFlight = 0L + private[this] var bytesInFlight = new AtomicLong(0) /** Current number of requests in flight */ - private[this] var reqsInFlight = 0 - - /** Current number of blocks in flight per host:port */ - private[this] val numBlocksInFlightPerAddress = new mutable.HashMap[BlockManagerId, Int]() + private[this] var reqsInFlight = new AtomicInteger(0) private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() @@ -170,14 +168,11 @@ final class RdmaShuffleBlockFetcherIterator( def sendRequest(rdmaRequest: RdmaRequest): Unit = { val shuffleBlockInfos = rdmaRequest.shuffleBlockInfos var blockNums= shuffleBlockInfos.size - bytesInFlight += rdmaRequest.reqSize - reqsInFlight += 1 + bytesInFlight.addAndGet(rdmaRequest.reqSize) + reqsInFlight.incrementAndGet val blockManagerId = rdmaRequest.blockManagerId val shuffleBlockIdName = rdmaRequest.shuffleBlockIdName - numBlocksInFlightPerAddress(blockManagerId) = - numBlocksInFlightPerAddress.getOrElse(blockManagerId, 0) + 1 - val pmofTransferService = shuffleClient.asInstanceOf[PmofTransferService] val blockFetchingReadCallback = new ReadCallback { @@ -214,7 +209,7 @@ final class RdmaShuffleBlockFetcherIterator( } def isRemoteBlockFetchable(rdmaRequest: RdmaRequest): Boolean = { - reqsInFlight + 1 <= maxReqsInFlight && bytesInFlight + rdmaRequest.reqSize <= maxBytesInFlight + reqsInFlight.get + 1 <= maxReqsInFlight && bytesInFlight.get + rdmaRequest.reqSize <= maxBytesInFlight } def fetchRemoteBlocks(): Unit = { @@ -362,8 +357,6 @@ final class RdmaShuffleBlockFetcherIterator( result match { case SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - if (numBlocksInFlightPerAddress.contains(address)) - numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { shuffleMetrics.incRemoteBytesReadToDisk(buf.size) @@ -371,9 +364,9 @@ final class RdmaShuffleBlockFetcherIterator( shuffleMetrics.incRemoteBlocksFetched(1) logDebug("take remote block.") } - bytesInFlight -= size + bytesInFlight.addAndGet(-size) if (isNetworkReqDone) { - reqsInFlight -= 1 + reqsInFlight.decrementAndGet } val in = try { @@ -402,11 +395,6 @@ final class RdmaShuffleBlockFetcherIterator( (currentResult.blockId, new RDMABufferReleasingInputStream(input, this)) } - def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { - numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > - maxBlocksInFlightPerAddress - } - private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => From b982bd76d6c6d4137a838951fd2aac63055e4127 Mon Sep 17 00:00:00 2001 From: Haodong Tang Date: Tue, 22 Oct 2019 10:37:31 +0800 Subject: [PATCH 4/5] reset pmem outputstream after flush operation Signed-off-by: Haodong Tang --- .../scala/org/apache/spark/storage/pmof/PmemOutputStream.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala index 9b65033b..1f0a3d72 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala @@ -35,6 +35,7 @@ class PmemOutputStream( override def flush(): Unit = { if (size() > 0) { persistentMemoryWriter.setPartition(numPartitions, blockId, byteBuffer, size(), set_clean, numMaps) + reset() } if (set_clean) { set_clean = false From 511df2957bda4ea97cc6fe72fd4eb4162c655bd8 Mon Sep 17 00:00:00 2001 From: Tang Haodong Date: Wed, 23 Oct 2019 00:39:57 +0800 Subject: [PATCH 5/5] code refactor Signed-off-by: Haodong Tang --- .../shuffle/sort/SerializedShuffleWriter.java | 545 ------------------ .../apache/spark/storage/pmof/PmemBuffer.java | 10 +- .../spark/network/pmof/ClientFactory.scala | 36 +- .../network/pmof/PmofTransferService.scala | 44 +- .../apache/spark/network/pmof/Server.scala | 33 +- ...leReader.scala => BaseShuffleReader.scala} | 32 +- .../shuffle/pmof/BaseShuffleWriter.scala | 41 +- .../spark/shuffle/pmof/MetadataResolver.scala | 132 +++-- .../pmof/PmemShuffleBlockResolver.scala | 4 +- .../shuffle/pmof/PmemShuffleWriter.scala | 102 ++-- .../shuffle/pmof/PmofShuffleManager.scala | 76 +-- .../shuffle/pmof/RdmaShuffleReader.scala | 27 +- .../pmof/PersistentMemoryHandler.scala | 17 +- .../storage/pmof/PmemBlockInputStream.scala | 53 ++ ...ream.scala => PmemBlockOutputStream.scala} | 63 +- .../spark/storage/pmof/PmemInputStream.scala | 8 +- .../storage/pmof/PmemManagedBuffer.scala | 2 +- .../spark/storage/pmof/PmemOutputStream.scala | 19 +- ... => RdmaShuffleBlockFetcherIterator.scala} | 385 ++++++------- .../collection/pmof/PmemExternalSorter.scala | 128 ++-- .../util/configuration/pmof/PmofConf.scala | 29 + .../shuffle/pmof/PmemShuffleWriterSuite.scala | 10 +- .../pmof/PmemShuffleWriterWithSortSuite.scala | 10 +- .../pmof/PmemBlockObjectStreamSuite.scala | 96 --- 24 files changed, 629 insertions(+), 1273 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java rename core/src/main/scala/org/apache/spark/shuffle/pmof/{PmemShuffleReader.scala => BaseShuffleReader.scala} (81%) create mode 100644 core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala rename core/src/main/scala/org/apache/spark/storage/pmof/{PmemBlockObjectStream.scala => PmemBlockOutputStream.scala} (65%) rename core/src/main/scala/org/apache/spark/storage/pmof/{PmofShuffleBlockFetcherIterator.scala => RdmaShuffleBlockFetcherIterator.scala} (85%) create mode 100644 core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala delete mode 100644 core/src/test/scala/org/apache/spark/storage/pmof/PmemBlockObjectStreamSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java deleted file mode 100644 index 76ef0cb2..00000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java +++ /dev/null @@ -1,545 +0,0 @@ -package org.apache.spark.shuffle.sort; - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import javax.annotation.Nullable; -import java.io.*; -import java.nio.channels.FileChannel; -import java.util.Iterator; - -import org.apache.spark.shuffle.pmof.MetadataResolver; -import org.apache.spark.storage.BlockManagerId; -import org.apache.spark.storage.BlockManagerId$; -import org.apache.spark.network.pmof.PmofTransferService; -import scala.Option; -import scala.Product2; -import scala.collection.JavaConverters; -import scala.reflect.ClassTag; -import scala.reflect.ClassTag$; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.io.ByteStreams; -import com.google.common.io.Closeables; -import com.google.common.io.Files; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.*; -import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.io.CompressionCodec; -import org.apache.spark.io.CompressionCodec$; -import org.apache.spark.io.NioBufferedFileInputStream; -import org.apache.commons.io.output.CloseShieldOutputStream; -import org.apache.commons.io.output.CountingOutputStream; -import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.network.util.LimitedInputStream; -import org.apache.spark.scheduler.MapStatus; -import org.apache.spark.scheduler.MapStatus$; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleWriter; -import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.Utils; -import org.apache.spark.internal.config.package$; - -@Private -public class SerializedShuffleWriter extends ShuffleWriter { - - private static final Logger logger = LoggerFactory.getLogger(SerializedShuffleWriter.class); - - private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); - - @VisibleForTesting - static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; - static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; - - private final BlockManager blockManager; - private final IndexShuffleBlockResolver shuffleBlockResolver; - private final TaskMemoryManager memoryManager; - private final SerializerInstance serializer; - private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; - private final int shuffleId; - private final int mapId; - private final TaskContext taskContext; - private final SparkConf sparkConf; - private final boolean transferToEnabled; - private final int initialSortBufferSize; - private final int inputBufferSizeInBytes; - private final int outputBufferSizeInBytes; - private final MetadataResolver metadataResolver; - private final boolean enable_rdma; - - @Nullable private MapStatus mapStatus; - @Nullable private ShuffleExternalSorter sorter; - private long peakMemoryUsedBytes = 0; - - /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ - private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { - MyByteArrayOutputStream(int size) { super(size); } - public byte[] getBuf() { return buf; } - } - - private MyByteArrayOutputStream serBuffer; - private SerializationStream serOutputStream; - - /** - * Are we in the process of stopping? Because map tasks can call stop() with success = true - * and then call stop() with success = false if they get an exception, we want to make sure - * we don't try deleting files, etc twice. - */ - private boolean stopping = false; - - private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { - - CloseAndFlushShieldOutputStream(OutputStream outputStream) { - super(outputStream); - } - - @Override - public void flush() { - // do nothing - } - } - - public SerializedShuffleWriter( - BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, - MetadataResolver metadataResolver, - TaskMemoryManager memoryManager, - SerializedShuffleHandle handle, - int mapId, - TaskContext taskContext, - SparkConf sparkConf, - boolean enable_rdma) throws IOException { - final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { - throw new IllegalArgumentException( - "UnsafeShuffleWriter can only be used for shuffles with at most " + - SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + - " reduce partitions"); - } - this.blockManager = blockManager; - this.shuffleBlockResolver = shuffleBlockResolver; - this.memoryManager = memoryManager; - this.mapId = mapId; - final ShuffleDependency dep = handle.dependency(); - this.shuffleId = dep.shuffleId(); - this.serializer = dep.serializer().newInstance(); - this.partitioner = dep.partitioner(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); - this.taskContext = taskContext; - this.sparkConf = sparkConf; - this.enable_rdma = enable_rdma; - this.metadataResolver = metadataResolver; - this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); - this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", - DEFAULT_INITIAL_SORT_BUFFER_SIZE); - this.inputBufferSizeInBytes = - Integer.parseInt(sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()).toString()) * 1024; - this.outputBufferSizeInBytes = - Integer.parseInt(sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()).toString()) * 1024; - open(); - } - - private void updatePeakMemoryUsed() { - // sorter can be null if this writer is closed - if (sorter != null) { - long mem = sorter.getPeakMemoryUsedBytes(); - if (mem > peakMemoryUsedBytes) { - peakMemoryUsedBytes = mem; - } - } - } - - /** - * Return the peak memory used so far, in bytes. - */ - public long getPeakMemoryUsedBytes() { - updatePeakMemoryUsed(); - return peakMemoryUsedBytes; - } - - /** - * This convenience method should only be called in test code. - */ - @VisibleForTesting - public void write(Iterator> records) throws IOException { - write(JavaConverters.asScalaIteratorConverter(records).asScala()); - } - - @Override - public void write(scala.collection.Iterator> records) throws IOException { - // Keep track of success so we know if we encountered an exception - // We do this rather than a standard try/catch/re-throw to handle - // generic throwables. - boolean success = false; - try { - while (records.hasNext()) { - insertRecordIntoSorter(records.next()); - } - closeAndWriteOutput(); - success = true; - } finally { - if (sorter != null) { - try { - sorter.cleanupResources(); - } catch (Exception e) { - // Only throw this error if we won't be masking another - // error. - if (success) { - try { - throw e; - } catch (Exception e1) { - e1.printStackTrace(); - } - } else { - logger.error("In addition to a failure during writing, we failed during " + - "cleanup.", e); - } - } - } - } - } - - private void open() { - assert (sorter == null); - sorter = new ShuffleExternalSorter( - memoryManager, - blockManager, - taskContext, - initialSortBufferSize, - partitioner.numPartitions(), - sparkConf, - writeMetrics); - serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); - serOutputStream = serializer.serializeStream(serBuffer); - } - - @VisibleForTesting - void closeAndWriteOutput() throws IOException { - assert(sorter != null); - updatePeakMemoryUsed(); - serBuffer = null; - serOutputStream = null; - final SpillInfo[] spills = sorter.closeAndGetSpills(); - sorter = null; - final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - final File tmp = Utils.tempFileWith(output); - try { - try { - partitionLengths = mergeSpills(spills, tmp); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists()) { - if(!spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); - } - } - } - } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - if (enable_rdma) - metadataResolver.commitBlockInfo(shuffleId, mapId, partitionLengths); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); - } - } - BlockManagerId shuffleServerId = blockManager.shuffleServerId(); - if (enable_rdma) { - BlockManagerId blockManagerId = BlockManagerId$.MODULE$.apply(shuffleServerId.executorId(), PmofTransferService.shuffleNodesMap().get(shuffleServerId.host()).get(), - PmofTransferService.getTransferServiceInstance(blockManager, null, false).port(), shuffleServerId.topologyInfo()); - mapStatus = MapStatus$.MODULE$.apply(blockManagerId, partitionLengths); - } else { - mapStatus = MapStatus$.MODULE$.apply(shuffleServerId, partitionLengths); - } - } - - @VisibleForTesting - void insertRecordIntoSorter(Product2 record) throws IOException { - assert(sorter != null); - final K key = record._1(); - final int partitionId = partitioner.getPartition(key); - serBuffer.reset(); - serOutputStream.writeKey(key, OBJECT_CLASS_TAG); - serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); - serOutputStream.flush(); - - final int serializedRecordSize = serBuffer.size(); - assert (serializedRecordSize > 0); - - sorter.insertRecord( - serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); - } - - @VisibleForTesting - void forceSorterToSpill() throws IOException { - assert (sorter != null); - sorter.spill(); - } - - /** - * Merge zero or more spill files together, choosing the fastest merging strategy based on the - * number of spills and the IO compression codec. - * - * @return the partition lengths in the merged file. - */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { - final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); - final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); - final boolean fastMergeEnabled = - sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); - final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); - final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); - try { - if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; - } else if (spills.length == 1) { - // Here, we don't need to perform any metrics updates because the bytes written to this - // output file would have already been counted as shuffle bytes written. - Files.move(spills[0].file, outputFile); - return spills[0].partitionLengths; - } else { - final long[] partitionLengths; - // There are multiple spills to merge, so none of these spill files' lengths were counted - // towards our shuffle write count or shuffle write time. If we use the slow merge path, - // then the final output file's size won't necessarily be equal to the sum of the spill - // files' sizes. To guard against this case, we look at the output file's actual size when - // computing shuffle bytes written. - // - // We allow the individual merge methods to report their own IO times since different merge - // strategies use different IO techniques. We count IO during merge towards the shuffle - // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" - // branch in ExternalSorter. - if (fastMergeEnabled && fastMergeIsSupported) { - // Compression is disabled or we are using an IO compression codec that supports - // decompression of concatenated compressed streams, so we can perform a fast spill merge - // that doesn't need to interpret the spilled bytes. - if (transferToEnabled && !encryptionEnabled) { - logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); - } else { - logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); - } - } else { - logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); - } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incBytesWritten(outputFile.length()); - return partitionLengths; - } - } catch (IOException e) { - if (outputFile.exists() && !outputFile.delete()) { - logger.error("Unable to delete output file {}", outputFile.getPath()); - } - throw e; - } - } - - /** - * Merges spill files using Java FileStreams. This code path is typically slower than - * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], - * File)}, and it's mostly used in cases where the IO compression codec does not support - * concatenation of compressed data, when encryption is enabled, or when users have - * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. - * This code path might also be faster in cases where individual partition size in a spill - * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small - * disk ios which is inefficient. In those case, Using large buffers for input and output - * files helps reducing the number of disk ios, making the file merging faster. - * - * @param spills the spills to merge. - * @param outputFile the file to write the merged data to. - * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. - * @return the partition lengths in the merged file. - */ - private long[] mergeSpillsWithFileStream( - SpillInfo[] spills, - File outputFile, - @Nullable CompressionCodec compressionCodec) throws IOException { - assert (spills.length >= 2); - final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; - final InputStream[] spillInputStreams = new InputStream[spills.length]; - - final OutputStream bos = new BufferedOutputStream( - new FileOutputStream(outputFile), - outputBufferSizeInBytes); - // Use a counting output stream to avoid having to close the underlying file and ask - // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - - boolean threwException = true; - try { - for (int i = 0; i < spills.length; i++) { - spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); - } - for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close - // the higher level streams to make sure all data is really flushed and internal state is - // cleaned. - OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( - new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - try { - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); - } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); - } - } - } - partitionOutput.flush(); - partitionOutput.close(); - partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); - } - threwException = false; - } finally { - // To avoid masking exceptions that caused us to prematurely enter the finally block, only - // throw exceptions during cleanup if threwException == false. - for (InputStream stream : spillInputStreams) { - Closeables.close(stream, threwException); - } - Closeables.close(mergedFileOutputStream, threwException); - } - return partitionLengths; - } - - /** - * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. - * This is only safe when the IO compression codec and serializer support concatenation of - * serialized streams. - * - * @return the partition lengths in the merged file. - */ - private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { - assert (spills.length >= 2); - final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; - final FileChannel[] spillInputChannels = new FileChannel[spills.length]; - final long[] spillInputChannelPositions = new long[spills.length]; - FileChannel mergedFileOutputChannel = null; - - boolean threwException = true; - try { - for (int i = 0; i < spills.length; i++) { - spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); - } - // This file needs to opened in append mode in order to work around a Linux kernel bug that - // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); - - long bytesWrittenToMergedFile = 0; - for (int partition = 0; partition < numPartitions; partition++) { - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileChannel spillInputChannel = spillInputChannels[i]; - final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - mergedFileOutputChannel, - spillInputChannelPositions[i], - partitionLengthInSpill); - spillInputChannelPositions[i] += partitionLengthInSpill; - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - bytesWrittenToMergedFile += partitionLengthInSpill; - partitionLengths[partition] += partitionLengthInSpill; - } - } - // Check the position after transferTo loop to see if it is in the right position and raise an - // exception if it is incorrect. The position will not be increased to the expected length - // after calling transferTo in kernel version 2.6.32. This issue is described at - // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { - throw new IOException( - "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + - " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + - "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + - "to disable this NIO feature." - ); - } - threwException = false; - } finally { - // To avoid masking exceptions that caused us to prematurely enter the finally block, only - // throw exceptions during cleanup if threwException == false. - for (int i = 0; i < spills.length; i++) { - assert(spillInputChannelPositions[i] == spills[i].file.length()); - Closeables.close(spillInputChannels[i], threwException); - } - Closeables.close(mergedFileOutputChannel, threwException); - } - return partitionLengths; - } - - @Override - public Option stop(boolean success) { - try { - taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); - - if (stopping) { - return null; - } else { - stopping = true; - if (success) { - if (mapStatus == null) { - throw new IllegalStateException("Cannot call stop(true) without having called write()"); - } - return Option.apply(mapStatus); - } else { - return null; - } - } - } finally { - if (sorter != null) { - // If sorter is non-null, then this implies that we called stop() in response to an error, - // so we need to clean up memory and spill files created by the sorter - sorter.cleanupResources(); - } - } - } -} diff --git a/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java b/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java index 557aff2e..40d5b0f5 100644 --- a/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java +++ b/core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java @@ -13,7 +13,7 @@ public class PmemBuffer { private native long nativeGetPmemBufferDataAddr(long pmBuffer); private native long nativeDeletePmemBuffer(long pmBuffer); - private boolean closed = false; + private boolean closed = false; long pmBuffer; PmemBuffer() { pmBuffer = nativeNewPmemBuffer(); @@ -55,9 +55,9 @@ long getDirectAddr() { } synchronized void close() { - if (!closed) { - nativeDeletePmemBuffer(pmBuffer); - closed = true; - } + if (!closed) { + nativeDeletePmemBuffer(pmBuffer); + closed = true; + } } } diff --git a/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala b/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala index 4af1fa10..db2c9161 100644 --- a/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala @@ -5,25 +5,17 @@ import java.nio.ByteBuffer import java.util.concurrent.ConcurrentHashMap import com.intel.hpnl.core._ -import org.apache.spark.SparkConf import org.apache.spark.shuffle.pmof.PmofShuffleManager +import org.apache.spark.util.configuration.pmof.PmofConf -import scala.collection.mutable.ArrayBuffer - -class ClientFactory(conf: SparkConf) { - final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE - final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16) - final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) - - final val eqService = new EqService(workers, BUFFER_NUM, false).init() - final val cqService = new CqService(eqService).init() - - final val conArray: ArrayBuffer[Connection] = ArrayBuffer() - final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]() - final val conMap = new ConcurrentHashMap[Connection, Client]() +class ClientFactory(pmofConf: PmofConf) { + final val eqService = new EqService(pmofConf.clientWorkerNums, pmofConf.clientBufferNums, false).init() + private[this] final val cqService = new CqService(eqService).init() + private[this] final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]() + private[this] final val conMap = new ConcurrentHashMap[Connection, Client]() def init(): Unit = { - eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2) + eqService.initBufferPool(pmofConf.clientBufferNums, pmofConf.networkBufferSize, pmofConf.clientBufferNums * 2) val clientRecvHandler = new ClientRecvHandler val clientReadHandler = new ClientReadHandler eqService.setRecvCallback(clientRecvHandler) @@ -62,16 +54,16 @@ class ClientFactory(conf: SparkConf) { class ClientRecvHandler() extends Handler { override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { - val buffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId) - val rpcMessage: ByteBuffer = buffer.get(blockBufferSize) - val seq = buffer.getSeq - val msgType = buffer.getType + val hpnlBuffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId) + val byteBuffer: ByteBuffer = hpnlBuffer.get(blockBufferSize) + val seq = hpnlBuffer.getSeq + val msgType = hpnlBuffer.getType val callback = conMap.get(con).outstandingReceiveFetches.get(seq) - if (msgType == 0.toByte) { + if (msgType == 0.toByte) { // get ACK from driver, which means the block info has been saved to driver memory callback.onSuccess(null) - } else { + } else { // get block info from driver, and deserialize the info to scala object val metadataResolver = conMap.get(con).shuffleManager.metadataResolver - val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(rpcMessage) + val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(byteBuffer) callback.onSuccess(blockInfoArray) } } diff --git a/core/src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala b/core/src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala index c158198e..268c17e2 100644 --- a/core/src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala @@ -8,16 +8,16 @@ import org.apache.spark.network.BlockDataManager import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.shuffle.pmof.{MetadataResolver, PmofShuffleManager} import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId} -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.util.configuration.pmof.PmofConf import scala.collection.mutable -class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManager, +class PmofTransferService(val pmofConf: PmofConf, val shuffleManager: PmofShuffleManager, val hostname: String, var port: Int) extends TransferService { + private[this] final val metadataResolver: MetadataResolver = this.shuffleManager.metadataResolver final var server: Server = _ - final private var clientFactory: ClientFactory = _ - private var nextReqId: AtomicLong = _ - final val metadataResolver: MetadataResolver = this.shuffleManager.metadataResolver + private[this] final var clientFactory: ClientFactory = _ + private[this] var nextReqId: AtomicLong = _ override def fetchBlocks(host: String, port: Int, @@ -33,12 +33,12 @@ class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage } def fetchBlockInfo(blockIds: Array[BlockId], receivedCallback: ReceivedCallback): Unit = { - val shuffleBlockIds = blockIds.map(blockId=>blockId.asInstanceOf[ShuffleBlockId]) + val shuffleBlockIds = blockIds.map(blockId => blockId.asInstanceOf[ShuffleBlockId]) metadataResolver.fetchBlockInfo(shuffleBlockIds, receivedCallback) } - def syncBlocksInfo(host: String, port: Int, byteBuffer: ByteBuffer, msgType: Byte, - callback: ReceivedCallback): Unit = { + def pushBlockInfo(host: String, port: Int, byteBuffer: ByteBuffer, msgType: Byte, + callback: ReceivedCallback): Unit = { clientFactory.createClient(shuffleManager, host, port). send(byteBuffer, nextReqId.getAndIncrement(), msgType, callback, isDeferred = false) } @@ -59,8 +59,8 @@ class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage } def init(): Unit = { - this.server = new Server(conf, shuffleManager, hostname, port) - this.clientFactory = new ClientFactory(conf) + this.server = new Server(pmofConf, shuffleManager, hostname, port) + this.clientFactory = new ClientFactory(pmofConf) this.server.init() this.server.start() this.clientFactory.init() @@ -73,30 +73,24 @@ class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage } object PmofTransferService { - final val env: SparkEnv = SparkEnv.get - final val conf: SparkConf = env.conf - final val CHUNKSIZE: Int = conf.getInt("spark.shuffle.pmof.chunk_size", 4096*3) - final val driverHost: String = conf.get("spark.driver.rhost", defaultValue = "172.168.0.43") - final val driverPort: Int = conf.getInt("spark.driver.rport", defaultValue = 61000) - final val shuffleNodes: Array[Array[String]] = - conf.get("spark.shuffle.pmof.node", defaultValue = "").split(",").map(_.split("-")) final val shuffleNodesMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() - for (array <- shuffleNodes) { - shuffleNodesMap.put(array(0), array(1)) - } - private val initialized = new AtomicBoolean(false) - private var transferService: PmofTransferService = _ - def getTransferServiceInstance(blockManager: BlockManager, shuffleManager: PmofShuffleManager = null, + private[this] final val initialized = new AtomicBoolean(false) + private[this] var transferService: PmofTransferService = _ + + def getTransferServiceInstance(pmofConf: PmofConf, blockManager: BlockManager, shuffleManager: PmofShuffleManager = null, isDriver: Boolean = false): PmofTransferService = { if (!initialized.get()) { PmofTransferService.this.synchronized { if (initialized.get()) return transferService if (isDriver) { transferService = - new PmofTransferService(conf, shuffleManager, driverHost, driverPort) + new PmofTransferService(pmofConf, shuffleManager, pmofConf.driverHost, pmofConf.driverPort) } else { + for (array <- pmofConf.shuffleNodes) { + shuffleNodesMap.put(array(0), array(1)) + } transferService = - new PmofTransferService(conf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), 0) + new PmofTransferService(pmofConf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), 0) } transferService.init() initialized.set(true) diff --git a/core/src/main/scala/org/apache/spark/network/pmof/Server.scala b/core/src/main/scala/org/apache/spark/network/pmof/Server.scala index 056f6593..5aa6933f 100644 --- a/core/src/main/scala/org/apache/spark/network/pmof/Server.scala +++ b/core/src/main/scala/org/apache/spark/network/pmof/Server.scala @@ -4,25 +4,22 @@ import java.nio.ByteBuffer import java.util import com.intel.hpnl.core._ -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.shuffle.pmof.PmofShuffleManager +import org.apache.spark.util.configuration.pmof.PmofConf -class Server(conf: SparkConf, val shuffleManager: PmofShuffleManager, address: String, var port: Int) { +class Server(pmofConf: PmofConf, val shuffleManager: PmofShuffleManager, address: String, var port: Int) { if (port == 0) { port = Utils.getPort } - final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE - final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.server_buffer_nums", 256) - final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) - final val eqService = new EqService(workers, BUFFER_NUM, true).init() - final val cqService = new CqService(eqService).init() + private[this] final val eqService = new EqService(pmofConf.serverWorkerNums, pmofConf.serverBufferNums, true).init() + private[this] final val cqService = new CqService(eqService).init() - val conList = new util.ArrayList[Connection]() + private[this] final val conList = new util.ArrayList[Connection]() def init(): Unit = { - eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2) + eqService.initBufferPool(pmofConf.serverBufferNums, pmofConf.networkBufferSize, pmofConf.serverBufferNums * 2) val recvHandler = new ServerRecvHandler(this) val connectedHandler = new ServerConnectedHandler(this) eqService.setConnectedCallback(connectedHandler) @@ -62,17 +59,17 @@ class ServerRecvHandler(server: Server) extends Handler with Logging { } override def handle(con: Connection, bufferId: Int, blockBufferSize: Int): Unit = synchronized { - val buffer: HpnlBuffer = con.getRecvBuffer(bufferId) - val message: ByteBuffer = buffer.get(blockBufferSize) - val seq = buffer.getSeq - val msgType = buffer.getType + val hpnlBuffer: HpnlBuffer = con.getRecvBuffer(bufferId) + val byteBuffer: ByteBuffer = hpnlBuffer.get(blockBufferSize) + val seq = hpnlBuffer.getSeq + val msgType = hpnlBuffer.getType val metadataResolver = server.shuffleManager.metadataResolver - if (msgType == 0.toByte) { - metadataResolver.addShuffleBlockInfo(message) + if (msgType == 0.toByte) { // get block info message from executor, then save the info to memory + metadataResolver.saveShuffleBlockInfo(byteBuffer) sendMetadata(con, byteBufferTmp, 0.toByte, seq, isDeferred = false) - } else { - val bufferArray = metadataResolver.serializeShuffleBlockInfo(message) - for (buffer <- bufferArray) { + } else { // lookup block info from memory, then send the info to executor + val blockInfoArray = metadataResolver.serializeShuffleBlockInfo(byteBuffer) + for (buffer <- blockInfoArray) { sendMetadata(con, buffer, 1.toByte, seq, isDeferred = false) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleReader.scala similarity index 81% rename from core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleReader.scala index f9291a49..33304f41 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleReader.scala @@ -18,27 +18,28 @@ package org.apache.spark.shuffle import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.{Logging, config} import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.pmof.PmemExternalSorter +import org.apache.spark.util.configuration.pmof.PmofConf /** - * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by - * requesting them from other nodes' block stores. - */ -private[spark] class PmemShuffleReader[K, C]( - handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, - context: TaskContext, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by + * requesting them from other nodes' block stores. + */ +private[spark] class BaseShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + pmofConf: PmofConf, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { - private val dep = handle.dependency + private[this] val dep = handle.dependency /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { @@ -97,10 +98,11 @@ private[spark] class PmemShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. val resultIter = dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => + assert(pmofConf.enablePmem == true) // Create an ExternalSorter to sort the data. val sorter = - new PmemExternalSorter[K, C, C](context, handle, ordering = Some(keyOrd), serializer = dep.serializer) - logDebug("call PmemExternalSorter.insertAll for shuffle_0_" + handle.shuffleId + "_[" + startPartition + "," + endPartition + "]" ) + new PmemExternalSorter[K, C, C](context, handle, pmofConf, ordering = Some(keyOrd), serializer = dep.serializer) + logDebug("call PmemExternalSorter.insertAll for shuffle_0_" + handle.shuffleId + "_[" + startPartition + "," + endPartition + "]") sorter.insertAll(aggregatedIter) // Use completion callback to stop sorter if task was finished/cancelled. context.addTaskCompletionListener(_ => { diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala index b264f442..b45a1090 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala @@ -25,31 +25,27 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, S import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.configuration.pmof.PmofConf -private[spark] class BaseShuffleWriter[K, V, C]( - shuffleBlockResolver: IndexShuffleBlockResolver, - metadataResolver: MetadataResolver, - handle: BaseShuffleHandle[K, V, C], - mapId: Int, - context: TaskContext, - enable_rdma: Boolean) +private[spark] class BaseShuffleWriter[K, V, C](shuffleBlockResolver: IndexShuffleBlockResolver, + metadataResolver: MetadataResolver, + handle: BaseShuffleHandle[K, V, C], + mapId: Int, + context: TaskContext, + pmofConf: PmofConf) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency private val blockManager = SparkEnv.get.blockManager - + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics private var sorter: ExternalSorter[K, V, _] = _ - // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure // we don't try deleting files, etc twice. private var stopping = false - private var mapStatus: MapStatus = _ - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { @@ -76,11 +72,11 @@ private[spark] class BaseShuffleWriter[K, V, C]( shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) val shuffleServerId = blockManager.shuffleServerId - if (enable_rdma) { - metadataResolver.commitBlockInfo(dep.shuffleId, mapId, partitionLengths) + if (pmofConf.enableRdma) { + metadataResolver.pushFileBlockInfo(dep.shuffleId, mapId, partitionLengths) val blockManagerId: BlockManagerId = - BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host), - PmofTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo) + BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host), + PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port, shuffleServerId.topologyInfo) mapStatus = MapStatus(blockManagerId, partitionLengths) } else { mapStatus = MapStatus(shuffleServerId, partitionLengths) @@ -115,16 +111,3 @@ private[spark] class BaseShuffleWriter[K, V, C]( } } } - -private[spark] object BaseShuffleWriter { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - false - } else { - val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala index ab371b9f..e40e575b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala @@ -8,51 +8,42 @@ import java.util.zip.{Deflater, DeflaterOutputStream, Inflater, InflaterInputStr import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{ByteBufferInputStream, ByteBufferOutputStream, Input, Output} -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.SparkEnv import org.apache.spark.network.pmof._ import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} +import org.apache.spark.util.configuration.pmof.PmofConf import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.{Breaks, ControlThrowable} -class MetadataResolver(conf: SparkConf) { - private lazy val blockManager = SparkEnv.get.blockManager - private lazy val blockMap: ConcurrentHashMap[String, ShuffleBuffer] = new ConcurrentHashMap[String, ShuffleBuffer]() - val driverHost: String = conf.get("spark.driver.rhost", defaultValue = "172.168.0.43") - val driverPort: Int = conf.getInt("spark.driver.rport", defaultValue = 61000) +/** + * This class is to handle shuffle block metadata + * It can be used by driver to store or lookup metadata + * and can be used by executor to send metadata to driver + * @param pmofConf + */ +class MetadataResolver(pmofConf: PmofConf) { + private[this] val blockManager = SparkEnv.get.blockManager + private[this] val blockMap: ConcurrentHashMap[String, ShuffleBuffer] = new ConcurrentHashMap[String, ShuffleBuffer]() + private[this] val info_serialize_stream = new Kryo() + private[this] val shuffleBlockInfoSerializer = new ShuffleBlockInfoSerializer + private[this] val shuffleBlockMap = new ConcurrentHashMap[String, ArrayBuffer[ShuffleBlockInfo]]() - var map_serializer_buffer_size = 0L - if (conf == null) { - map_serializer_buffer_size = 16 * 1024L - } - else { - map_serializer_buffer_size = conf.getLong("spark.shuffle.pmof.map_serializer_buffer_size", 16 * 1024) - } - - var reduce_serializer_buffer_size = 0L - if (conf == null) { - reduce_serializer_buffer_size = 16 * 1024L - } - else { - reduce_serializer_buffer_size = conf.getLong("spark.shuffle.pmof.reduce_serializer_buffer_size", 16 * 1024) - } - - val metadataCompress: Boolean = conf.getBoolean("spark.shuffle.pmof.metadata_compress", defaultValue = false) - - val shuffleBlockSize: Int = conf.getInt("spark.shuffle.pmof.shuffle_block_size", defaultValue = 2048) - - val info_serialize_stream = new Kryo() - val shuffleBlockInfoSerializer = new ShuffleBlockInfoSerializer info_serialize_stream.register(classOf[ShuffleBlockInfo], shuffleBlockInfoSerializer) - val shuffleBlockMap = new ConcurrentHashMap[String, ArrayBuffer[ShuffleBlockInfo]]() - - def commitPmemBlockInfo(shuffleId: Int, mapId: Int, dataAddressMap: mutable.HashMap[Int, Array[(Long, Int)]], rkey: Long): Unit = { - val buffer: Array[Byte] = new Array[Byte](reduce_serializer_buffer_size.toInt) + /** + * called by executor, send shuffle block metadata to driver when using persistent memory as shuffle device + * @param shuffleId + * @param mapId + * @param dataAddressMap + * @param rkey + */ + def pushPmemBlockInfo(shuffleId: Int, mapId: Int, dataAddressMap: mutable.HashMap[Int, Array[(Long, Int)]], rkey: Long): Unit = { + val buffer: Array[Byte] = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt) var output = new Output(buffer) val bufferArray = new ArrayBuffer[ByteBuffer]() @@ -62,13 +53,13 @@ class MetadataResolver(conf: SparkConf) { val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, iterator._1).name info_serialize_stream.writeObject(output, new ShuffleBlockInfo(shuffleBlockId, address, length.toInt, rkey)) output.flush() - if (output.position() >= map_serializer_buffer_size * 0.9) { + if (output.position() >= pmofConf.map_serializer_buffer_size * 0.9) { val blockBuffer = ByteBuffer.wrap(output.getBuffer) blockBuffer.position(output.position()) blockBuffer.flip() bufferArray += blockBuffer output.close() - val new_buffer = new Array[Byte](reduce_serializer_buffer_size.toInt) + val new_buffer = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt) output = new Output(new_buffer) } } @@ -93,12 +84,18 @@ class MetadataResolver(conf: SparkConf) { } for (buffer <- bufferArray) { PmofTransferService.getTransferServiceInstance(null, null). - syncBlocksInfo(driverHost, driverPort, buffer, 0.toByte, receivedCallback) + pushBlockInfo(pmofConf.driverHost, pmofConf.driverPort, buffer, 0.toByte, receivedCallback) } latch.await() } - def commitBlockInfo(shuffleId: Int, mapId: Int, partitionLengths: Array[Long]): Unit = { + /** + * called by executor, send shuffle block metadata to driver when not using persistent memory as shuffle device + * @param shuffleId + * @param mapId + * @param partitionLengths + */ + def pushFileBlockInfo(shuffleId: Int, mapId: Int, partitionLengths: Array[Long]): Unit = { var offset: Long = 0L val file = blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) val channel: FileChannel = new RandomAccessFile(file, "rw").getChannel @@ -109,7 +106,7 @@ class MetadataResolver(conf: SparkConf) { totalLength = totalLength + currentLength } - val eqService = PmofTransferService.getTransferServiceInstance(blockManager).server.getEqService + val eqService = PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).server.getEqService val shuffleBuffer = new ShuffleBuffer(0, totalLength, channel, eqService) val startedAddress = shuffleBuffer.getAddress val rdmaBuffer = eqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), startedAddress, totalLength.toInt) @@ -118,10 +115,10 @@ class MetadataResolver(conf: SparkConf) { val blockId = ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) blockMap.put(blockId.name, shuffleBuffer) - val byteBuffer = ByteBuffer.allocate(map_serializer_buffer_size.toInt) + val byteBuffer = ByteBuffer.allocate(pmofConf.map_serializer_buffer_size.toInt) val bos = new ByteBufferOutputStream(byteBuffer) var output: Output = null - if (metadataCompress) { + if (pmofConf.metadataCompress) { val dos = new DeflaterOutputStream(bos, new Deflater(9, true)) output = new Output(dos) } else { @@ -131,14 +128,14 @@ class MetadataResolver(conf: SparkConf) { for (executorId <- partitionLengths.indices) { val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, executorId) val currentLength: Int = partitionLengths(executorId).toInt - val blockNums = currentLength / shuffleBlockSize + (if (currentLength % shuffleBlockSize == 0) 0 else 1) + val blockNums = currentLength / pmofConf.shuffleBlockSize + (if (currentLength % pmofConf.shuffleBlockSize == 0) 0 else 1) for (i <- 0 until blockNums) { if (i != blockNums - 1) { - info_serialize_stream.writeObject(output, new ShuffleBlockInfo(shuffleBlockId.name, startedAddress + offset, shuffleBlockSize, rdmaBuffer.getRKey)) - offset += shuffleBlockSize + info_serialize_stream.writeObject(output, new ShuffleBlockInfo(shuffleBlockId.name, startedAddress + offset, pmofConf.shuffleBlockSize, rdmaBuffer.getRKey)) + offset += pmofConf.shuffleBlockSize } else { - info_serialize_stream.writeObject(output, new ShuffleBlockInfo(shuffleBlockId.name, startedAddress + offset, currentLength - (i * shuffleBlockSize), rdmaBuffer.getRKey)) - offset += (currentLength - (i * shuffleBlockSize)) + info_serialize_stream.writeObject(output, new ShuffleBlockInfo(shuffleBlockId.name, startedAddress + offset, currentLength - (i * pmofConf.shuffleBlockSize), rdmaBuffer.getRKey)) + offset += (currentLength - (i * pmofConf.shuffleBlockSize)) } } } @@ -161,19 +158,18 @@ class MetadataResolver(conf: SparkConf) { } PmofTransferService.getTransferServiceInstance(null, null). - syncBlocksInfo(driverHost, driverPort, byteBuffer, 0.toByte, receivedCallback) + pushBlockInfo(pmofConf.driverHost, pmofConf.driverPort, byteBuffer, 0.toByte, receivedCallback) latch.await() } - def closeBlocks(): Unit = { - for ((_, v) <- blockMap.asScala) { - v.close() - } - } - + /** + * called by executor, fetch shuffle block metadata from driver + * @param blockIds + * @param receivedCallback + */ def fetchBlockInfo(blockIds: Array[ShuffleBlockId], receivedCallback: ReceivedCallback): Unit = { val nums = blockIds.length - val byteBufferTmp = ByteBuffer.allocate(4+12*nums) + val byteBufferTmp = ByteBuffer.allocate(4 + 12 * nums) byteBufferTmp.putInt(nums) for (i <- 0 until nums) { byteBufferTmp.putInt(blockIds(i).shuffleId) @@ -182,13 +178,17 @@ class MetadataResolver(conf: SparkConf) { } byteBufferTmp.flip() PmofTransferService.getTransferServiceInstance(null, null). - syncBlocksInfo(driverHost, driverPort, byteBufferTmp, 1.toByte, receivedCallback) + pushBlockInfo(pmofConf.driverHost, pmofConf.driverPort, byteBufferTmp, 1.toByte, receivedCallback) } - def addShuffleBlockInfo(byteBuffer: ByteBuffer): Unit = { + /** + * called by driver, save shuffle block metadata to memory + * @param byteBuffer + */ + def saveShuffleBlockInfo(byteBuffer: ByteBuffer): Unit = { val bis = new ByteBufferInputStream(byteBuffer) var input: Input = null - if (metadataCompress) { + if (pmofConf.metadataCompress) { val iis = new InflaterInputStream(bis, new Inflater(true)) input = new Input(iis) } else { @@ -214,8 +214,13 @@ class MetadataResolver(conf: SparkConf) { } } + /** + * called by driver, serialize shuffle block metadata object to bytebuffer, then send to executor through RDMA network + * @param byteBuffer + * @return + */ def serializeShuffleBlockInfo(byteBuffer: ByteBuffer): ArrayBuffer[ByteBuffer] = { - val buffer: Array[Byte] = new Array[Byte](reduce_serializer_buffer_size.toInt) + val buffer: Array[Byte] = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt) var output = new Output(buffer) val bufferArray = new ArrayBuffer[ByteBuffer]() @@ -241,14 +246,14 @@ class MetadataResolver(conf: SparkConf) { loop.breakable { for (i <- blockInfoArray.indices) { info_serialize_stream.writeObject(output, blockInfoArray(i)) - if (output.position() >= reduce_serializer_buffer_size * 0.9) { + if (output.position() >= pmofConf.reduce_serializer_buffer_size * 0.9) { output.setPosition(startPos) val blockBuffer = ByteBuffer.wrap(output.getBuffer) blockBuffer.position(output.position()) blockBuffer.flip() bufferArray += blockBuffer output.close() - val new_buffer = new Array[Byte](reduce_serializer_buffer_size.toInt) + val new_buffer = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt) output = new Output(new_buffer) cur -= 1 loop.break() @@ -273,6 +278,11 @@ class MetadataResolver(conf: SparkConf) { bufferArray } + /** + * called by executor, deserialize bytebuffer to shuffle block metadata object + * @param byteBuffer + * @return + */ def deserializeShuffleBlockInfo(byteBuffer: ByteBuffer): ArrayBuffer[ShuffleBlockInfo] = { val blockInfoArray: ArrayBuffer[ShuffleBlockInfo] = ArrayBuffer[ShuffleBlockInfo]() val bais = new ByteBufferInputStream(byteBuffer) @@ -290,4 +300,10 @@ class MetadataResolver(conf: SparkConf) { } null } + + def closeBlocks(): Unit = { + for ((_, v) <- blockMap.asScala) { + v.close() + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleBlockResolver.scala index f7fdd103..98845ff9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleBlockResolver.scala @@ -4,7 +4,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage.{BlockManager, ShuffleBlockId} -import org.apache.spark.storage.pmof.{PmemBlockObjectStream, PersistentMemoryHandler} +import org.apache.spark.storage.pmof.{PmemBlockOutputStream, PersistentMemoryHandler} import org.apache.spark.network.buffer.ManagedBuffer private[spark] class PmemShuffleBlockResolver( @@ -12,7 +12,7 @@ private[spark] class PmemShuffleBlockResolver( _blockManager: BlockManager = null) extends IndexShuffleBlockResolver(conf, _blockManager) with Logging { // create ShuffleHandler here, so multiple executors can share - var partitionBufferArray: Array[PmemBlockObjectStream] = _ + var partitionBufferArray: Array[PmemBlockOutputStream] = _ override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // return BlockId corresponding ManagedBuffer diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala index ea1d8b5d..0de2f336 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala @@ -24,38 +24,30 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} import org.apache.spark.storage._ import org.apache.spark.util.collection.pmof.PmemExternalSorter -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.pmof._ +import org.apache.spark.util.configuration.pmof.PmofConf import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -private[spark] class PmemShuffleWriter[K, V, C]( - shuffleBlockResolver: PmemShuffleBlockResolver, - metadataResolver: MetadataResolver, - handle: BaseShuffleHandle[K, V, C], - mapId: Int, - context: TaskContext, - conf: SparkConf - ) +private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffleBlockResolver, + metadataResolver: MetadataResolver, + handle: BaseShuffleHandle[K, V, C], + mapId: Int, + context: TaskContext, + conf: SparkConf, + pmofConf: PmofConf) extends ShuffleWriter[K, V] with Logging { - private val dep = handle.dependency - private val blockManager = SparkEnv.get.blockManager - private var mapStatus: MapStatus = _ - private val stageId = dep.shuffleId - private val partitioner = dep.partitioner - private val numPartitions = partitioner.numPartitions - private val serInstance: SerializerInstance = dep.serializer.newInstance() - private val numMaps = handle.numMaps - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - logDebug("This stage has "+ numMaps + " maps") - - val enable_rdma: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_rdma", defaultValue = true) - val enable_pmem: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) - - val partitionLengths: Array[Long] = Array.fill[Long](numPartitions)(0) - var set_clean: Boolean = true - private var sorter: PmemExternalSorter[K, V, _] = _ + private[this] val dep = handle.dependency + private[this] val blockManager = SparkEnv.get.blockManager + private[this] var mapStatus: MapStatus = _ + private[this] val stageId = dep.shuffleId + private[this] val partitioner = dep.partitioner + private[this] val numPartitions = partitioner.numPartitions + private[this] val numMaps = handle.numMaps + private[this] val writeMetrics = context.taskMetrics().shuffleWriteMetrics + private[this] val partitionLengths: Array[Long] = Array.fill[Long](numPartitions)(0) + private[this] var sorter: PmemExternalSorter[K, V, _] = _ /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -64,77 +56,77 @@ private[spark] class PmemShuffleWriter[K, V, C]( */ private var stopping = false - - /** + /** * Call PMDK to write data to persistent memory * Original Spark writer will do write and mergesort in this function, * while by using pmdk, we can do that once since pmdk supports transaction. */ override def write(records: Iterator[Product2[K, V]]): Unit = { - // TODO: keep checking if data need to spill to disk when PM capacity is not enough. - // TODO: currently, we apply processed records to PM. - - val partitionBufferArray = (0 until numPartitions).toArray.map( partitionId => - new PmemBlockObjectStream( - blockManager.serializerManager, - serInstance, + val PmemBlockOutputStreamArray = (0 until numPartitions).toArray.map(partitionId => + new PmemBlockOutputStream( context.taskMetrics(), ShuffleBlockId(stageId, mapId, partitionId), + SparkEnv.get.serializerManager, + dep.serializer, conf, + pmofConf, numMaps, numPartitions)) - if (dep.mapSideCombine) { // do aggragation + if (dep.mapSideCombine) { // do aggregation if (dep.aggregator.isDefined) { - sorter = new PmemExternalSorter[K, V, C](context, handle, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.setPartitionByteBufferArray(partitionBufferArray) + sorter = new PmemExternalSorter[K, V, C](context, handle, pmofConf, dep.aggregator, Some(dep.partitioner), + dep.keyOrdering, dep.serializer) + sorter.setPartitionByteBufferArray(PmemBlockOutputStreamArray) sorter.insertAll(records) sorter.forceSpillToPmem() } else { throw new IllegalStateException("Aggregator is empty for map-side combine") } } else { // no aggregation - while (records.hasNext) { + while (records.hasNext) { // since we need to write same partition (key, value) together, do a partition index here val elem = records.next() val partitionId: Int = partitioner.getPartition(elem._1) - partitionBufferArray(partitionId).write(elem._1, elem._2) + PmemBlockOutputStreamArray(partitionId).write(elem._1, elem._2) } for (partitionId <- 0 until numPartitions) { - partitionBufferArray(partitionId).maybeSpill(force = true) + PmemBlockOutputStreamArray(partitionId).maybeSpill(force = true) } } var spilledPartition = 0 - val partitionSpilled: ArrayBuffer[Int] = ArrayBuffer[Int]() + val spillPartitionArray: ArrayBuffer[Int] = ArrayBuffer[Int]() while (spilledPartition < numPartitions) { - if (partitionBufferArray(spilledPartition).ifSpilled()) { - partitionSpilled.append(spilledPartition) + if (PmemBlockOutputStreamArray(spilledPartition).ifSpilled()) { + spillPartitionArray.append(spilledPartition) } spilledPartition += 1 } - val data_addr_map = mutable.HashMap.empty[Int, Array[(Long, Int)]] + val pmemBlockInfoMap = mutable.HashMap.empty[Int, Array[(Long, Int)]] var output_str : String = "" - for (i <- partitionSpilled) { - if (enable_rdma) - data_addr_map(i) = partitionBufferArray(i).getPartitionMeta().map{ info => (info._1, info._2)} - partitionLengths(i) = partitionBufferArray(i).size - output_str += "\tPartition " + i + ": " + partitionLengths(i) + ", records: " + partitionBufferArray(i).records + "\n" + for (i <- spillPartitionArray) { + if (pmofConf.enableRdma) { + pmemBlockInfoMap(i) = PmemBlockOutputStreamArray(i).getPartitionMeta().map { info => (info._1, info._2) } + } + partitionLengths(i) = PmemBlockOutputStreamArray(i).size + output_str += "\tPartition " + i + ": " + partitionLengths(i) + ", records: " + PmemBlockOutputStreamArray(i).records + "\n" } + for (i <- 0 until numPartitions) { - partitionBufferArray(i).close() + PmemBlockOutputStreamArray(i).close() } logDebug("shuffle_" + dep.shuffleId + "_" + mapId + ": \n" + output_str) val shuffleServerId = blockManager.shuffleServerId - if (enable_rdma) { - val rkey = partitionBufferArray(0).getRkey() - metadataResolver.commitPmemBlockInfo(stageId, mapId, data_addr_map, rkey) + if (pmofConf.enableRdma) { + val rkey = PmemBlockOutputStreamArray(0).getRkey() + metadataResolver.pushPmemBlockInfo(stageId, mapId, pmemBlockInfoMap, rkey) val blockManagerId: BlockManagerId = BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host), - PmofTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo) + PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port, shuffleServerId.topologyInfo) mapStatus = MapStatus(blockManagerId, partitionLengths) } else { mapStatus = MapStatus(shuffleServerId, partitionLengths) diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala index eae7df43..e517234a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala @@ -5,73 +5,61 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.internal.Logging import org.apache.spark.network.pmof.PmofTransferService import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SerializedShuffleWriter, SortShuffleManager} +import org.apache.spark.util.configuration.pmof.PmofConf import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { logInfo("Initialize RdmaShuffleManager") - if (!conf.getBoolean("spark.shuffle.spill", defaultValue = true)) logWarning("spark.shuffle.spill was set to false") - val enable_rdma: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_rdma", defaultValue = true) - val enable_pmem: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) - - val metadataResolver: MetadataResolver = new MetadataResolver(conf) + if (!conf.getBoolean("spark.shuffle.spill", defaultValue = true)) { + logWarning("spark.shuffle.spill was set to false") + } private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private[this] val pmofConf = new PmofConf(conf) + var metadataResolver: MetadataResolver = _ override def registerShuffle[K, V, C](shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { val env: SparkEnv = SparkEnv.get - if (enable_rdma) { - PmofTransferService.getTransferServiceInstance(env.blockManager, this, isDriver = true) - } - if (enable_pmem) { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { - // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: - new SerializedShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, numMaps, dependency) + + metadataResolver = new MetadataResolver(pmofConf) + + if (pmofConf.enableRdma) { + PmofTransferService.getTransferServiceInstance(pmofConf: PmofConf, env.blockManager, this, isDriver = true) } + + new BaseShuffleHandle(shuffleId, numMaps, dependency) } override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] = { + assert(handle.isInstanceOf[BaseShuffleHandle[_, _, _]]) + val env: SparkEnv = SparkEnv.get val numMaps = handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps + + metadataResolver = new MetadataResolver(pmofConf) numMapsForShuffle.putIfAbsent(handle.shuffleId, numMaps) - if (enable_rdma) { - PmofTransferService.getTransferServiceInstance(env.blockManager, this) + + if (pmofConf.enableRdma) { + PmofTransferService.getTransferServiceInstance(pmofConf, env.blockManager, this) } - handle match { - case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - new SerializedShuffleWriter( - env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], - metadataResolver, - context.taskMemoryManager(), - unsafeShuffleHandle, - mapId, - context, - env.conf, - enable_rdma) - case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - if (enable_pmem) { - new PmemShuffleWriter(shuffleBlockResolver.asInstanceOf[PmemShuffleBlockResolver], metadataResolver, - handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, env.conf) - } else { - new BaseShuffleWriter(shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], metadataResolver, other, mapId, context, enable_rdma) - } + + if (pmofConf.enablePmem) { + new PmemShuffleWriter(shuffleBlockResolver.asInstanceOf[PmemShuffleBlockResolver], metadataResolver, + handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, env.conf, pmofConf) + } else { + new BaseShuffleWriter(shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], metadataResolver, + handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, pmofConf) } } override def getReader[K, C](handle: _root_.org.apache.spark.shuffle.ShuffleHandle, startPartition: Int, endPartition: Int, context: _root_.org.apache.spark.TaskContext): _root_.org.apache.spark.shuffle.ShuffleReader[K, C] = { - if (enable_rdma) { + if (pmofConf.enableRdma) { new RdmaShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context) + startPartition, endPartition, context, pmofConf) } else { - new PmemShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + new BaseShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, pmofConf) } } @@ -89,7 +77,7 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager } override val shuffleBlockResolver: ShuffleBlockResolver = { - if (enable_pmem) + if (pmofConf.enablePmem) new PmemShuffleBlockResolver(conf) else new IndexShuffleBlockResolver(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala index 653b9679..92af5417 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala @@ -10,30 +10,31 @@ import org.apache.spark.storage.pmof._ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.pmof.PmemExternalSorter +import org.apache.spark.util.configuration.pmof.PmofConf /** * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by * requesting them from other nodes' block stores. */ -private[spark] class RdmaShuffleReader[K, C]( - handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, - context: TaskContext, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) +private[spark] class RdmaShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + pmofConf: PmofConf, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { - private val dep = handle.dependency - val serializerInstance: SerializerInstance = dep.serializer.newInstance() - val enable_pmem: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) + private[this] val dep = handle.dependency + private[this] val serializerInstance: SerializerInstance = dep.serializer.newInstance() + private[this] val enable_pmem: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val wrappedStreams: RdmaShuffleBlockFetcherIterator = new RdmaShuffleBlockFetcherIterator( context, - PmofTransferService.getTransferServiceInstance(blockManager), + PmofTransferService.getTransferServiceInstance(pmofConf, blockManager), blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), serializerManager.wrapStream, @@ -86,7 +87,7 @@ private[spark] class RdmaShuffleReader[K, C]( dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => if (enable_pmem) { - val sorter = new PmemExternalSorter[K, C, C](context, handle, ordering = Some(keyOrd), serializer = dep.serializer) + val sorter = new PmemExternalSorter[K, C, C](context, handle, pmofConf, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala index c448d8db..5b0d39be 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala @@ -4,7 +4,7 @@ import java.nio.ByteBuffer import org.apache.spark.internal.Logging import org.apache.spark.network.pmof.PmofTransferService -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.SparkEnv import scala.collection.JavaConverters._ import java.nio.file.{Files, Paths} @@ -12,6 +12,8 @@ import java.util.UUID import java.lang.management.ManagementFactory import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.util.configuration.pmof.PmofConf + import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global @@ -137,21 +139,18 @@ private[spark] class PersistentMemoryHandler( object PersistentMemoryHandler { private var persistentMemoryHandler: PersistentMemoryHandler = _ private var stopped: Boolean = false - def getPersistentMemoryHandler(conf: SparkConf, root_dir: String, path_arg: List[String], shuffleBlockId: String, pmPoolSize: Long, maxStages: Int, maxMaps: Int): PersistentMemoryHandler = synchronized { + def getPersistentMemoryHandler(pmofConf: PmofConf, root_dir: String, path_arg: List[String], shuffleBlockId: String, pmPoolSize: Long, maxStages: Int, maxMaps: Int): PersistentMemoryHandler = synchronized { if (persistentMemoryHandler == null) { persistentMemoryHandler = new PersistentMemoryHandler(root_dir, path_arg, shuffleBlockId, maxStages, maxMaps, pmPoolSize) persistentMemoryHandler.log("Use persistentMemoryHandler Object: " + this) - val enable_rdma: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_rdma", defaultValue = true) - if (enable_rdma) { - val pmem_capacity: Long = conf.getLong("spark.shuffle.pmof.pmem_capacity", defaultValue = 264239054848L) + if (pmofConf.enableRdma) { val blockManager = SparkEnv.get.blockManager - val eqService = PmofTransferService.getTransferServiceInstance(blockManager).server.getEqService + val eqService = PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).server.getEqService val offset: Long = persistentMemoryHandler.getRootAddr - val rdmaBuffer = eqService.regRmaBufferByAddress(null, offset, pmem_capacity) + val rdmaBuffer = eqService.regRmaBufferByAddress(null, offset, pmofConf.pmemCapacity) persistentMemoryHandler.rkey = rdmaBuffer.getRKey() } - val dev_core_map = conf.get("spark.shuffle.pmof.dev_core_set").split(";").map(_.trim).map(_.split(":")).map(arr => arr(0) -> arr(1)).toMap - val core_set = dev_core_map.get(persistentMemoryHandler.getDevice()) + val core_set = pmofConf.pmemCoreMap.get(persistentMemoryHandler.getDevice()) core_set match { case Some(s) => Future {nativeTaskset(s)} case None => {} diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala new file mode 100644 index 00000000..75e9fb5d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala @@ -0,0 +1,53 @@ +package org.apache.spark.storage.pmof + +import java.io.InputStream + +import com.esotericsoftware.kryo.KryoException +import org.apache.spark.SparkEnv +import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance, SerializerManager} +import org.apache.spark.storage.BlockId + +class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, serializer: Serializer) { + val blockId: BlockId = pmemBlockOutputStream.getBlockId() + val serializerManager: SerializerManager = SparkEnv.get.serializerManager + val serInstance: SerializerInstance = serializer.newInstance() + val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler + var pmemInputStream: PmemInputStream = new PmemInputStream(persistentMemoryWriter, blockId.name) + var wrappedStream: InputStream = serializerManager.wrapStream(blockId, pmemInputStream) + var inObjStream: DeserializationStream = serInstance.deserializeStream(wrappedStream) + + var total_records: Long = 0 + var indexInBatch: Int = 0 + var closing: Boolean = false + + def loadStream(): Unit = { + total_records = pmemBlockOutputStream.getTotalRecords() + indexInBatch = 0 + } + + def readNextItem(): (K, C) = { + if (closing == true) { + close() + return null + } + try{ + val k = inObjStream.readObject().asInstanceOf[K] + val c = inObjStream.readObject().asInstanceOf[C] + indexInBatch += 1 + if (indexInBatch == total_records) { + closing = true + } + (k, c) + } catch { + case ex: KryoException => { + } + sys.exit(0) + } + } + + def close(): Unit = { + pmemInputStream.close + wrappedStream = null + inObjStream = null + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala rename to core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala index f7039ce1..4916f725 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockObjectStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala @@ -4,11 +4,11 @@ import org.apache.spark.storage._ import org.apache.spark.serializer._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.util.Utils -import java.io._ -import java.io.{InputStream, OutputStream} +import java.io.{File, OutputStream} + +import org.apache.spark.util.configuration.pmof.PmofConf import scala.collection.mutable.ArrayBuffer @@ -26,39 +26,36 @@ object PmemBlockId { } } -private[spark] class PmemBlockObjectStream( - serializerManager: SerializerManager, - serializerInstance: SerializerInstance, +private[spark] class PmemBlockOutputStream( taskMetrics: TaskMetrics, blockId: BlockId, + serializerManager: SerializerManager, + serializer: Serializer, conf: SparkConf, + pmofConf: PmofConf, numMaps: Int = 0, numPartitions: Int = 1 ) extends DiskBlockObjectWriter(new File(Utils.getConfiguredLocalDirs(conf).toList(0) + "/null"), null, null, 0, true, null, null) with Logging { var size: Int = 0 var records: Int = 0 - var recordsPerBlock: Int = 0 val recordsArray: ArrayBuffer[Int] = ArrayBuffer() var spilled: Boolean = false var partitionMeta: Array[(Long, Int, Int)] = _ - val root_dir = Utils.getConfiguredLocalDirs(conf).toList(0) - val path_list = conf.get("spark.shuffle.pmof.pmem_list").split(",").map(_.trim).toList - val maxPoolSize: Long = conf.getLong("spark.shuffle.pmof.pmpool_size", defaultValue = 1073741824) - val maxStages: Int = conf.getInt("spark.shuffle.pmof.max_stage_num", defaultValue = 1000) - val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler(conf, root_dir, path_list, blockId.name, maxPoolSize, maxStages, numMaps) - val spill_throttle = 4194304 + + val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler(pmofConf, + root_dir, pmofConf.path_list, blockId.name, pmofConf.maxPoolSize, pmofConf.maxStages, numMaps) + //disable metadata updating by default //persistentMemoryWriter.updateShuffleMeta(blockId.name) - logDebug(blockId.name) - val bytesStream: OutputStream = new PmemOutputStream( + val pmemOutputStream: PmemOutputStream = new PmemOutputStream( persistentMemoryWriter, numPartitions, blockId.name, numMaps) - val wrappedStream: OutputStream = serializerManager.wrapStream(blockId, bytesStream) - val objStream: SerializationStream = serializerInstance.serializeStream(wrappedStream) - var inputStream: InputStream = _ + val serInstance = serializer.newInstance() + var wrappedStream: OutputStream = serializerManager.wrapStream(blockId, pmemOutputStream) + var objStream: SerializationStream = serInstance.serializeStream(wrappedStream) override def write(key: Any, value: Any): Unit = { objStream.writeKey(key) @@ -72,11 +69,9 @@ private[spark] class PmemBlockObjectStream( } override def close() { - bytesStream.close() - logDebug("Serialize stream closed.") - if (inputStream != null) - inputStream.close() - logDebug("PersistentMemoryHandlerPartition: stream closed.") + pmemOutputStream.close() + wrappedStream = null + objStream = null } override def flush() { @@ -84,11 +79,11 @@ private[spark] class PmemBlockObjectStream( } def maybeSpill(force: Boolean = false): Unit = { - if ((spill_throttle != -1 && bytesStream.asInstanceOf[PmemOutputStream].size >= spill_throttle) || force == true) { + if ((pmofConf.spill_throttle != -1 && pmemOutputStream.asInstanceOf[PmemOutputStream].size >= pmofConf.spill_throttle) || force == true) { val start = System.nanoTime() objStream.flush() - bytesStream.flush() - val bufSize = bytesStream.asInstanceOf[PmemOutputStream].size + pmemOutputStream.flush() + val bufSize = pmemOutputStream.size //logInfo(blockId.name + " do spill, size is " + bufSize) if (bufSize > 0) { recordsArray += recordsPerBlock @@ -102,7 +97,7 @@ private[spark] class PmemBlockObjectStream( } else { taskMetrics.incDiskBytesSpilled(bufSize) } - bytesStream.asInstanceOf[PmemOutputStream].reset() + pmemOutputStream.reset() spilled = true } } @@ -128,10 +123,6 @@ private[spark] class PmemBlockObjectStream( persistentMemoryWriter.rkey } - /*def getAllBytes(): Array[Byte] = { - persistentMemoryWriter.getPartition(blockId.name) - }*/ - def getTotalRecords(): Long = { records } @@ -140,11 +131,7 @@ private[spark] class PmemBlockObjectStream( size } - def getInputStream(): InputStream = { - if (inputStream == null) { - inputStream = new PmemInputStream(persistentMemoryWriter, blockId.name) - } - inputStream + def getPersistentMemoryHandler: PersistentMemoryHandler = { + persistentMemoryWriter } - } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala index c1388e54..dee11675 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemInputStream.scala @@ -1,14 +1,12 @@ package org.apache.spark.storage.pmof import java.io.InputStream -import org.apache.spark.storage.pmof.PmemBuffer import org.apache.spark.internal.Logging import scala.util.control.Breaks._ class PmemInputStream( persistentMemoryHandler: PersistentMemoryHandler, - blockId: String - ) extends InputStream with Logging { + blockId: String) extends InputStream with Logging { val buf = new PmemBuffer() var index: Int = 0 var remaining: Int = 0 @@ -60,8 +58,8 @@ class PmemInputStream( } } } - def getByteBufferDirectAddr(): Long = { - buf.getDirectAddr() + def getByteBufferDirectAddr: Long = { + buf.getDirectAddr } override def available(): Int = { diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemManagedBuffer.scala index 13da3111..391b2aca 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemManagedBuffer.scala @@ -70,6 +70,6 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext val data_length = size().toInt val in = createInputStream() in.asInstanceOf[PmemInputStream].load(data_length) - Unpooled.wrappedBuffer(in.asInstanceOf[PmemInputStream].getByteBufferDirectAddr(), data_length, false) + Unpooled.wrappedBuffer(in.asInstanceOf[PmemInputStream].getByteBufferDirectAddr, data_length, false) } } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala index 1f0a3d72..215bac01 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala @@ -16,7 +16,8 @@ class PmemOutputStream( var is_closed = false val length: Int = 1024*1024*6 - var total: Int = 0 + var flushedSize: Int = 0 + var remainingSize: Int = 0 val buf: ByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(length, length) val byteBuffer: ByteBuffer = buf.nioBuffer(0, length) @@ -24,18 +25,19 @@ class PmemOutputStream( override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { byteBuffer.put(bytes, off, len) - total += len + remainingSize += len } override def write(byte: Int): Unit = { byteBuffer.putInt(byte) - total += 4 + remainingSize += 4 } override def flush(): Unit = { - if (size() > 0) { - persistentMemoryWriter.setPartition(numPartitions, blockId, byteBuffer, size(), set_clean, numMaps) - reset() + if (remainingSize > 0) { + persistentMemoryWriter.setPartition(numPartitions, blockId, byteBuffer, remainingSize, set_clean, numMaps) + flushedSize += remainingSize + remainingSize = 0 } if (set_clean) { set_clean = false @@ -43,11 +45,12 @@ class PmemOutputStream( } def size(): Int = { - total + flushedSize } def reset(): Unit = { - total = 0 + remainingSize = 0 + flushedSize = 0 byteBuffer.clear() } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala similarity index 85% rename from core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala rename to core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala index 84345025..c563e222 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala @@ -24,16 +24,16 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} import javax.annotation.concurrent.GuardedBy import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{ShuffleClient, TempFileManager} import org.apache.spark.network.pmof._ +import org.apache.spark.network.shuffle.{ShuffleClient, TempFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage._ import org.apache.spark.{SparkException, TaskContext} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -45,36 +45,50 @@ import scala.concurrent.ExecutionContext.Implicits.global * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid * using too much memory. * - * @param context [[TaskContext]], used for metrics update - * @param shuffleClient [[ShuffleClient]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks - * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. - * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. - * @param streamWrapper A function to wrap the returned input stream. - * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. - * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param context [[TaskContext]], used for metrics update + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks + * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. + * For each block we also require the size (in bytes as a long field) in + * order to throttle the memory usage. + * @param streamWrapper A function to wrap the returned input stream. + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point * for a given remote host:port. - * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. - * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. + * @param detectCorrupt whether to detect any corruption in fetched blocks. */ private[spark] -final class RdmaShuffleBlockFetcherIterator( - context: TaskContext, - shuffleClient: ShuffleClient, - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - streamWrapper: (BlockId, InputStream) => InputStream, - maxBytesInFlight: Long, - maxReqsInFlight: Int, - maxBlocksInFlightPerAddress: Int, - maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) +final class RdmaShuffleBlockFetcherIterator(context: TaskContext, + shuffleClient: ShuffleClient, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { import RdmaShuffleBlockFetcherIterator._ + /** Local blocks to fetch, excluding zero-sized blocks. */ + private[this] val localBlocks = new ArrayBuffer[BlockId]() + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[File]() + private[this] val remoteRdmaRequestQueue = new LinkedBlockingQueue[RdmaRequest]() /** * Total number of blocks to fetch. This can be smaller than the total number of blocks * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. @@ -82,36 +96,20 @@ final class RdmaShuffleBlockFetcherIterator( * This should equal localBlocks.size + remoteBlocks.size. */ private[this] var numBlocksToFetch = 0 - /** * The number of blocks processed by the caller. The iterator is exhausted when * [[numBlocksProcessed]] == [[numBlocksToFetch]]. */ private[this] var numBlocksProcessed = 0 - - /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() - - /** - * A queue to hold our results. This turns the asynchronous model provided by - * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). - */ - private[this] val results = new LinkedBlockingQueue[FetchResult] - /** * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ @volatile private[this] var currentResult: SuccessFetchResult = _ - /** Current bytes in flight from our requests */ private[this] var bytesInFlight = new AtomicLong(0) - /** Current number of requests in flight */ private[this] var reqsInFlight = new AtomicInteger(0) - - private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -119,16 +117,28 @@ final class RdmaShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false - /** - * A set to store the files used for shuffling remote huge blocks. Files in this set will be - * deleted when cleanup. This is a layer of defensiveness against disk file leaks. - */ - @GuardedBy("this") - private[this] val shuffleFilesSet = mutable.HashSet[File]() + initialize() - private[this] val remoteRdmaRequestQueue = new LinkedBlockingQueue[RdmaRequest]() + def initialize(): Unit = { + context.addTaskCompletionListener(_ => cleanup()) - initialize() + val remoteBlocksByAddress = blocksByAddress.filter(_._1.executorId != blockManager.blockManagerId.executorId) + for ((address, blockInfos) <- blocksByAddress) { + if (address.executorId == blockManager.blockManagerId.executorId) { + localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + numBlocksToFetch += localBlocks.size + } + } + + startFetch(remoteBlocksByAddress) + } + + def startFetch(remoteBlocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + for ((blockManagerId, blockInfos) <- remoteBlocksByAddress) { + startFetchMetadata(blockManagerId, blockInfos.filter(_._2 != 0).map(_._1).toArray) + } + fetchLocalBlocks() + } def startFetchMetadata(blockManagerId: BlockManagerId, blockIds: Array[BlockId]): Unit = { if (blockIds.length == 0) return @@ -145,13 +155,17 @@ final class RdmaShuffleBlockFetcherIterator( remoteRdmaRequestQueue.put(new RdmaRequest(blockManagerId, last.getShuffleBlockId, blockInfoArray.slice(startIndex, i), reqSize)) startIndex = i reqSize = 0 - Future { fetchRemoteBlocks() } + Future { + fetchRemoteBlocks() + } } last = current reqSize += current.getLength } remoteRdmaRequestQueue.put(new RdmaRequest(blockManagerId, last.getShuffleBlockId, blockInfoArray.slice(startIndex, num), reqSize)) - Future { fetchRemoteBlocks() } + Future { + fetchRemoteBlocks() + } } override def onFailure(e: Throwable): Unit = { @@ -165,107 +179,29 @@ final class RdmaShuffleBlockFetcherIterator( rdmaTransferService.fetchBlockInfo(blockIds, receivedCallback) } - def sendRequest(rdmaRequest: RdmaRequest): Unit = { - val shuffleBlockInfos = rdmaRequest.shuffleBlockInfos - var blockNums= shuffleBlockInfos.size - bytesInFlight.addAndGet(rdmaRequest.reqSize) - reqsInFlight.incrementAndGet - val blockManagerId = rdmaRequest.blockManagerId - val shuffleBlockIdName = rdmaRequest.shuffleBlockIdName - - val pmofTransferService = shuffleClient.asInstanceOf[PmofTransferService] - - val blockFetchingReadCallback = new ReadCallback { - def onSuccess(shuffleBuffer: ShuffleBuffer, f: Int => Unit): Unit = { - if (!isZombie) { - RdmaShuffleBlockFetcherIterator.this.synchronized { - blockNums -= 1 - if (blockNums == 0) { - results.put(SuccessFetchResult(BlockId(shuffleBlockIdName), blockManagerId, rdmaRequest.reqSize, shuffleBuffer, isNetworkReqDone = true)) - f(shuffleBuffer.getRdmaBufferId) - } - } - } - } - - override def onFailure(e: Throwable): Unit = { - results.put(FailureFetchResult(BlockId(shuffleBlockIdName), blockManagerId, e)) - } - } - - val client = pmofTransferService.getClient(blockManagerId.host, blockManagerId.port) - val shuffleBuffer = new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true) - val rdmaBuffer = client.getEqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), - shuffleBuffer.getAddress, shuffleBuffer.getLength.toInt) - shuffleBuffer.setRdmaBufferId(rdmaBuffer.getBufferId) - - var offset = 0 - for (i <- 0 until blockNums) { - pmofTransferService.fetchBlock(blockManagerId.host, blockManagerId.port, - shuffleBlockInfos(i).getAddress, shuffleBlockInfos(i).getLength, - shuffleBlockInfos(i).getRkey, offset, shuffleBuffer, client, blockFetchingReadCallback) - offset += shuffleBlockInfos(i).getLength - } - } - - def isRemoteBlockFetchable(rdmaRequest: RdmaRequest): Boolean = { - reqsInFlight.get + 1 <= maxReqsInFlight && bytesInFlight.get + rdmaRequest.reqSize <= maxBytesInFlight - } - - def fetchRemoteBlocks(): Unit = { - val rdmaRequest = remoteRdmaRequestQueue.poll() - if (rdmaRequest == null) { - return - } - if (!isRemoteBlockFetchable(rdmaRequest)) { - remoteRdmaRequestQueue.put(rdmaRequest) - } else { - sendRequest(rdmaRequest) - } - } - - def startFetch(remoteBlocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { - for ((blockManagerId, blockInfos) <- remoteBlocksByAddress) { - startFetchMetadata(blockManagerId, blockInfos.filter(_._2 != 0).map(_._1).toArray) - } - fetchLocalBlocks() - } - - def initialize(): Unit = { - context.addTaskCompletionListener(_ => cleanup()) - - val remoteBlocksByAddress = blocksByAddress.filter(_._1.executorId != blockManager.blockManagerId.executorId) - for ((address, blockInfos) <- blocksByAddress) { - if (address.executorId == blockManager.blockManagerId.executorId) { - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - numBlocksToFetch += localBlocks.size + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks() { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() + try { + val buf = blockManager.getBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, isNetworkReqDone = false)) + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, blockManager.blockManagerId, e)) + return } } - - startFetch(remoteBlocksByAddress) - } - - // Decrements the buffer reference count. - // The currentResult is set to null to prevent releasing the buffer again on cleanup() - private[storage] def releaseCurrentResultBuffer(): Unit = { - // Release the current buffer if necessary - if (currentResult != null) { - currentResult.buf.release() - } - currentResult = null - } - - override def createTempFile(): File = { - blockManager.diskBlockManager.createTempLocalBlock()._2 - } - - override def registerTempFileToClean(file: File): Boolean = synchronized { - if (isZombie) { - false - } else { - shuffleFilesSet += file - true - } } /** @@ -300,32 +236,28 @@ final class RdmaShuffleBlockFetcherIterator( } } - /** - * Fetch the local blocks while we are fetching remote blocks. This is ok because - * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we - * track in-memory are the ManagedBuffer references themselves. - */ - private[this] def fetchLocalBlocks() { - val iter = localBlocks.iterator - while (iter.hasNext) { - val blockId = iter.next() - try { - val buf = blockManager.getBlockData(blockId) - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - buf.retain() - results.put(SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, isNetworkReqDone = false)) - } catch { - case e: Exception => - // If we see an exception, stop immediately. - logError(s"Error occurred while fetching local blocks", e) - results.put(FailureFetchResult(blockId, blockManager.blockManagerId, e)) - return - } + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + // Release the current buffer if necessary + if (currentResult != null) { + currentResult.buf.release() } + currentResult = null } - override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + override def createTempFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempFileToClean(file: File): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } /** * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers @@ -381,20 +313,83 @@ final class RdmaShuffleBlockFetcherIterator( } input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption, also the size of - // block is small (the decompressed block is smaller than maxBytesInFlight) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } // Send fetch requests up to maxBytesInFlight - Future { fetchRemoteBlocks() } + Future { + fetchRemoteBlocks() + } } currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new RDMABufferReleasingInputStream(input, this)) } + def fetchRemoteBlocks(): Unit = { + val rdmaRequest = remoteRdmaRequestQueue.poll() + if (rdmaRequest == null) { + return + } + if (!isRemoteBlockFetchable(rdmaRequest)) { + remoteRdmaRequestQueue.put(rdmaRequest) + } else { + sendRequest(rdmaRequest) + } + } + + def sendRequest(rdmaRequest: RdmaRequest): Unit = { + val shuffleBlockInfos = rdmaRequest.shuffleBlockInfos + var blockNums = shuffleBlockInfos.size + bytesInFlight.addAndGet(rdmaRequest.reqSize) + reqsInFlight.incrementAndGet + val blockManagerId = rdmaRequest.blockManagerId + val shuffleBlockIdName = rdmaRequest.shuffleBlockIdName + + val pmofTransferService = shuffleClient.asInstanceOf[PmofTransferService] + + val blockFetchingReadCallback = new ReadCallback { + def onSuccess(shuffleBuffer: ShuffleBuffer, f: Int => Unit): Unit = { + if (!isZombie) { + RdmaShuffleBlockFetcherIterator.this.synchronized { + blockNums -= 1 + if (blockNums == 0) { + results.put(SuccessFetchResult(BlockId(shuffleBlockIdName), blockManagerId, rdmaRequest.reqSize, shuffleBuffer, isNetworkReqDone = true)) + f(shuffleBuffer.getRdmaBufferId) + } + } + } + } + + override def onFailure(e: Throwable): Unit = { + results.put(FailureFetchResult(BlockId(shuffleBlockIdName), blockManagerId, e)) + } + } + + val client = pmofTransferService.getClient(blockManagerId.host, blockManagerId.port) + val shuffleBuffer = new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true) + val rdmaBuffer = client.getEqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), + shuffleBuffer.getAddress, shuffleBuffer.getLength.toInt) + shuffleBuffer.setRdmaBufferId(rdmaBuffer.getBufferId) + + var offset = 0 + for (i <- 0 until blockNums) { + pmofTransferService.fetchBlock(blockManagerId.host, blockManagerId.port, + shuffleBlockInfos(i).getAddress, shuffleBlockInfos(i).getLength, + shuffleBlockInfos(i).getRkey, offset, shuffleBuffer, client, blockFetchingReadCallback) + offset += shuffleBlockInfos(i).getLength + } + } + + def isRemoteBlockFetchable(rdmaRequest: RdmaRequest): Boolean = { + reqsInFlight.get + 1 <= maxReqsInFlight && bytesInFlight.get + rdmaRequest.reqSize <= maxBytesInFlight + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => @@ -412,8 +407,8 @@ private class RdmaRequest(val blockManagerId: BlockManagerId, val shuffleBlockId * Helper class that ensures a ManagedBuffer is released upon InputStream.close() */ private class RDMABufferReleasingInputStream( - private val delegate: InputStream, - private val iterator: RdmaShuffleBlockFetcherIterator) + private val delegate: InputStream, + private val iterator: RdmaShuffleBlockFetcherIterator) extends InputStream { private[this] var closed = false @@ -445,16 +440,6 @@ private class RDMABufferReleasingInputStream( private[storage] object RdmaShuffleBlockFetcherIterator { - /** - * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { - val size: Long = blocks.map(_._2).sum - } - /** * Result of a fetch from a remote block. */ @@ -463,13 +448,25 @@ object RdmaShuffleBlockFetcherIterator { val address: BlockManagerId } + /** + * A request to fetch blocks from a remote BlockManager. + * + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { + val size: Long = blocks.map(_._2).sum + } + /** * Result of a fetch from a remote block successfully. - * @param blockId block id - * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param buf `ManagedBuffer` for the content. + * + * @param blockId block id + * @param address BlockManager that the block was fetched from. + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( @@ -484,13 +481,15 @@ object RdmaShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. + * * @param blockId block id * @param address BlockManager that the block was attempted to be fetched from - * @param e the failure exception + * @param e the failure exception */ private[storage] case class FailureFetchResult( blockId: BlockId, address: BlockManagerId, e: Throwable) extends FetchResult + } diff --git a/core/src/main/scala/org/apache/spark/util/collection/pmof/PmemExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/pmof/PmemExternalSorter.scala index ae79d429..20259284 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/pmof/PmemExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/pmof/PmemExternalSorter.scala @@ -1,6 +1,5 @@ package org.apache.spark.util.collection.pmof -import java.io.InputStream import java.util.Comparator import scala.collection.mutable @@ -11,30 +10,26 @@ import org.apache.spark.serializer._ import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.util.collection._ import org.apache.spark.storage.pmof._ -import com.esotericsoftware.kryo.KryoException -import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.spark.util.configuration.pmof.PmofConf private[spark] class PmemExternalSorter[K, V, C]( context: TaskContext, handle: BaseShuffleHandle[K, _, C], + pmofConf: PmofConf, aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Serializer = SparkEnv.get.serializer) - extends ExternalSorter[K, V, C](context, aggregator, partitioner, ordering, serializer) - with Logging { - var partitionBufferArray = ArrayBuffer[PmemBlockObjectStream]() - var mapSideCombine = false - private val dep = handle.dependency - private val serializerManager = SparkEnv.get.serializerManager - private val serInstance = serializer.newInstance() - private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) - private val shouldPartition = numPartitions > 1 + extends ExternalSorter[K, V, C](context, aggregator, partitioner, ordering, serializer) with Logging { + private[this] val pmemBlockOutputStreamArray: ArrayBuffer[PmemBlockOutputStream] = ArrayBuffer[PmemBlockOutputStream]() + private[this] var mapStage = false + private[this] val dep = handle.dependency + private[this] val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) + private[this] val shouldPartition = numPartitions > 1 + private def getPartition(key: K): Int = { if (shouldPartition) partitioner.get.getPartition(key) else 0 } - private val inMemoryCollectionSizeThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.pmof.MemoryThreshold", 5 * 1024 * 1024) private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] { override def compare(a: K, b: K): Int = { @@ -52,25 +47,27 @@ private[spark] class PmemExternalSorter[K, V, C]( } } - def setPartitionByteBufferArray(writerArray: Array[PmemBlockObjectStream] = null): Unit = { - for (i <- 0 until writerArray.length) { - partitionBufferArray += writerArray(i) + def setPartitionByteBufferArray(writerArray: Array[PmemBlockOutputStream] = null): Unit = { + for (i <- writerArray.indices) { + pmemBlockOutputStreamArray += writerArray(i) } - mapSideCombine = true + mapStage = true } - def getPartitionByteBufferArray(stageId: Int, partitionId: Int): PmemBlockObjectStream = { - if (mapSideCombine) { - partitionBufferArray(partitionId) + def getPartitionByteBufferArray(stageId: Int, partitionId: Int): PmemBlockOutputStream = { + if (mapStage) { + pmemBlockOutputStreamArray(partitionId) } else { - partitionBufferArray += new PmemBlockObjectStream(serializerManager, - serInstance, + pmemBlockOutputStreamArray += new PmemBlockOutputStream( context.taskMetrics(), PmemBlockId.getTempBlockId(stageId), + SparkEnv.get.serializerManager, + serializer, SparkEnv.get.conf, + pmofConf, 1, numPartitions) - partitionBufferArray(partitionBufferArray.length - 1) + pmemBlockOutputStreamArray(pmemBlockOutputStreamArray.length - 1) } } @@ -84,11 +81,9 @@ private[spark] class PmemExternalSorter[K, V, C]( override protected[this] def maybeSpill(collection: WritablePartitionedPairCollection[K, C], currentMemory: Long): Boolean = { var shouldSpill = false - if (elementsRead % 32 == 0 && currentMemory >= inMemoryCollectionSizeThreshold) { - shouldSpill = currentMemory >= inMemoryCollectionSizeThreshold - //logInfo("maybeSpill") + if (elementsRead % 32 == 0 && currentMemory >= pmofConf.inMemoryCollectionSizeThreshold) { + shouldSpill = currentMemory >= pmofConf.inMemoryCollectionSizeThreshold } - // Actually spill if (shouldSpill) { spill(collection) } @@ -101,10 +96,10 @@ private[spark] class PmemExternalSorter[K, V, C]( } private[this] def spillMemoryIteratorToPmem(inMemoryIterator: WritablePartitionedIterator): Unit = { - var buffer: PmemBlockObjectStream = null + var buffer: PmemBlockOutputStream = null var cur_partitionId = -1 while (inMemoryIterator.hasNext) { - var partitionId = inMemoryIterator.nextPartition() + val partitionId = inMemoryIterator.nextPartition() if (cur_partitionId != partitionId) { if (cur_partitionId != -1) { buffer.maybeSpill(true) @@ -118,11 +113,15 @@ private[spark] class PmemExternalSorter[K, V, C]( val elem = if (inMemoryIterator.hasNext) inMemoryIterator.writeNext(buffer) else null //elementsPerPartition(partitionId) += 1 } - buffer.maybeSpill(true) + if (buffer != null) { + buffer.maybeSpill(true) + } } override def stop(): Unit = { - partitionBufferArray.foreach(_.close()) + if (mapStage) { + pmemBlockOutputStreamArray.foreach(_.close()) + } } /** @@ -164,21 +163,22 @@ private[spark] class PmemExternalSorter[K, V, C]( def getCollection(variableName: String): WritablePartitionedPairCollection[K, C] = { import java.lang.reflect._ // use reflection to get private map or buffer - var privateField: Field = this.getClass().getSuperclass().getDeclaredField(variableName) + val privateField: Field = this.getClass().getSuperclass().getDeclaredField(variableName) privateField.setAccessible(true) - var fieldValue = privateField.get(this) + val fieldValue = privateField.get(this) fieldValue.asInstanceOf[WritablePartitionedPairCollection[K, C]] } override def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) getCollection("map") else getCollection("buffer") - if (partitionBufferArray.isEmpty) { + if (pmemBlockOutputStreamArray.isEmpty) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID - if (!ordering.isDefined) { + if (ordering.isEmpty) { // The user hasn't requested sorted keys, so only sort by partition ID, not key - groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None))) + groupByPartition(destructiveIterator( + collection.partitionedDestructiveSortedIterator(None))) } else { // We do need to sort by both partition ID and key groupByPartition(destructiveIterator( @@ -194,7 +194,7 @@ private[spark] class PmemExternalSorter[K, V, C]( def merge(inMemory: Iterator[((Int, K), C)]): Iterator[(Int, Iterator[Product2[K, C]])] = { // this function is used to merge spilled data with inMemory records val inMemBuffered = inMemory.buffered - val readers = partitionBufferArray.map(partitionBuffer => {new SpillReader(partitionBuffer)}) + val readers: ArrayBuffer[SpillReader] = pmemBlockOutputStreamArray.map(pmemBlockOutputStream => {new SpillReader(pmemBlockOutputStream)}) (0 until numPartitions).iterator.map { partitionId => val inMemIterator = new IteratorForPartition(partitionId, inMemBuffered) val iterators = readers.map(_.readPartitionIter(partitionId)) ++ Seq(inMemIterator) @@ -226,7 +226,7 @@ private[spark] class PmemExternalSorter[K, V, C]( }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { - override def hasNext: Boolean = !heap.isEmpty + override def hasNext: Boolean = heap.nonEmpty override def next(): Product2[K, C] = { if (!hasNext) { @@ -303,7 +303,7 @@ private[spark] class PmemExternalSorter[K, V, C]( } else { // We have a total ordering, so the objects with the same key are sequential. new Iterator[Product2[K, C]] { - val sorted = mergeSort(iterators, comparator).buffered + val sorted: BufferedIterator[Product2[K, C]] = mergeSort(iterators, comparator).buffered override def hasNext: Boolean = sorted.hasNext @@ -324,28 +324,21 @@ private[spark] class PmemExternalSorter[K, V, C]( } } - class SpillReader(writeBuffer: PmemBlockObjectStream) { + class SpillReader(pmemBlockOutputStream: PmemBlockOutputStream) { // Each spill reader is relate to one partition // which is different from spark original codes (relate to one spill file) - val blockId = writeBuffer.getBlockId() - var indexInBatch: Int = 0 - var partitionMetaIndex: Int = 0 - - var inStream: InputStream = _ - var inObjStream: DeserializationStream = _ + val pmemBlockInputStream = new PmemBlockInputStream[K, C](pmemBlockOutputStream, serializer) var nextItem: (K, C) = _ - var total_records: Long = 0 - loadStream() def readPartitionIter(partitionId: Int): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] { override def hasNext: Boolean = { if (nextItem == null) { - nextItem = readNextItem() + nextItem = pmemBlockInputStream.readNextItem() if (nextItem == null) { return false } } - return getPartition(nextItem._1) == partitionId + getPartition(nextItem._1) == partitionId } override def next(): Product2[K, C] = { @@ -357,38 +350,5 @@ private[spark] class PmemExternalSorter[K, V, C]( item } } - - private def readNextItem(): (K, C) = { - if (inObjStream == null) { - if (inStream != null) { - inStream.close() - inStream.asInstanceOf[PmemInputStream].deleteBlock() - } - return null - } - try{ - val k = inObjStream.readObject().asInstanceOf[K] - val c = inObjStream.readObject().asInstanceOf[C] - indexInBatch += 1 - if (indexInBatch == total_records) { - inObjStream = null - } - (k, c) - } catch { - case ex: KryoException => { - logError("Kyro deserialization failed, loaded records count is " + indexInBatch + ", error backtrace: " + ExceptionUtils.getStackTrace(ex)) - } - logError(ExceptionUtils.getStackTrace(ex)) - sys.exit(0) - } - } - - def loadStream(): Unit = { - total_records = writeBuffer.getTotalRecords() - inStream = writeBuffer.getInputStream() - val wrappedStream = serializerManager.wrapStream(blockId, inStream) - inObjStream = serInstance.deserializeStream(wrappedStream) - indexInBatch = 0 - } } } diff --git a/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala b/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala new file mode 100644 index 00000000..67bf1ca9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala @@ -0,0 +1,29 @@ +package org.apache.spark.util.configuration.pmof + +import org.apache.spark.SparkConf + +class PmofConf(conf: SparkConf) { + val enableRdma: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_rdma", defaultValue = true) + val enablePmem: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) + val path_list: List[String] = conf.get("spark.shuffle.pmof.pmem_list").split(",").map(_.trim).toList + val maxPoolSize: Long = conf.getLong("spark.shuffle.pmof.pmpool_size", defaultValue = 1073741824) + val maxStages: Int = conf.getInt("spark.shuffle.pmof.max_stage_num", defaultValue = 1000) + val spill_throttle: Long = conf.getLong("spark.shuffle.pmof.spill_throttle", defaultValue = 4194304) + val inMemoryCollectionSizeThreshold: Long = + conf.getLong("spark.shuffle.spill.pmof.MemoryThreshold", 5 * 1024 * 1024) + val networkBufferSize: Int = conf.getInt("spark.shuffle.pmof.network_buffer_size", 4096 * 3) + val driverHost: String = conf.get("spark.driver.rhost", defaultValue = "172.168.0.43") + val driverPort: Int = conf.getInt("spark.driver.rport", defaultValue = 61000) + val serverBufferNums: Int = conf.getInt("spark.shuffle.pmof.server_buffer_nums", 256) + val serverWorkerNums = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) + val clientBufferNums: Int = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16) + val clientWorkerNums = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) + val shuffleNodes: Array[Array[String]] = + conf.get("spark.shuffle.pmof.node", defaultValue = "").split(",").map(_.split("-")) + val map_serializer_buffer_size = conf.getLong("spark.shuffle.pmof.map_serializer_buffer_size", 16 * 1024) + val reduce_serializer_buffer_size = conf.getLong("spark.shuffle.pmof.reduce_serializer_buffer_size", 16 * 1024) + val metadataCompress: Boolean = conf.getBoolean("spark.shuffle.pmof.metadata_compress", defaultValue = false) + val shuffleBlockSize: Int = conf.getInt("spark.shuffle.pmof.shuffle_block_size", defaultValue = 2048) + val pmemCapacity: Long = conf.getLong("spark.shuffle.pmof.pmem_capacity", defaultValue = 264239054848L) + val pmemCoreMap = conf.get("spark.shuffle.pmof.dev_core_set").split(";").map(_.trim).map(_.split(":")).map(arr => arr(0) -> arr(1)).toMap +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala index d6658593..16ab0ceb 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.shuffle.pmof import scala.collection.mutable.ArrayBuffer - import org.mockito.Mockito._ import org.mockito.MockitoAnnotations import org.scalatest.Matchers import org.scalatest.BeforeAndAfterEach - import org.apache.spark._ -import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer._ import org.apache.spark.util.Utils @@ -33,6 +31,7 @@ import org.apache.spark.storage.pmof._ import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} import org.apache.spark.util.Utils import org.apache.spark.util.collection.pmof.PmemExternalSorter +import org.apache.spark.util.configuration.pmof.PmofConf class PmemShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { @@ -43,6 +42,7 @@ class PmemShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with private var shuffleBlockResolver: PmemShuffleBlockResolver = _ private var serializer: KryoSerializer = _ private var serializerManager: SerializerManager = _ + private var pmofConf: PmofConf = _ private var taskMetrics: TaskMetrics = _ private var partitioner: Partitioner = _ @@ -55,6 +55,7 @@ class PmemShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with shuffleBlockResolver = new PmemShuffleBlockResolver(conf) serializer = new KryoSerializer(conf) serializerManager = new SerializerManager(serializer, conf) + pmofConf = new PmofConf(conf) taskMetrics = new TaskMetrics() partitioner = new Partitioner() { def numPartitions = 1 @@ -102,7 +103,8 @@ class PmemShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with shuffleHandle, mapId = 2, context, - conf) + conf, + pmofConf) writer.write(records.toIterator) writer.stop(success = true) val buf = shuffleBlockResolver.getBlockData(blockId).asInstanceOf[PmemManagedBuffer] diff --git a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala index 039d18e8..b1dff9e7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.shuffle.pmof import scala.collection.mutable.ArrayBuffer - import org.mockito.Mockito._ import org.mockito.MockitoAnnotations import org.scalatest.Matchers import org.scalatest.BeforeAndAfterEach - import org.apache.spark._ -import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer._ import org.apache.spark.util.Utils @@ -33,6 +31,7 @@ import org.apache.spark.storage.pmof._ import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} import org.apache.spark.util.Utils import org.apache.spark.util.collection.pmof.PmemExternalSorter +import org.apache.spark.util.configuration.pmof.PmofConf class PmemShuffleWriterWithSortSuite extends SparkFunSuite with SharedSparkContext with Matchers { @@ -43,6 +42,7 @@ class PmemShuffleWriterWithSortSuite extends SparkFunSuite with SharedSparkConte private var shuffleBlockResolver: PmemShuffleBlockResolver = _ private var serializer: KryoSerializer = _ private var serializerManager: SerializerManager = _ + private var pmofConf: PmofConf = _ private var taskMetrics: TaskMetrics = _ private var partitioner: Partitioner = _ @@ -55,6 +55,7 @@ class PmemShuffleWriterWithSortSuite extends SparkFunSuite with SharedSparkConte shuffleBlockResolver = new PmemShuffleBlockResolver(conf) serializer = new KryoSerializer(conf) serializerManager = new SerializerManager(serializer, conf) + pmofConf = new PmofConf(conf) taskMetrics = new TaskMetrics() partitioner = new Partitioner() { def numPartitions = 1 @@ -105,7 +106,8 @@ class PmemShuffleWriterWithSortSuite extends SparkFunSuite with SharedSparkConte shuffleHandle, mapId = 2, context, - conf) + conf, + pmofConf) writer.write(records.toIterator) writer.stop(success = true) val buf = shuffleBlockResolver.getBlockData(blockId).asInstanceOf[PmemManagedBuffer] diff --git a/core/src/test/scala/org/apache/spark/storage/pmof/PmemBlockObjectStreamSuite.scala b/core/src/test/scala/org/apache/spark/storage/pmof/PmemBlockObjectStreamSuite.scala deleted file mode 100644 index d83bf95d..00000000 --- a/core/src/test/scala/org/apache/spark/storage/pmof/PmemBlockObjectStreamSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -package org.apache.spark.storage.pmof - -import java.io.File - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.serializer._ -import org.apache.spark.util.Utils -import org.apache.spark.storage._ -import org.apache.spark.storage.pmof._ - -class PmemBlockObjectStreamSuite extends SparkFunSuite with BeforeAndAfterEach { - - val conf = new SparkConf() - conf.set("spark.shuffle.pmof.enable_rdma", "false") - conf.set("spark.shuffle.pmof.enable_pmem", "true") - //val serializer = new JavaSerializer(conf) - val serializer = new KryoSerializer(conf) - val serializerManager = new SerializerManager(serializer, conf) - val taskMetrics = new TaskMetrics() - - override def beforeEach(): Unit = { - super.beforeEach() - } - - override def afterEach(): Unit = { - super.afterEach() - } - - private def createWriter(blockId: ShuffleBlockId): (PmemBlockObjectStream, ShuffleWriteMetrics) = { - val writeMetrics = taskMetrics.shuffleWriteMetrics - val writer = new PmemBlockObjectStream( - serializerManager, serializer.newInstance(), taskMetrics, blockId, conf, 100, 100) - (writer, writeMetrics) - } - - test("verify ShuffleWrite of Shuffle_0_0_0, then check read") { - val blockId = ShuffleBlockId(0, 0, 0) - val (writer, writeMetrics) = createWriter(blockId) - val key: String = "key" - val value: String = "value" - writer.write(key, value) - // Record metrics update on every write - assert(writeMetrics.recordsWritten === 1) - // Metrics don't update on every write - assert(writeMetrics.bytesWritten == 0) - // write then flush, metrics should update - writer.write(key, value) - writer.flush() - assert(writeMetrics.recordsWritten === 2) - writer.close() - - val inStream = writer.getInputStream() - val wrappedStream = serializerManager.wrapStream(blockId, inStream) - val inObjStream = serializer.newInstance().deserializeStream(wrappedStream) - val k = inObjStream.readObject().asInstanceOf[String] - val v = inObjStream.readObject().asInstanceOf[String] - assert(k.equals(key)) - assert(v.equals(value)) - inObjStream.close() - } - - test("verify ShuffleRead of Shuffle_0_0_0") { - val blockId = ShuffleBlockId(0, 0, 0) - val persistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler - val buf = persistentMemoryHandler.getPartitionManagedBuffer(blockId.name) - val inStream = buf.createInputStream() - val wrappedStream = serializerManager.wrapStream(blockId, inStream) - val inObjStream = serializer.newInstance().deserializeStream(wrappedStream) - val k = inObjStream.readObject().asInstanceOf[String] - val v = inObjStream.readObject().asInstanceOf[String] - assert(k.equals("key")) - assert(v.equals("value")) - inObjStream.close() - } - - test("verify ShuffleRead of none exists Shuffle_0_0_1") { - val blockId = ShuffleBlockId(0, 0, 1) - val persistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler - val buf = persistentMemoryHandler.getPartitionManagedBuffer(blockId.name) - - val inStream = buf.createInputStream() - val wrappedStream = serializerManager.wrapStream(blockId, inStream) - val inObjStream = serializer.newInstance().deserializeStream(wrappedStream) - try{ - val k = inObjStream.readObject().asInstanceOf[String] - val v = inObjStream.readObject().asInstanceOf[String] - } catch { - case ex: java.io.EOFException => - logInfo(s"Expected Error: $ex") - } - inObjStream.close() - } -}