From 8b6c234cbd9d155d7bff0bbc7a386b5a40d2bb63 Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Mon, 21 Oct 2024 14:36:58 +0800 Subject: [PATCH] merge master (#31) --- .../v1alpha1/remoteshuffleservice_types.go | 8 ++++ .../uniffle/v1alpha1/zz_generated.deepcopy.go | 26 ++++++++++ ...ffle.apache.org_remoteshuffleservices.yaml | 16 +++++++ .../sync/coordinator/coordinator.go | 27 +++++++++-- .../sync/coordinator/coordinator_test.go | 45 ++++++++++++++++++ .../apache/uniffle/test/DynamicConfTest.java | 15 +++--- .../apache/uniffle/test/HadoopConfTest.java | 14 +++--- .../apache/uniffle/test/LargeSorterTest.java | 14 +++--- .../uniffle/test/MRIntegrationTestBase.java | 47 ++++++++++++------- .../apache/uniffle/test/RMWordCountTest.java | 11 +++-- .../uniffle/test/SecondarySortTest.java | 12 +++-- .../apache/uniffle/test/WordCountTest.java | 12 +++-- 12 files changed, 194 insertions(+), 53 deletions(-) diff --git a/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go b/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go index 2903f09329..8ec293e964 100644 --- a/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go +++ b/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go @@ -106,6 +106,14 @@ type CoordinatorConfig struct { // HTTPNodePort defines http port of node port service used for coordinators' external access. // +optional HTTPNodePort []int32 `json:"httpNodePort,omitempty"` + + // NodePortServiceAnnotations is a list of annotations for the NodePort service. + // +optional + NodePortServiceAnnotations []map[string]string `json:"nodePortServiceAnnotations,omitempty"` + + // HeadlessServiceAnnotations is a list of annotations for the headless service. + // +optional + HeadlessServiceAnnotations []map[string]string `json:"headlessServiceAnnotations,omitempty"` } // ShuffleServerConfig records configuration used to generate workload of shuffle servers. diff --git a/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go b/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go index c9e15d7a5e..da70877ba3 100644 --- a/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go +++ b/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go @@ -99,6 +99,32 @@ func (in *CoordinatorConfig) DeepCopyInto(out *CoordinatorConfig) { *out = make([]int32, len(*in)) copy(*out, *in) } + if in.NodePortServiceAnnotations != nil { + in, out := &in.NodePortServiceAnnotations, &out.NodePortServiceAnnotations + *out = make([]map[string]string, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + } + } + if in.HeadlessServiceAnnotations != nil { + in, out := &in.HeadlessServiceAnnotations, &out.HeadlessServiceAnnotations + *out = make([]map[string]string, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CoordinatorConfig. diff --git a/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml b/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml index 46e3e23b83..e36e46b406 100644 --- a/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml +++ b/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml @@ -1786,6 +1786,14 @@ spec: description: ExcludeNodesFilePath indicates exclude nodes file path in coordinators' containers. type: string + headlessServiceAnnotations: + description: HeadlessServiceAnnotations is a list of annotations + for the headless service. + items: + additionalProperties: + type: string + type: object + type: array hostNetwork: default: true description: HostNetwork indicates whether we need to enable host @@ -1827,6 +1835,14 @@ spec: description: LogHostPath represents host path used to save logs of shuffle servers. type: string + nodePortServiceAnnotations: + description: NodePortServiceAnnotations is a list of annotations + for the NodePort service. + items: + additionalProperties: + type: string + type: object + type: array nodeSelector: additionalProperties: type: string diff --git a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go index 84b6c84a68..b134d6ca44 100644 --- a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go +++ b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go @@ -105,10 +105,19 @@ func GenerateHeadlessSvc(rss *unifflev1alpha1.RemoteShuffleService, index int) * name := GenerateNameByIndex(rss, index) serviceName := appendHeadless(name) + annotations := map[string]string{} + + if len(rss.Spec.Coordinator.HeadlessServiceAnnotations) > index { + for key, value := range rss.Spec.Coordinator.HeadlessServiceAnnotations[index] { + annotations[key] = value + } + } + svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: serviceName, - Namespace: rss.Namespace, + Name: serviceName, + Namespace: rss.Namespace, + Annotations: annotations, }, Spec: corev1.ServiceSpec{ ClusterIP: corev1.ClusterIPNone, @@ -140,10 +149,20 @@ func GenerateHeadlessSvc(rss *unifflev1alpha1.RemoteShuffleService, index int) * // this function is skipped. func GenerateSvc(rss *unifflev1alpha1.RemoteShuffleService, index int) *corev1.Service { name := GenerateNameByIndex(rss, index) + + annotations := map[string]string{} + + if len(rss.Spec.Coordinator.NodePortServiceAnnotations) > index { + for key, value := range rss.Spec.Coordinator.NodePortServiceAnnotations[index] { + annotations[key] = value + } + } + svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: rss.Namespace, + Name: name, + Namespace: rss.Namespace, + Annotations: annotations, }, Spec: corev1.ServiceSpec{ Type: corev1.ServiceTypeNodePort, diff --git a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go index 22caf5fc41..ebfdb3cf53 100644 --- a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go +++ b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go @@ -138,8 +138,24 @@ var ( "key1": "value1", "key2": "value2", } + + testSvcAnnotationsList = []map[string]string{ + { + "annotation1": "value1", + }, + { + "annotation2": "value2", + }, + } ) +func buildRssWithSvcAnnotations() *uniffleapi.RemoteShuffleService { + rss := utils.BuildRSSWithDefaultValue() + rss.Spec.Coordinator.NodePortServiceAnnotations = testSvcAnnotationsList + rss.Spec.Coordinator.HeadlessServiceAnnotations = testSvcAnnotationsList + return rss +} + func buildRssWithLabels() *uniffleapi.RemoteShuffleService { rss := utils.BuildRSSWithDefaultValue() rss.Spec.Coordinator.Labels = testLabels @@ -546,6 +562,35 @@ func TestGenerateSvcForCoordinator(t *testing.T) { } } +func TestGenerateSvcWithAnnotationsForCoordinator(t *testing.T) { + for _, tt := range []struct { + name string + rss *uniffleapi.RemoteShuffleService + expectedAnnotations []map[string]string + }{ + { + name: "nodeport and headless services with annotations", + rss: buildRssWithSvcAnnotations(), + expectedAnnotations: []map[string]string{ + {"annotation1": "value1"}, + {"annotation1": "value1"}, + {"annotation2": "value2"}, + {"annotation2": "value2"}}, + }, + } { + t.Run(tt.name, func(tc *testing.T) { + _, _, services, _ := GenerateCoordinators(tt.rss) + + for i, svc := range services { + match := reflect.DeepEqual(tt.expectedAnnotations[i], svc.Annotations) + if !match { + tc.Errorf("unexpected annotations: %v, expected: %v", svc.Annotations, tt.expectedAnnotations[i]) + } + } + }) + } +} + func TestGenerateAddresses(t *testing.T) { assertion := assert.New(t) rss := buildRssWithLabels() diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/DynamicConfTest.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/DynamicConfTest.java index c5421d113f..aad036794e 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/DynamicConfTest.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/DynamicConfTest.java @@ -25,7 +25,8 @@ import org.apache.hadoop.mapreduce.RssMRConfig; import org.apache.hadoop.util.Tool; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.storage.util.StorageType; @@ -41,18 +42,18 @@ protected static Map getDynamicConf() { Map dynamicConf = new HashMap<>(); dynamicConf.put(RssMRConfig.RSS_REMOTE_STORAGE_PATH, HDFS_URI + "rss/test"); dynamicConf.put(RssMRConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE_HDFS.name()); - dynamicConf.put(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); return dynamicConf; } - @Test - public void dynamicConfTest() throws Exception { - run(); + @ParameterizedTest + @MethodSource("clientTypeProvider") + public void dynamicConfTest(ClientType clientType) throws Exception { + run(clientType); } @Override - protected void updateRssConfiguration(Configuration jobConf) { - jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); + protected void updateRssConfiguration(Configuration jobConf, ClientType clientType) { + jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, clientType.name()); jobConf.setInt(LargeSorter.NUM_MAP_TASKS, 1); jobConf.setInt(LargeSorter.MBS_PER_MAP, 256); } diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/HadoopConfTest.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/HadoopConfTest.java index 892b9e19a2..be8227bdc8 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/HadoopConfTest.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/HadoopConfTest.java @@ -25,7 +25,8 @@ import org.apache.hadoop.mapreduce.RssMRConfig; import org.apache.hadoop.util.Tool; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.storage.util.StorageType; @@ -41,14 +42,15 @@ protected static Map getDynamicConf() { return new HashMap<>(); } - @Test - public void hadoopConfTest() throws Exception { - run(); + @ParameterizedTest + @MethodSource("clientTypeProvider") + public void hadoopConfTest(ClientType clientType) throws Exception { + run(clientType); } @Override - protected void updateRssConfiguration(Configuration jobConf) { - jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); + protected void updateRssConfiguration(Configuration jobConf, ClientType clientType) { + jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, clientType.name()); jobConf.set(RssMRConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE_HDFS.name()); jobConf.set(RssMRConfig.RSS_REMOTE_STORAGE_PATH, HDFS_URI + "rss/test"); jobConf.setInt(LargeSorter.NUM_MAP_TASKS, 1); diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/LargeSorterTest.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/LargeSorterTest.java index be7ec83750..a1546b7629 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/LargeSorterTest.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/LargeSorterTest.java @@ -23,7 +23,8 @@ import org.apache.hadoop.mapreduce.RssMRConfig; import org.apache.hadoop.util.Tool; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.apache.uniffle.common.ClientType; @@ -34,14 +35,15 @@ public static void setupServers() throws Exception { MRIntegrationTestBase.setupServers(MRIntegrationTestBase.getDynamicConf()); } - @Test - public void largeSorterTest() throws Exception { - run(); + @ParameterizedTest + @MethodSource("clientTypeProvider") + public void largeSorterTest(ClientType clientType) throws Exception { + run(clientType); } @Override - protected void updateRssConfiguration(Configuration jobConf) { - jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); + protected void updateRssConfiguration(Configuration jobConf, ClientType clientType) { + jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, clientType.name()); jobConf.setInt(LargeSorter.NUM_MAP_TASKS, 1); jobConf.setInt(LargeSorter.MBS_PER_MAP, 256); jobConf.set( diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java index 0387eb9cb8..0c6e32652f 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Stream; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; @@ -45,6 +46,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.provider.Arguments; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.rpc.ServerType; @@ -75,6 +77,10 @@ public class MRIntegrationTestBase extends IntegrationTestBase { private static final String OUTPUT_ROOT_DIR = "/tmp/" + TestMRJobs.class.getSimpleName(); private static final Path TEST_RESOURCES_DIR = new Path(TEST_ROOT_DIR, "localizedResources"); + static Stream clientTypeProvider() { + return Stream.of(Arguments.of(ClientType.GRPC), Arguments.of(ClientType.GRPC_NETTY)); + } + @BeforeAll public static void setUpMRYarn() throws IOException { mrYarnCluster = new MiniMRYarnCluster("test"); @@ -99,29 +105,29 @@ public static void tearDown() throws IOException { } } - public void run() throws Exception { + public void run(ClientType clientType) throws Exception { JobConf appConf = new JobConf(mrYarnCluster.getConfig()); updateCommonConfiguration(appConf); runOriginApp(appConf); final String originPath = appConf.get("mapreduce.output.fileoutputformat.outputdir"); appConf = new JobConf(mrYarnCluster.getConfig()); updateCommonConfiguration(appConf); - runRssApp(appConf); + runRssApp(appConf, clientType); String rssPath = appConf.get("mapreduce.output.fileoutputformat.outputdir"); verifyResults(originPath, rssPath); appConf = new JobConf(mrYarnCluster.getConfig()); appConf.set("mapreduce.rss.reduce.remote.spill.enable", "true"); - runRssApp(appConf); + runRssApp(appConf, clientType); String rssRemoteSpillPath = appConf.get("mapreduce.output.fileoutputformat.outputdir"); verifyResults(originPath, rssRemoteSpillPath); } - public void runWithRemoteMerge() throws Exception { + public void runWithRemoteMerge(ClientType clientType) throws Exception { // 1 run application when remote merge is enable JobConf appConf = new JobConf(mrYarnCluster.getConfig()); updateCommonConfiguration(appConf); - runRssApp(appConf, true); + runRssApp(appConf, true, clientType); final String rssPath1 = appConf.get("mapreduce.output.fileoutputformat.outputdir"); // 2 run original application @@ -142,11 +148,12 @@ private void runOriginApp(Configuration jobConf) throws Exception { runMRApp(jobConf, getTestTool(), getTestArgs()); } - private void runRssApp(Configuration jobConf) throws Exception { - runRssApp(jobConf, false); + private void runRssApp(Configuration jobConf, ClientType clientType) throws Exception { + runRssApp(jobConf, false, clientType); } - private void runRssApp(Configuration jobConf, boolean remoteMerge) throws Exception { + private void runRssApp(Configuration jobConf, boolean remoteMerge, ClientType clientType) + throws Exception { URL url = MRIntegrationTestBase.class.getResource("/"); final String parentPath = new Path(url.getPath()).getParent().getParent().getParent().getParent().toString(); @@ -185,19 +192,19 @@ private void runRssApp(Configuration jobConf, boolean remoteMerge) throws Except } assertNotNull(localFile); String props = System.getProperty("java.class.path"); - String newProps = ""; + StringBuilder newProps = new StringBuilder(); String[] splittedProps = props.split(":"); for (String prop : splittedProps) { if (!prop.contains("classes") && !prop.contains("grpc") && !prop.contains("rss-") && !prop.contains("shuffle-storage")) { - newProps = newProps + ":" + prop; + newProps.append(":").append(prop); } else if (prop.contains("mr") && prop.contains("integration-test")) { - newProps = newProps + ":" + prop; + newProps.append(":").append(prop); } } - System.setProperty("java.class.path", newProps); + System.setProperty("java.class.path", newProps.toString()); Path newPath = new Path(HDFS_URI + "/rss.jar"); FileUtil.copy(file, fs, newPath, false, jobConf); DistributedCache.addFileToClassPath(new Path(newPath.toUri().getPath()), jobConf, fs); @@ -208,8 +215,9 @@ private void runRssApp(Configuration jobConf, boolean remoteMerge) throws Except + "," + MRJobConfig.DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH); jobConf.set(RssMRConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM); - updateRssConfiguration(jobConf); + updateRssConfiguration(jobConf, clientType); runMRApp(jobConf, getTestTool(), getTestArgs()); + fs.delete(newPath, true); } protected String[] getTestArgs() { @@ -225,11 +233,14 @@ protected static void setupServers(Map dynamicConf, ShuffleServe CoordinatorConf coordinatorConf = getCoordinatorConf(); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); - ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC); + ShuffleServerConf grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC); + ShuffleServerConf nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY); if (serverConf != null) { - shuffleServerConf.addAll(serverConf); + grpcShuffleServerConf.addAll(serverConf); + nettyShuffleServerConf.addAll(serverConf); } - createShuffleServer(shuffleServerConf); + createShuffleServer(grpcShuffleServerConf); + createShuffleServer(nettyShuffleServerConf); startServers(); } @@ -240,8 +251,8 @@ protected static Map getDynamicConf() { return dynamicConf; } - protected void updateRssConfiguration(Configuration jobConf) { - jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); + protected void updateRssConfiguration(Configuration jobConf, ClientType clientType) { + jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, clientType.name()); } private void runMRApp(Configuration conf, Tool tool, String[] args) throws Exception { diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/RMWordCountTest.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/RMWordCountTest.java index 4ac8fc5896..9c2252920a 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/RMWordCountTest.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/RMWordCountTest.java @@ -29,8 +29,10 @@ import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.util.Tool; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.apache.uniffle.common.ClientType; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; @@ -51,10 +53,11 @@ public static void setupServers() throws Exception { MRIntegrationTestBase.setupServers(MRIntegrationTestBase.getDynamicConf(), serverConf); } - @Test - public void wordCountTest() throws Exception { + @ParameterizedTest + @MethodSource("clientTypeProvider") + public void wordCountTest(ClientType clientType) throws Exception { generateInputFile(); - runWithRemoteMerge(); + runWithRemoteMerge(clientType); } @Override diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/SecondarySortTest.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/SecondarySortTest.java index 6e0bf5dcd9..4882ec5d62 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/SecondarySortTest.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/SecondarySortTest.java @@ -32,7 +32,10 @@ import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.util.Tool; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.apache.uniffle.common.ClientType; public class SecondarySortTest extends MRIntegrationTestBase { @@ -43,10 +46,11 @@ public static void setupServers() throws Exception { MRIntegrationTestBase.setupServers(MRIntegrationTestBase.getDynamicConf()); } - @Test - public void secondarySortTest() throws Exception { + @ParameterizedTest + @MethodSource("clientTypeProvider") + public void secondarySortTest(ClientType clientType) throws Exception { generateInputFile(); - run(); + run(clientType); } private void generateInputFile() throws Exception { diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/WordCountTest.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/WordCountTest.java index 2ba76f5a92..47aef051ed 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/WordCountTest.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/WordCountTest.java @@ -32,7 +32,10 @@ import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.util.Tool; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.apache.uniffle.common.ClientType; public class WordCountTest extends MRIntegrationTestBase { @@ -46,10 +49,11 @@ public static void setupServers() throws Exception { MRIntegrationTestBase.setupServers(MRIntegrationTestBase.getDynamicConf()); } - @Test - public void wordCountTest() throws Exception { + @ParameterizedTest + @MethodSource("clientTypeProvider") + public void wordCountTest(ClientType clientType) throws Exception { generateInputFile(); - run(); + run(clientType); } @Override