diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java index 3b5e38b559..f9111e0406 100644 --- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java +++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java @@ -201,6 +201,10 @@ public class RssTezConfig { public static final String RSS_SHUFFLE_DESTINATION_VERTEX_ID = TEZ_RSS_CONFIG_PREFIX + "rss.shuffle.destination.vertex.id"; + public static final String RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK = + TEZ_RSS_CONFIG_PREFIX + "rss.avoid.recompute.succeeded.task"; + public static final boolean RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT = false; + public static RssConf toRssConf(Configuration jobConf) { RssConf rssConf = new RssConf(); for (Map.Entry entry : jobConf) { diff --git a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java index d996ea7c38..608f26d044 100644 --- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java +++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java @@ -37,12 +37,14 @@ import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.event.EventHandler; import org.apache.hadoop.yarn.util.Clock; import org.apache.hadoop.yarn.util.ConverterUtils; import org.apache.hadoop.yarn.util.SystemClock; import org.apache.log4j.LogManager; import org.apache.log4j.helpers.Loader; import org.apache.log4j.helpers.OptionConverter; +import org.apache.tez.common.AsyncDispatcher; import org.apache.tez.common.RssTezConfig; import org.apache.tez.common.RssTezUtils; import org.apache.tez.common.TezClassLoader; @@ -56,10 +58,15 @@ import org.apache.tez.dag.api.OutputDescriptor; import org.apache.tez.dag.api.TezConstants; import org.apache.tez.dag.api.TezUncheckedException; +import org.apache.tez.dag.api.oldrecords.TaskAttemptState; import org.apache.tez.dag.api.records.DAGProtos; import org.apache.tez.dag.api.records.DAGProtos.AMPluginDescriptorProto; import org.apache.tez.dag.app.dag.DAG; import org.apache.tez.dag.app.dag.DAGState; +import org.apache.tez.dag.app.dag.Task; +import org.apache.tez.dag.app.dag.TaskAttempt; +import org.apache.tez.dag.app.dag.event.TaskAttemptEvent; +import org.apache.tez.dag.app.dag.event.TaskAttemptEventType; import org.apache.tez.dag.app.dag.impl.DAGImpl; import org.apache.tez.dag.app.dag.impl.Edge; import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager; @@ -76,8 +83,12 @@ import static org.apache.log4j.LogManager.DEFAULT_CONFIGURATION_KEY; import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS; import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT; +import static org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK; +import static org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT; import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID; import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT; public class RssDAGAppMaster extends DAGAppMaster { private static final Logger LOG = LoggerFactory.getLogger(RssDAGAppMaster.class); @@ -125,6 +136,10 @@ public RssDAGAppMaster( @Override public synchronized void serviceInit(Configuration conf) throws Exception { super.serviceInit(conf); + if (conf.getBoolean( + RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)) { + overrideTaskAttemptEventDispatcher(); + } initAndStartRSSClient(this, conf); } @@ -336,6 +351,16 @@ public static void main(String[] args) { } } + if (conf.getBoolean( + RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT) + && conf.getBoolean( + TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, + TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)) { + LOG.info( + "When rss.avoid.recompute.succeeded.task is enable, " + + "we can not rescheduler succeeded task on unhealthy node"); + conf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, false); + } initAndStartAppMaster(appMaster, conf); } catch (Throwable t) { LOG.error("Error starting RssDAGAppMaster", t); @@ -476,7 +501,7 @@ private static Object getPrivateField(Object object, String name) { } } - private static void reconfigureLog4j() { + static void reconfigureLog4j() { String configuratorClassName = OptionConverter.getSystemProperty(CONFIGURATOR_CLASS_KEY, null); String configurationOptionStr = OptionConverter.getSystemProperty(DEFAULT_CONFIGURATION_KEY, null); @@ -484,4 +509,42 @@ private static void reconfigureLog4j() { OptionConverter.selectAndConfigure( url, configuratorClassName, LogManager.getLoggerRepository()); } + + protected void overrideTaskAttemptEventDispatcher() + throws NoSuchFieldException, IllegalAccessException { + AsyncDispatcher dispatcher = (AsyncDispatcher) this.getDispatcher(); + Field field = dispatcher.getClass().getDeclaredField("eventHandlers"); + field.setAccessible(true); + Map, EventHandler> eventHandlers = + (Map, EventHandler>) field.get(dispatcher); + eventHandlers.put(TaskAttemptEventType.class, new RssTaskAttemptEventDispatcher()); + } + + private class RssTaskAttemptEventDispatcher implements EventHandler { + @SuppressWarnings("unchecked") + @Override + public void handle(TaskAttemptEvent event) { + DAG dag = getContext().getCurrentDAG(); + int eventDagIndex = event.getTaskAttemptID().getTaskID().getVertexID().getDAGId().getId(); + if (dag == null || eventDagIndex != dag.getID().getId()) { + return; // event not relevant any more + } + Task task = + dag.getVertex(event.getTaskAttemptID().getTaskID().getVertexID()) + .getTask(event.getTaskAttemptID().getTaskID()); + TaskAttempt attempt = task.getAttempt(event.getTaskAttemptID()); + + if (attempt.getState() == TaskAttemptState.SUCCEEDED + && event.getType() == TaskAttemptEventType.TA_NODE_FAILED) { + // Here we only handle TA_NODE_FAILED. TA_KILL_REQUEST and TA_KILLED also could trigger + // TerminatedAfterSuccessTransition, but the reason is not about bad node. + LOG.info( + "We should not recompute the succeeded task attempt, though task attempt {} recieved envent {}", + attempt, + event); + return; + } + ((EventHandler) attempt).handle(event); + } + } } diff --git a/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java b/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java new file mode 100644 index 0000000000..9637a2f8bf --- /dev/null +++ b/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.app; + +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.util.ShutdownHookManager; +import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.NodeId; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.api.records.NodeState; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.event.EventHandler; +import org.apache.hadoop.yarn.util.Clock; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.hadoop.yarn.util.SystemClock; +import org.apache.tez.common.AsyncDispatcher; +import org.apache.tez.common.TezClassLoader; +import org.apache.tez.common.TezCommonUtils; +import org.apache.tez.common.TezUtilsInternal; +import org.apache.tez.common.VersionInfo; +import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.dag.api.oldrecords.TaskAttemptState; +import org.apache.tez.dag.api.records.DAGProtos; +import org.apache.tez.dag.app.dag.DAG; +import org.apache.tez.dag.app.dag.Task; +import org.apache.tez.dag.app.dag.TaskAttempt; +import org.apache.tez.dag.app.dag.event.TaskAttemptEvent; +import org.apache.tez.dag.app.dag.event.TaskAttemptEventAttemptFailed; +import org.apache.tez.dag.app.dag.event.TaskAttemptEventType; +import org.apache.tez.dag.app.rm.node.AMNodeEventStateChanged; +import org.apache.tez.dag.records.TaskAttemptTerminationCause; +import org.apache.tez.runtime.api.TaskFailureType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.exception.RssException; + +import static org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK; +import static org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/* + * RssDAGAppMasterForWordCountWithFailures is only used for TezWordCountWithFailuresTest. + * We want to simulate that some task have succeeded, but the node which these task have run is label as black list. + * Then we will verify whether these task is recompute or not. + * + * Two test mode are supported: + * (a) testMode 0 + * The test example is WordCount. The parallelism of Tokenizer is 5, it means at lease one node run more than one + * container. Here if a task succeeded in node1, will kill the next container runs on node1. Because + * maxtaskfailures.per.node is set to 1, so the node1 will labeled as black list, then verify whether the succeeded + * task is recomputed or not. + * + * (b) testMode 1 + * The test example is WordCount. The parallelism of Tokenizer is 5, it means at lease one node run more than one + * container. Here if a task succeeded in node1, then will label node1 as DECOMMISSIONED. Then verify whether + * the succeeded task is recomputed or not. + * */ +public class RssDAGAppMasterForWordCountWithFailures extends RssDAGAppMaster { + + private static final Logger LOG = + LoggerFactory.getLogger(RssDAGAppMasterForWordCountWithFailures.class); + + private final int testMode; + + public RssDAGAppMasterForWordCountWithFailures( + ApplicationAttemptId applicationAttemptId, + ContainerId containerId, + String nmHost, + int nmPort, + int nmHttpPort, + Clock clock, + long appSubmitTime, + boolean isSession, + String workingDirectory, + String[] localDirs, + String[] logDirs, + String clientVersion, + Credentials credentials, + String jobUserName, + DAGProtos.AMPluginDescriptorProto pluginDescriptorProto, + int testMode) { + super( + applicationAttemptId, + containerId, + nmHost, + nmPort, + nmHttpPort, + clock, + appSubmitTime, + isSession, + workingDirectory, + localDirs, + logDirs, + clientVersion, + credentials, + jobUserName, + pluginDescriptorProto); + this.testMode = testMode; + } + + @Override + public synchronized void serviceInit(Configuration conf) throws Exception { + super.serviceInit(conf); + overrideTaskAttemptEventDispatcher(); + } + + public static void main(String[] args) { + int testMode = 0; + try { + // We use trick way to introduce RssDAGAppMaster by the config tez.am.launch.cmd-opts. + // It means some property which is set by command line will be ingored, so we must reload it. + boolean sessionModeCliOption = false; + for (int i = 0; i < args.length; i++) { + if (args[i].startsWith("-D")) { + String[] property = args[i].split("="); + if (property.length < 2) { + System.setProperty(property[0].substring(2), ""); + } else { + System.setProperty(property[0].substring(2), property[1]); + } + } else if (args[i].contains("--session") || args[i].contains("-s")) { + sessionModeCliOption = true; + } else if (args[i].startsWith("--testMode")) { + testMode = Integer.parseInt(args[i].substring(10)); + } + } + // Load the log4j config is only init in static code block of LogManager, so we must + // reconfigure. + reconfigureLog4j(); + + // Install the tez class loader, which can be used add new resources + TezClassLoader.setupTezClassLoader(); + Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler()); + final String pid = System.getenv().get("JVM_PID"); + String containerIdStr = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()); + String appSubmitTimeStr = System.getenv(ApplicationConstants.APP_SUBMIT_TIME_ENV); + String clientVersion = System.getenv(TezConstants.TEZ_CLIENT_VERSION_ENV); + if (clientVersion == null) { + clientVersion = VersionInfo.UNKNOWN; + } + + Objects.requireNonNull( + appSubmitTimeStr, ApplicationConstants.APP_SUBMIT_TIME_ENV + " is null"); + + ContainerId containerId = ConverterUtils.toContainerId(containerIdStr); + ApplicationAttemptId applicationAttemptId = containerId.getApplicationAttemptId(); + + String jobUserName = System.getenv(ApplicationConstants.Environment.USER.name()); + + LOG.info( + "Creating RssDAGAppMaster for " + + "applicationId=" + + applicationAttemptId.getApplicationId() + + ", attemptNum=" + + applicationAttemptId.getAttemptId() + + ", AMContainerId=" + + containerId + + ", jvmPid=" + + pid + + ", userFromEnv=" + + jobUserName + + ", cliSessionOption=" + + sessionModeCliOption + + ", pwd=" + + System.getenv(ApplicationConstants.Environment.PWD.name()) + + ", localDirs=" + + System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name()) + + ", logDirs=" + + System.getenv(ApplicationConstants.Environment.LOG_DIRS.name())); + + Configuration conf = new Configuration(new YarnConfiguration()); + + DAGProtos.ConfigurationProto confProto = + TezUtilsInternal.readUserSpecifiedTezConfiguration( + System.getenv(ApplicationConstants.Environment.PWD.name())); + TezUtilsInternal.addUserSpecifiedTezConfiguration(conf, confProto.getConfKeyValuesList()); + + DAGProtos.AMPluginDescriptorProto amPluginDescriptorProto = null; + if (confProto.hasAmPluginDescriptor()) { + amPluginDescriptorProto = confProto.getAmPluginDescriptor(); + } + + UserGroupInformation.setConfiguration(conf); + Credentials credentials = UserGroupInformation.getCurrentUser().getCredentials(); + + TezUtilsInternal.setSecurityUtilConfigration(LOG, conf); + + String nodeHostString = System.getenv(ApplicationConstants.Environment.NM_HOST.name()); + String nodePortString = System.getenv(ApplicationConstants.Environment.NM_PORT.name()); + String nodeHttpPortString = + System.getenv(ApplicationConstants.Environment.NM_HTTP_PORT.name()); + long appSubmitTime = Long.parseLong(appSubmitTimeStr); + RssDAGAppMasterForWordCountWithFailures appMaster = + new RssDAGAppMasterForWordCountWithFailures( + applicationAttemptId, + containerId, + nodeHostString, + Integer.parseInt(nodePortString), + Integer.parseInt(nodeHttpPortString), + new SystemClock(), + appSubmitTime, + sessionModeCliOption, + System.getenv(ApplicationConstants.Environment.PWD.name()), + TezCommonUtils.getTrimmedStrings( + System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())), + TezCommonUtils.getTrimmedStrings( + System.getenv(ApplicationConstants.Environment.LOG_DIRS.name())), + clientVersion, + credentials, + jobUserName, + amPluginDescriptorProto, + testMode); + ShutdownHookManager.get() + .addShutdownHook( + new RssDAGAppMaster.RssDAGAppMasterShutdownHook(appMaster), SHUTDOWN_HOOK_PRIORITY); + + // log the system properties + if (LOG.isInfoEnabled()) { + String systemPropsToLog = TezCommonUtils.getSystemPropertiesToLog(conf); + if (systemPropsToLog != null) { + LOG.info(systemPropsToLog); + } + } + + LOG.info( + "recompute is {}, reschedule is {}", + conf.getBoolean( + RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT), + conf.getBoolean( + TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, + TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)); + if (conf.getBoolean( + RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT) + && conf.getBoolean( + TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, + TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)) { + LOG.info( + "When rss.avoid.recompute.succeeded.task is enable, " + + "we can not rescheduler succeeded task on unhealthy node"); + conf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, false); + } + initAndStartAppMaster(appMaster, conf); + } catch (Throwable t) { + LOG.error("Error starting RssDAGAppMaster", t); + System.exit(1); + } + } + + public void overrideTaskAttemptEventDispatcher() + throws NoSuchFieldException, IllegalAccessException { + AsyncDispatcher dispatcher = (AsyncDispatcher) this.getDispatcher(); + Field field = dispatcher.getClass().getDeclaredField("eventHandlers"); + field.setAccessible(true); + Map, EventHandler> eventHandlers = + (Map, EventHandler>) field.get(dispatcher); + eventHandlers.put( + TaskAttemptEventType.class, new RssTaskAttemptEventDispatcher(this.getConfig())); + } + + private class RssTaskAttemptEventDispatcher implements EventHandler { + + Map succeed = new HashMap<>(); + boolean killed = false; + boolean avoidRecompute; + + RssTaskAttemptEventDispatcher(Configuration conf) { + avoidRecompute = + conf.getBoolean( + RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT); + } + + @SuppressWarnings("unchecked") + @Override + public void handle(TaskAttemptEvent event) { + DAG dag = getContext().getCurrentDAG(); + int eventDagIndex = event.getTaskAttemptID().getTaskID().getVertexID().getDAGId().getId(); + if (dag == null || eventDagIndex != dag.getID().getId()) { + return; // event not relevant any more + } + Task task = + dag.getVertex(event.getTaskAttemptID().getTaskID().getVertexID()) + .getTask(event.getTaskAttemptID().getTaskID()); + TaskAttempt attempt = task.getAttempt(event.getTaskAttemptID()); + + LOG.info("handle task attempt event: {}", event); + if (avoidRecompute) { + if (attempt.getState() == TaskAttemptState.SUCCEEDED + && event.getType() == TaskAttemptEventType.TA_NODE_FAILED) { + LOG.info( + "We should not recompute the succeeded task attempt, though taskattempt {} recieved event {}", + attempt, + event); + return; + } + } + ((EventHandler) attempt).handle(event); + // For Tokenizer, record the first succeeded task and its node. When next task runs on this + // node, will kill this task or label this node as unhealthy. + int vertexId = attempt.getVertexID().getId(); + if (vertexId == 0) { + if (attempt.getState() == TaskAttemptState.SUCCEEDED) { + NodeId nodeId = attempt.getAssignedContainer().getNodeId(); + if (!succeed.containsKey(nodeId)) { + succeed.put(nodeId, 1); + } else { + succeed.put(nodeId, succeed.get(nodeId) + 1); + } + } else if (attempt.getState() == TaskAttemptState.RUNNING) { + NodeId nodeId = attempt.getAssignedContainer().getNodeId(); + if (succeed.getOrDefault(nodeId, 0) == 1 && !killed) { + if (testMode == 0) { + TaskAttemptEventAttemptFailed eventAttemptFailed = + new TaskAttemptEventAttemptFailed( + attempt.getID(), + TaskAttemptEventType.TA_FAILED, + TaskFailureType.NON_FATAL, + "Triggerd by " + this.getClass().getName(), + TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED); + LOG.info( + "Killing running task attempt: {} at node: {}", + attempt, + attempt.getAssignedContainer().getNodeId()); + ((EventHandler) attempt).handle(eventAttemptFailed); + dag.getEventHandler().handle(eventAttemptFailed); + } else if (testMode == 1) { + NodeReport nodeReport = mock(NodeReport.class); + when(nodeReport.getNodeState()).thenReturn(NodeState.DECOMMISSIONED); + when(nodeReport.getNodeId()).thenReturn(nodeId); + LOG.info( + "Label the node {} as DECOMMISSIONED." + + attempt.getAssignedContainer().getNodeId()); + dag.getEventHandler().handle(new AMNodeEventStateChanged(nodeReport, 0)); + } else { + throw new RssException("testMode " + testMode + " is not supported!"); + } + killed = true; + } + } + } + } + } +} diff --git a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java index f9584efd56..b5219efebf 100644 --- a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java +++ b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java @@ -61,7 +61,7 @@ public class TezIntegrationTestBase extends IntegrationTestBase { private static final Logger LOG = LoggerFactory.getLogger(TezIntegrationTestBase.class); private static String TEST_ROOT_DIR = - "target" + Path.SEPARATOR + TezWordCountTest.class.getName() + "-tmpDir"; + "target" + Path.SEPARATOR + TezIntegrationTestBase.class.getName() + "-tmpDir"; private Path remoteStagingDir = null; protected static MiniTezCluster miniTezCluster; diff --git a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java new file mode 100644 index 0000000000..14957df4b1 --- /dev/null +++ b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.tez.client.CallerContext; +import org.apache.tez.client.TezClient; +import org.apache.tez.client.TezClientUtils; +import org.apache.tez.common.RssTezConfig; +import org.apache.tez.common.TezUtilsInternal; +import org.apache.tez.dag.api.DAG; +import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TezException; +import org.apache.tez.dag.api.client.DAGClient; +import org.apache.tez.dag.api.client.DAGStatus; +import org.apache.tez.dag.api.client.Progress; +import org.apache.tez.dag.api.client.StatusGetOpts; +import org.apache.tez.dag.app.RssDAGAppMasterForWordCountWithFailures; +import org.apache.tez.examples.WordCount; +import org.apache.tez.hadoop.shim.HadoopShim; +import org.apache.tez.hadoop.shim.HadoopShimsLoader; +import org.apache.tez.test.MiniTezCluster; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.util.StorageType; + +import static org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_MAX_TASK_FAILURES_PER_NODE; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_BLACKLISTING_ENABLED; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_BLACKLISTING_IGNORE_THRESHOLD; +import static org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TezWordCountWithFailuresTest extends IntegrationTestBase { + + private static final Logger LOG = LoggerFactory.getLogger(TezIntegrationTestBase.class); + private static String TEST_ROOT_DIR = + "target" + Path.SEPARATOR + TezWordCountWithFailuresTest.class.getName() + "-tmpDir"; + + private Path remoteStagingDir = null; + private String inputPath = "word_count_input"; + private String outputPath = "word_count_output"; + private List wordTable = + Lists.newArrayList( + "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", "tomato"); + + protected static MiniTezCluster miniTezCluster; + + @BeforeAll + public static void beforeClass() throws Exception { + LOG.info("Starting mini tez clusters"); + if (miniTezCluster == null) { + miniTezCluster = new MiniTezCluster(TezIntegrationTestBase.class.getName(), 3, 1, 1); + miniTezCluster.init(conf); + miniTezCluster.start(); + } + LOG.info("Starting corrdinators and shuffer servers"); + CoordinatorConf coordinatorConf = getCoordinatorConf(); + Map dynamicConf = new HashMap(); + dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); + dynamicConf.put(RssTezConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE_HDFS.name()); + addDynamicConf(coordinatorConf, dynamicConf); + createCoordinatorServer(coordinatorConf); + ShuffleServerConf shuffleServerConf = getShuffleServerConf(); + createShuffleServer(shuffleServerConf); + startServers(); + } + + @AfterAll + public static void tearDown() throws Exception { + if (miniTezCluster != null) { + LOG.info("Stopping MiniTezCluster"); + miniTezCluster.stop(); + miniTezCluster = null; + } + } + + @BeforeEach + public void setup() throws Exception { + remoteStagingDir = + fs.makeQualified(new Path(TEST_ROOT_DIR, String.valueOf(new Random().nextInt(100000)))); + TezClientUtils.ensureStagingDirExists(conf, remoteStagingDir); + generateInputFile(); + } + + private void generateInputFile() throws Exception { + assertTrue(fs.mkdirs(new Path(inputPath))); + for (int j = 0; j < 5; j++) { + FSDataOutputStream outputStream = fs.create(new Path(inputPath + "/file." + j)); + Random random = new Random(); + for (int i = 0; i < 100; i++) { + int index = random.nextInt(wordTable.size()); + String str = wordTable.get(index) + "\n"; + outputStream.writeBytes(str); + } + outputStream.close(); + } + FileStatus[] fileStatus = fs.listStatus(new Path(inputPath)); + for (FileStatus status : fileStatus) { + System.out.println("status is " + status); + } + } + + @AfterEach + public void tearDownEach() throws Exception { + if (this.remoteStagingDir != null) { + fs.delete(this.remoteStagingDir, true); + } + for (int j = 0; j < 5; j++) { + fs.delete(new Path(inputPath + "/file." + j), true); + } + } + + @Test + public void wordCountTestWithTaskFailureWhenAvoidRecomputeEnable() throws Exception { + // 1 Run Tez examples based on rss + TezConfiguration appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateRssConfiguration(appConf, 0, true, false, 1); + TezIntegrationTestBase.appendAndUploadRssJars(appConf); + runTezApp(appConf, getTestArgs("rss"), 0); + final String rssPath = getOutputDir("rss"); + + // 2 Run original Tez examples + appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateCommonConfiguration(appConf); + runTezApp(appConf, getTestArgs("origin"), -1); + final String originPath = getOutputDir("origin"); + + // 3 verify the results + TezIntegrationTestBase.verifyResultEqual(originPath, rssPath); + } + + @Test + public void wordCountTestWithTaskFailureWhenAvoidRecomputeDisable() throws Exception { + // 1 Run Tez examples based on rss + TezConfiguration appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateRssConfiguration(appConf, 0, false, false, 1); + TezIntegrationTestBase.appendAndUploadRssJars(appConf); + runTezApp(appConf, getTestArgs("rss"), 1); + final String rssPath = getOutputDir("rss"); + + // 2 Run original Tez examples + appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateCommonConfiguration(appConf); + runTezApp(appConf, getTestArgs("origin"), -1); + final String originPath = getOutputDir("origin"); + + // 3 verify the results + TezIntegrationTestBase.verifyResultEqual(originPath, rssPath); + } + + @Test + public void wordCountTestWithNodeUnhealthyWhenAvoidRecomputeEnable() throws Exception { + // 1 Run Tez examples based on rss + TezConfiguration appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateRssConfiguration(appConf, 1, true, true, 100); + TezIntegrationTestBase.appendAndUploadRssJars(appConf); + runTezApp(appConf, getTestArgs("rss"), 0); + final String rssPath = getOutputDir("rss"); + + // 2 Run original Tez examples + appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateCommonConfiguration(appConf); + runTezApp(appConf, getTestArgs("origin"), -1); + final String originPath = getOutputDir("origin"); + + // 3 verify the results + TezIntegrationTestBase.verifyResultEqual(originPath, rssPath); + } + + @Test + public void wordCountTestWithNodeUnhealthyWhenAvoidRecomputeDisable() throws Exception { + // 1 Run Tez examples based on rss + TezConfiguration appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateRssConfiguration(appConf, 1, false, true, 100); + TezIntegrationTestBase.appendAndUploadRssJars(appConf); + runTezApp(appConf, getTestArgs("rss"), 1); + final String rssPath = getOutputDir("rss"); + + // 2 Run original Tez examples + appConf = new TezConfiguration(miniTezCluster.getConfig()); + updateCommonConfiguration(appConf); + runTezApp(appConf, getTestArgs("origin"), -1); + final String originPath = getOutputDir("origin"); + + // 3 verify the results + TezIntegrationTestBase.verifyResultEqual(originPath, rssPath); + } + + /* + * Two verify mode are supported: + * (a) verifyMode 0 + * tez.rss.avoid.recompute.succeeded.task is enable, should not recompute the task when this node is + * blacke-listed for unhealthy. + * + * (b) verifyMode 1 + * tez.rss.avoid.recompute.succeeded.task is disable, will recompute the task when this node is + * blacke-listed for unhealthy. + * */ + protected void runTezApp(TezConfiguration tezConf, String[] args, int verifyMode) + throws Exception { + assertEquals( + 0, + ToolRunner.run(tezConf, new WordCountWithFailures(verifyMode), args), + "WordCountWithFailures failed"); + } + + public String[] getTestArgs(String uniqueOutputName) { + return new String[] { + "-disableSplitGrouping", inputPath, outputPath + "/" + uniqueOutputName, "2" + }; + } + + public String getOutputDir(String uniqueOutputName) { + return outputPath + "/" + uniqueOutputName; + } + + /* + * In this integration test, mini cluster have three NM with 4G + * (YarnConfiguration.DEFAULT_YARN_MINICLUSTER_NM_PMEM_MB). The request of am is 4G, the request of task is 2G. + * It means that one node only runs one am container so that won't lable the node which am container runs as + * black-list or uhealthy node. + * */ + public void updateRssConfiguration( + Configuration appConf, + int testMode, + boolean avoidRecompute, + boolean rescheduleWhenUnhealthy, + int maxFailures) + throws Exception { + appConf.set(TezConfiguration.TEZ_AM_STAGING_DIR, remoteStagingDir.toString()); + appConf.setInt(TezConfiguration.TEZ_AM_RESOURCE_MEMORY_MB, 4096); + appConf.setInt(TezConfiguration.TEZ_TASK_RESOURCE_MEMORY_MB, 4096); + appConf.setBoolean(TEZ_AM_NODE_BLACKLISTING_ENABLED, true); + appConf.setInt(TEZ_AM_NODE_BLACKLISTING_IGNORE_THRESHOLD, 99); + appConf.setInt(TEZ_AM_MAX_TASK_FAILURES_PER_NODE, maxFailures); + appConf.set(RssTezConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM); + appConf.set(RssTezConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); + appConf.set( + TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS, + TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS_DEFAULT + + " " + + RssDAGAppMasterForWordCountWithFailures.class.getName() + + " --testMode" + + testMode); + appConf.setBoolean(RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, avoidRecompute); + appConf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, rescheduleWhenUnhealthy); + } + + public void updateCommonConfiguration(Configuration appConf) { + appConf.set(TezConfiguration.TEZ_AM_STAGING_DIR, remoteStagingDir.toString()); + appConf.setInt(TezConfiguration.TEZ_AM_RESOURCE_MEMORY_MB, 512); + appConf.set(TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS, " -Xmx384m"); + appConf.setInt(TezConfiguration.TEZ_TASK_RESOURCE_MEMORY_MB, 512); + appConf.set(TezConfiguration.TEZ_TASK_LAUNCH_CMD_OPTS, " -Xmx384m"); + } + + public class WordCountWithFailures extends WordCount { + + TezClient tezClientInternal = null; + private HadoopShim hadoopShim; + int verifyMode = -1; + + WordCountWithFailures(int assertMode) { + this.verifyMode = assertMode; + } + + @Override + protected int runJob(String[] args, TezConfiguration tezConf, TezClient tezClient) + throws Exception { + this.tezClientInternal = tezClient; + Method method = + WordCount.class.getDeclaredMethod( + "createDAG", TezConfiguration.class, String.class, String.class, int.class); + method.setAccessible(true); + DAG dag = + (DAG) + method.invoke( + this, + tezConf, + args[0], + args[1], + args.length == 3 ? Integer.parseInt(args[2]) : 1); + LOG.info("Running WordCountWithFailures"); + return runDag(dag, isCountersLog(), LOG); + } + + public int runDag(DAG dag, boolean printCounters, Logger logger) + throws TezException, InterruptedException, IOException { + tezClientInternal.waitTillReady(); + + CallerContext callerContext = + CallerContext.create("TezExamples", "Tez Example DAG: " + dag.getName()); + ApplicationId appId = tezClientInternal.getAppMasterApplicationId(); + if (hadoopShim == null) { + Configuration conf = (getConf() == null ? new Configuration(false) : getConf()); + hadoopShim = new HadoopShimsLoader(conf).getHadoopShim(); + } + + if (appId != null) { + TezUtilsInternal.setHadoopCallerContext(hadoopShim, appId); + callerContext.setCallerIdAndType(appId.toString(), "TezExampleApplication"); + } + dag.setCallerContext(callerContext); + + DAGClient dagClient = tezClientInternal.submitDAG(dag); + Set getOpts = Sets.newHashSet(); + if (printCounters) { + getOpts.add(StatusGetOpts.GET_COUNTERS); + } + + DAGStatus dagStatus = dagClient.waitForCompletionWithStatusUpdates(getOpts); + if (dagStatus.getState() != DAGStatus.State.SUCCEEDED) { + logger.info("DAG diagnostics: " + dagStatus.getDiagnostics()); + return -1; + } + + Map progressMap = dagStatus.getVertexProgress(); + if (verifyMode == 0) { + // verifyMode is 0: avoid recompute succeeded task is true + Assertions.assertEquals(0, progressMap.get("Tokenizer").getKilledTaskAttemptCount()); + } else if (verifyMode == 1) { + // verifyMode is 1: avoid recompute succeeded task is true + Assertions.assertTrue(progressMap.get("Tokenizer").getKilledTaskAttemptCount() > 0); + } + return 0; + } + } +}