Skip to content

Commit

Permalink
[Improve][Zeta] Split the classloader of task group (#7580)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hisoka-X authored Sep 27, 2024
1 parent dc7f695 commit 3be0d1c
Show file tree
Hide file tree
Showing 27 changed files with 414 additions and 155 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ jobs:
echo "engine-e2e=$true_or_false" >> $GITHUB_OUTPUT
echo "engine-e2e_files=$file_list" >> $GITHUB_OUTPUT
api_files=`python tools/update_modules_check/check_file_updates.py ua $workspace apache/dev origin/$current_branch "seatunnel-api/**" "seatunnel-common/**" "seatunnel-config/**" "seatunnel-connectors/**" "seatunnel-core/**" "seatunnel-e2e/seatunnel-e2e-common/**" "seatunnel-formats/**" "seatunnel-plugin-discovery/**" "seatunnel-transforms-v2/**" "seatunnel-translation/**" "seatunnel-e2e/seatunnel-transforms-v2-e2e/**" "seatunnel-connectors/**" "pom.xml" "**/workflows/**" "tools/**" "seatunnel-dist/**"`
api_files=`python tools/update_modules_check/check_file_updates.py ua $workspace apache/dev origin/$current_branch "seatunnel-api/**" "seatunnel-common/**" "seatunnel-config/**" "seatunnel-engine/**" "seatunnel-core/**" "seatunnel-e2e/seatunnel-e2e-common/**" "seatunnel-formats/**" "seatunnel-plugin-discovery/**" "seatunnel-transforms-v2/**" "seatunnel-translation/**" "seatunnel-e2e/seatunnel-transforms-v2-e2e/**" "pom.xml" "**/workflows/**" "tools/**" "seatunnel-dist/**"`
true_or_false=${api_files%%$'\n'*}
file_list=${api_files#*$'\n'}
if [[ $repository_owner == 'apache' ]];then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public class HiveSourceConfig implements Serializable {

private static final long serialVersionUID = 1L;

private final Table table;
private final CatalogTable catalogTable;
private final FileFormat fileFormat;
private final ReadStrategy readStrategy;
Expand All @@ -81,7 +80,7 @@ public HiveSourceConfig(ReadonlyConfig readonlyConfig) {
readonlyConfig
.getOptional(HdfsSourceConfigOptions.READ_PARTITIONS)
.ifPresent(this::validatePartitions);
this.table = HiveTableUtils.getTableInfo(readonlyConfig);
Table table = HiveTableUtils.getTableInfo(readonlyConfig);
this.hadoopConf = parseHiveHadoopConfig(readonlyConfig, table);
this.fileFormat = HiveTableUtils.parseFileFormat(table);
this.readStrategy = parseReadStrategy(table, readonlyConfig, fileFormat, hadoopConf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ private void startMaster() {

private void startWorker() {
taskExecutionService =
new TaskExecutionService(
classLoaderService, nodeEngine, nodeEngine.getProperties(), eventService);
new TaskExecutionService(classLoaderService, nodeEngine, eventService);
nodeEngine.getMetricsRegistry().registerDynamicMetricsProvider(taskExecutionService);
taskExecutionService.start();
getSlotService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.seatunnel.engine.server.execution.TaskGroup;
import org.apache.seatunnel.engine.server.execution.TaskGroupContext;
import org.apache.seatunnel.engine.server.execution.TaskGroupLocation;
import org.apache.seatunnel.engine.server.execution.TaskGroupUtils;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.execution.TaskTracker;
import org.apache.seatunnel.engine.server.metrics.SeaTunnelMetricsContext;
Expand All @@ -65,13 +66,12 @@
import com.hazelcast.map.IMap;
import com.hazelcast.spi.impl.NodeEngineImpl;
import com.hazelcast.spi.impl.operationservice.impl.InvocationFuture;
import com.hazelcast.spi.properties.HazelcastProperties;
import lombok.Getter;
import lombok.NonNull;
import lombok.SneakyThrows;

import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -149,7 +149,6 @@ public class TaskExecutionService implements DynamicMetricsProvider {
public TaskExecutionService(
ClassLoaderService classLoaderService,
NodeEngineImpl nodeEngine,
HazelcastProperties properties,
EventService eventService) {
seaTunnelConfig = ConfigProvider.locateAndGetSeaTunnelConfig();
this.hzInstanceName = nodeEngine.getHazelcastInstance().getName();
Expand Down Expand Up @@ -282,33 +281,50 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
taskImmutableInfo.getExecutionId()));
TaskGroup taskGroup = null;
try {
Set<ConnectorJarIdentifier> connectorJarIdentifiers =
List<Set<ConnectorJarIdentifier>> connectorJarIdentifiersList =
taskImmutableInfo.getConnectorJarIdentifiers();
Set<URL> jars = new HashSet<>();
ClassLoader classLoader;
if (!CollectionUtils.isEmpty(connectorJarIdentifiers)) {
// Prioritize obtaining the jar package file required for the current task execution
// from the local, if it does not exist locally, it will be downloaded from the
// master node.
jars =
serverConnectorPackageClient.getConnectorJarFromLocal(
connectorJarIdentifiers);
} else if (!CollectionUtils.isEmpty(taskImmutableInfo.getJars())) {
jars = taskImmutableInfo.getJars();
}
classLoader =
classLoaderService.getClassLoader(
taskImmutableInfo.getJobId(), Lists.newArrayList(jars));
if (jars.isEmpty()) {
taskGroup =
nodeEngine.getSerializationService().toObject(taskImmutableInfo.getGroup());
} else {
taskGroup =
CustomClassLoadedObject.deserializeWithCustomClassLoader(
nodeEngine.getSerializationService(),
classLoader,
taskImmutableInfo.getGroup());
List<Data> taskData = taskImmutableInfo.getTasksData();
ConcurrentHashMap<Long, ClassLoader> classLoaders = new ConcurrentHashMap<>();
List<Task> tasks = new ArrayList<>();
ConcurrentHashMap<Long, Collection<URL>> taskJars = new ConcurrentHashMap<>();
for (int i = 0; i < taskData.size(); i++) {
Set<URL> jars = new HashSet<>();
Set<ConnectorJarIdentifier> connectorJarIdentifiers =
connectorJarIdentifiersList.get(i);
if (!CollectionUtils.isEmpty(connectorJarIdentifiers)) {
// Prioritize obtaining the jar package file required for the current task
// execution
// from the local, if it does not exist locally, it will be downloaded from the
// master node.
jars =
serverConnectorPackageClient.getConnectorJarFromLocal(
connectorJarIdentifiers);
} else if (!CollectionUtils.isEmpty(taskImmutableInfo.getJars().get(i))) {
jars = taskImmutableInfo.getJars().get(i);
}
ClassLoader classLoader =
classLoaderService.getClassLoader(
taskImmutableInfo.getJobId(), Lists.newArrayList(jars));
Task task;
if (jars.isEmpty()) {
task = nodeEngine.getSerializationService().toObject(taskData.get(i));
} else {
task =
CustomClassLoadedObject.deserializeWithCustomClassLoader(
nodeEngine.getSerializationService(),
classLoader,
taskData.get(i));
}
tasks.add(task);
classLoaders.put(task.getTaskID(), classLoader);
taskJars.put(task.getTaskID(), jars);
}
taskGroup =
TaskGroupUtils.createTaskGroup(
taskImmutableInfo.getTaskGroupType(),
taskImmutableInfo.getTaskGroupLocation(),
taskImmutableInfo.getTaskGroupName(),
tasks);

logger.info(
String.format(
Expand All @@ -322,7 +338,7 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
"TaskGroupLocation: %s already exists",
taskGroup.getTaskGroupLocation()));
}
deployLocalTask(taskGroup, classLoader, jars);
deployLocalTask(taskGroup, classLoaders, taskJars);
return TaskDeployState.success();
}
} catch (Throwable t) {
Expand All @@ -337,15 +353,10 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
}
}

@Deprecated
public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
@NonNull TaskGroup taskGroup) {
return deployLocalTask(
taskGroup, Thread.currentThread().getContextClassLoader(), emptyList());
}

public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
@NonNull TaskGroup taskGroup, @NonNull ClassLoader classLoader, Collection<URL> jars) {
@NonNull TaskGroup taskGroup,
@NonNull ConcurrentHashMap<Long, ClassLoader> classLoaders,
ConcurrentHashMap<Long, Collection<URL>> jars) {
CompletableFuture<TaskExecutionState> resultFuture = new CompletableFuture<>();
try {
taskGroup.init();
Expand Down Expand Up @@ -389,7 +400,7 @@ public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
}));
executionContexts.put(
taskGroup.getTaskGroupLocation(),
new TaskGroupContext(taskGroup, classLoader, jars));
new TaskGroupContext(taskGroup, classLoaders, jars));
cancellationFutures.put(taskGroup.getTaskGroupLocation(), cancellationFuture);
submitThreadShareTask(executionTracker, byCooperation.get(true));
submitBlockingTask(executionTracker, byCooperation.get(false));
Expand Down Expand Up @@ -591,7 +602,7 @@ private void updateMetricsContextInImap() {
}
});
});
if (localMap.size() > 0) {
if (!localMap.isEmpty()) {
boolean lockedIMap = false;
try {
lockedIMap =
Expand Down Expand Up @@ -669,7 +680,8 @@ public void run() {
ClassLoader classLoader =
executionContexts
.get(taskGroupExecutionTracker.taskGroup.getTaskGroupLocation())
.getClassLoader();
.getClassLoaders()
.get(tracker.task.getTaskID());
ClassLoader oldClassLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(classLoader);
final Task t = tracker.task;
Expand Down Expand Up @@ -728,16 +740,16 @@ public final class CooperativeTaskWorker implements Runnable {
public AtomicReference<TaskTracker> exclusiveTaskTracker = new AtomicReference<>();
final TaskCallTimer timer;
private Thread myThread;
public LinkedBlockingDeque<TaskTracker> taskqueue;
public LinkedBlockingDeque<TaskTracker> taskQueue;
private Future<?> thisTaskFuture;
private BlockingQueue<Future<?>> futureBlockingQueue;

public CooperativeTaskWorker(
LinkedBlockingDeque<TaskTracker> taskqueue,
LinkedBlockingDeque<TaskTracker> taskQueue,
RunBusWorkSupplier runBusWorkSupplier,
BlockingQueue<Future<?>> futureBlockingQueue) {
logger.info(String.format("Created new BusWork : %s", this.hashCode()));
this.taskqueue = taskqueue;
this.taskQueue = taskQueue;
this.timer = new TaskCallTimer(50, keep, runBusWorkSupplier, this);
this.futureBlockingQueue = futureBlockingQueue;
}
Expand All @@ -752,7 +764,7 @@ public void run() {
TaskTracker taskTracker =
null != exclusiveTaskTracker.get()
? exclusiveTaskTracker.get()
: taskqueue.takeFirst();
: taskQueue.takeFirst();
TaskGroupExecutionTracker taskGroupExecutionTracker =
taskTracker.taskGroupExecutionTracker;
if (taskGroupExecutionTracker.executionCompletedExceptionally()) {
Expand All @@ -777,7 +789,8 @@ public void run() {
myThread.setContextClassLoader(
executionContexts
.get(taskGroupExecutionTracker.taskGroup.getTaskGroupLocation())
.getClassLoader());
.getClassLoaders()
.get(taskTracker.task.getTaskID()));
call = taskTracker.task.call();
synchronized (timer) {
timer.timerStop();
Expand Down Expand Up @@ -819,7 +832,7 @@ public void run() {
// Task is not completed. Put task to the end of the queue
// If the current work has an exclusive tracker, it will not be put back
if (null == exclusiveTaskTracker.get()) {
taskqueue.offer(taskTracker);
taskQueue.offer(taskTracker);
}
}
}
Expand All @@ -840,7 +853,7 @@ public RunBusWorkSupplier(
}

public boolean runNewBusWork(boolean checkTaskQueue) {
if (!checkTaskQueue || taskQueue.size() > 0) {
if (!checkTaskQueue || !taskQueue.isEmpty()) {
BlockingQueue<Future<?>> futureBlockingQueue = new LinkedBlockingQueue<>();
CooperativeTaskWorker cooperativeTaskWorker =
new CooperativeTaskWorker(taskQueue, this, futureBlockingQueue);
Expand All @@ -867,7 +880,7 @@ public final class TaskGroupExecutionTracker {

private final AtomicBoolean isCancel = new AtomicBoolean(false);

@Getter private Map<Long, Future<?>> currRunningTaskFuture = new ConcurrentHashMap<>();
private final Map<Long, Future<?>> currRunningTaskFuture = new ConcurrentHashMap<>();

TaskGroupExecutionTracker(
@NonNull CompletableFuture<Void> cancellationFuture,
Expand Down Expand Up @@ -972,8 +985,10 @@ void taskDone(Task task) {

private void recycleClassLoader(TaskGroupLocation taskGroupLocation) {
TaskGroupContext context = executionContexts.get(taskGroupLocation);
executionContexts.get(taskGroupLocation).setClassLoader(null);
classLoaderService.releaseClassLoader(taskGroupLocation.getJobId(), context.getJars());
executionContexts.get(taskGroupLocation).setClassLoaders(null);
for (Collection<URL> jars : context.getJars().values()) {
classLoaderService.releaseClassLoader(taskGroupLocation.getJobId(), jars);
}
}

boolean executionCompletedExceptionally() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public void runInternal() throws Exception {
.getExecutionContext(taskLocation.getTaskGroupLocation());
Task task = groupContext.getTaskGroup().getTask(taskLocation.getTaskID());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(groupContext.getClassLoader());
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader(taskLocation.getTaskID()));

task.notifyCheckpointEnd(checkpointId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public void runInternal() throws Exception {
.getExecutionContext(taskLocation.getTaskGroupLocation());
Task task = groupContext.getTaskGroup().getTask(taskLocation.getTaskID());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(groupContext.getClassLoader());
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader(taskLocation.getTaskID()));
if (successful) {
task.notifyCheckpointComplete(checkpointId);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public void runInternal() throws Exception {
() -> {
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader());
groupContext.getClassLoader(
task.getTaskID()));
try {
log.debug(
"NotifyTaskRestoreOperation.restoreState "
Expand Down
Loading

0 comments on commit 3be0d1c

Please sign in to comment.