From 5a409ce78fff87a9417969684357080c672e697a Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Wed, 23 Oct 2024 17:51:47 +0000 Subject: [PATCH 1/4] Add SourcesToFilter support for network-blackhole-port fault --- .../handlers/fault/v1/handlers/handlers.go | 28 +++-- .../tmds/handlers/fault/v1/types/types.go | 23 +++- .../amazon-ecs-agent/ecs-agent/tmds/server.go | 2 +- .../handlers/fault/v1/handlers/handlers.go | 28 +++-- .../fault/v1/handlers/handlers_test.go | 102 +++++++++++++++++- .../tmds/handlers/fault/v1/types/types.go | 23 +++- .../handlers/fault/v1/types/types_test.go | 43 ++++++++ ecs-agent/tmds/server.go | 2 +- 8 files changed, 227 insertions(+), 24 deletions(-) create mode 100644 ecs-agent/tmds/handlers/fault/v1/types/types_test.go diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 7841937fd38..02a0c0e9989 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -120,12 +120,19 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht if err != nil { return } + // Validate the fault request err = validateRequest(w, request, requestType) if err != nil { return } + if aws.StringValue(request.TrafficType) == types.TrafficTypeEgress && + aws.Uint16Value(request.Port) == tmds.PortForTasks { + // Add TMDS IP to SouresToFilter so that access to TMDS is not blocked for the task + request.AddSourceToFilterIfNotAlready(tmds.IPForTasks) + } + // Obtain the task metadata via the endpoint container ID taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { @@ -154,7 +161,8 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht insertTable = "OUTPUT" } - _, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + _, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), + port, aws.StringValueSlice(request.SourcesToFilter), chainName, networkMode, networkNSPath, insertTable, taskArn) if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { statusCode = http.StatusInternalServerError @@ -187,7 +195,10 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht // 2. Creates a new chain via `iptables -N ` (the chain name is in the form of "--") // 3. Appends a new rule to the newly created chain via `iptables -A -p --dport -j DROP` // 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table -func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) { +func (h *FaultHandler) startNetworkBlackholePort( + ctx context.Context, protocol, port string, sourcesToFilter []string, + chain, networkMode, netNs, insertTable, taskArn string, +) (string, error) { running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn) if err != nil { return cmdOutput, err @@ -246,12 +257,13 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, return "", nil } - // Add a rule to accept all traffic to TMDS - protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, - requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks, - acceptTarget) - if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil { - return out, err + for _, sourceToFilter := range sourcesToFilter { + filterRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, + requestTimeoutSeconds, chain, protocol, sourceToFilter, port, + acceptTarget) + if out, err := execRuleChangeCommand(filterRuleCmdString); err != nil { + return out, err + } } // Add a rule to drop all traffic to the port that the fault targets diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go index a69d417ea5b..d1bbef8c232 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go @@ -32,6 +32,8 @@ const ( missingRequiredFieldError = "required parameter %s is missing" MissingRequestBodyError = "required request body is missing" invalidValueError = "invalid value %s for parameter %s" + TrafficTypeIngress = "ingress" + TrafficTypeEgress = "egress" ) type NetworkFaultRequest interface { @@ -43,6 +45,9 @@ type NetworkBlackholePortRequest struct { Port *uint16 `json:"Port"` Protocol *string `json:"Protocol"` TrafficType *string `json:"TrafficType"` + // SourcesToFilter is a list including IPv4 addresses or IPv4 CIDR blocks that will be excluded from the + // network latency fault. + SourcesToFilter []*string `json:"SourcesToFilter,omitempty"` } type NetworkFaultInjectionResponse struct { @@ -65,13 +70,29 @@ func (request NetworkBlackholePortRequest) ValidateRequest() error { return fmt.Errorf(invalidValueError, *request.Protocol, "Protocol") } - if *request.TrafficType != "ingress" && *request.TrafficType != "egress" { + if *request.TrafficType != TrafficTypeIngress && *request.TrafficType != TrafficTypeEgress { return fmt.Errorf(invalidValueError, *request.TrafficType, "TrafficType") } + if err := validateNetworkFaultRequestSources(request.SourcesToFilter, "SourcesToFilter"); err != nil { + return err + } return nil } +// Adds a source to SourcesToFilter +func (request *NetworkBlackholePortRequest) AddSourceToFilterIfNotAlready(source string) { + if request.SourcesToFilter == nil { + request.SourcesToFilter = []*string{} + } + for _, src := range request.SourcesToFilter { + if src != nil && *src == source { + return + } + } + request.SourcesToFilter = append(request.SourcesToFilter, aws.String(source)) +} + func (request NetworkBlackholePortRequest) ToString() string { data, err := json.Marshal(request) if err != nil { diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go index 7248bd7ff25..373c1659a6e 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go @@ -32,7 +32,7 @@ const ( IPv4 = "127.0.0.1" Port = 51679 IPForTasks = "169.254.170.2" - PortForTasks = "80" + PortForTasks = 80 ) // IPv4 address for TMDS diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 7841937fd38..02a0c0e9989 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -120,12 +120,19 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht if err != nil { return } + // Validate the fault request err = validateRequest(w, request, requestType) if err != nil { return } + if aws.StringValue(request.TrafficType) == types.TrafficTypeEgress && + aws.Uint16Value(request.Port) == tmds.PortForTasks { + // Add TMDS IP to SouresToFilter so that access to TMDS is not blocked for the task + request.AddSourceToFilterIfNotAlready(tmds.IPForTasks) + } + // Obtain the task metadata via the endpoint container ID taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { @@ -154,7 +161,8 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht insertTable = "OUTPUT" } - _, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + _, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), + port, aws.StringValueSlice(request.SourcesToFilter), chainName, networkMode, networkNSPath, insertTable, taskArn) if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { statusCode = http.StatusInternalServerError @@ -187,7 +195,10 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht // 2. Creates a new chain via `iptables -N ` (the chain name is in the form of "--") // 3. Appends a new rule to the newly created chain via `iptables -A -p --dport -j DROP` // 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table -func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) { +func (h *FaultHandler) startNetworkBlackholePort( + ctx context.Context, protocol, port string, sourcesToFilter []string, + chain, networkMode, netNs, insertTable, taskArn string, +) (string, error) { running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn) if err != nil { return cmdOutput, err @@ -246,12 +257,13 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, return "", nil } - // Add a rule to accept all traffic to TMDS - protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, - requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks, - acceptTarget) - if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil { - return out, err + for _, sourceToFilter := range sourcesToFilter { + filterRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, + requestTimeoutSeconds, chain, protocol, sourceToFilter, port, + acceptTarget) + if out, err := execRuleChangeCommand(filterRuleCmdString); err != nil { + return out, err + } } // Add a rule to drop all traffic to the port that the fault targets diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 6ba1242d33d..e44135bbc91 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -35,6 +35,7 @@ import ( mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -521,8 +522,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) }, }, @@ -556,8 +555,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) }, }, @@ -663,6 +660,103 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase ) }, }, + { + name: "SourcesToFilter validation failure", + expectedStatusCode: 400, + requestBody: map[string]interface{}{ + "Port": port, + "Protocol": protocol, + "TrafficType": trafficType, + "SourcesToFilter": aws.StringSlice([]string{"1.2.3.4", "bad"}), + }, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value bad for parameter SourcesToFilter"), + }, + { + name: "TMDS IP is added to SourcesToFilter if needed", + requestBody: map[string]interface{}{ + "Port": 80, + "Protocol": protocol, + "TrafficType": "egress", + }, + expectedStatusCode: 200, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), + "nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80", + "-p", "tcp", "-d", "169.254.170.2", "--dport", "80", "-j", "ACCEPT", + ).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), + "nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80", + "-p", "tcp", "-d", "0.0.0.0/0", "--dport", "80", "-j", "DROP", + ).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + }, + }, + { + name: "Sources to filter are filtered", + requestBody: map[string]interface{}{ + "Port": 443, + "Protocol": "udp", + "TrafficType": "ingress", + "SourcesToFilter": []string{"1.2.3.4/20", "8.8.8.8"}, + }, + expectedStatusCode: 200, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), + "nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443", + "-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT", + ).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), + "nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443", + "-p", "udp", "-d", "8.8.8.8", "--dport", "443", "-j", "ACCEPT", + ).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), + "nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443", + "-p", "udp", "-d", "0.0.0.0/0", "--dport", "443", "-j", "DROP", + ).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + }, + }, } return append(tcs, commonTcs...) diff --git a/ecs-agent/tmds/handlers/fault/v1/types/types.go b/ecs-agent/tmds/handlers/fault/v1/types/types.go index a69d417ea5b..d1bbef8c232 100644 --- a/ecs-agent/tmds/handlers/fault/v1/types/types.go +++ b/ecs-agent/tmds/handlers/fault/v1/types/types.go @@ -32,6 +32,8 @@ const ( missingRequiredFieldError = "required parameter %s is missing" MissingRequestBodyError = "required request body is missing" invalidValueError = "invalid value %s for parameter %s" + TrafficTypeIngress = "ingress" + TrafficTypeEgress = "egress" ) type NetworkFaultRequest interface { @@ -43,6 +45,9 @@ type NetworkBlackholePortRequest struct { Port *uint16 `json:"Port"` Protocol *string `json:"Protocol"` TrafficType *string `json:"TrafficType"` + // SourcesToFilter is a list including IPv4 addresses or IPv4 CIDR blocks that will be excluded from the + // network latency fault. + SourcesToFilter []*string `json:"SourcesToFilter,omitempty"` } type NetworkFaultInjectionResponse struct { @@ -65,13 +70,29 @@ func (request NetworkBlackholePortRequest) ValidateRequest() error { return fmt.Errorf(invalidValueError, *request.Protocol, "Protocol") } - if *request.TrafficType != "ingress" && *request.TrafficType != "egress" { + if *request.TrafficType != TrafficTypeIngress && *request.TrafficType != TrafficTypeEgress { return fmt.Errorf(invalidValueError, *request.TrafficType, "TrafficType") } + if err := validateNetworkFaultRequestSources(request.SourcesToFilter, "SourcesToFilter"); err != nil { + return err + } return nil } +// Adds a source to SourcesToFilter +func (request *NetworkBlackholePortRequest) AddSourceToFilterIfNotAlready(source string) { + if request.SourcesToFilter == nil { + request.SourcesToFilter = []*string{} + } + for _, src := range request.SourcesToFilter { + if src != nil && *src == source { + return + } + } + request.SourcesToFilter = append(request.SourcesToFilter, aws.String(source)) +} + func (request NetworkBlackholePortRequest) ToString() string { data, err := json.Marshal(request) if err != nil { diff --git a/ecs-agent/tmds/handlers/fault/v1/types/types_test.go b/ecs-agent/tmds/handlers/fault/v1/types/types_test.go new file mode 100644 index 00000000000..320e6a833f7 --- /dev/null +++ b/ecs-agent/tmds/handlers/fault/v1/types/types_test.go @@ -0,0 +1,43 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package types + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/require" +) + +func TestNetworkBlackholePortAddSourceToFilterIfNotAlready(t *testing.T) { + t.Run("nil SourcesToFilter is initialized", func(t *testing.T) { + var req *NetworkBlackholePortRequest = &NetworkBlackholePortRequest{} + req.AddSourceToFilterIfNotAlready("1.2.3.4") + require.Equal(t, aws.StringValueSlice(req.SourcesToFilter), []string{"1.2.3.4"}) + }) + t.Run("Source can be added", func(t *testing.T) { + var req *NetworkBlackholePortRequest = &NetworkBlackholePortRequest{ + SourcesToFilter: aws.StringSlice([]string{"8.8.8.8"}), + } + req.AddSourceToFilterIfNotAlready("1.2.3.4") + require.Equal(t, aws.StringValueSlice(req.SourcesToFilter), []string{"8.8.8.8", "1.2.3.4"}) + }) + t.Run("Duplicate source is not added", func(t *testing.T) { + var req *NetworkBlackholePortRequest = &NetworkBlackholePortRequest{ + SourcesToFilter: aws.StringSlice([]string{"8.8.8.8", "1.2.3.4"}), + } + req.AddSourceToFilterIfNotAlready("1.2.3.4") + require.Equal(t, aws.StringValueSlice(req.SourcesToFilter), []string{"8.8.8.8", "1.2.3.4"}) + }) +} diff --git a/ecs-agent/tmds/server.go b/ecs-agent/tmds/server.go index 7248bd7ff25..373c1659a6e 100644 --- a/ecs-agent/tmds/server.go +++ b/ecs-agent/tmds/server.go @@ -32,7 +32,7 @@ const ( IPv4 = "127.0.0.1" Port = 51679 IPForTasks = "169.254.170.2" - PortForTasks = "80" + PortForTasks = 80 ) // IPv4 address for TMDS From e2bbc56ee278c579cba06c2e924126ec7dec305c Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Wed, 23 Oct 2024 18:27:14 +0000 Subject: [PATCH 2/4] Minor change --- ecs-agent/tmds/handlers/fault/v1/types/types_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ecs-agent/tmds/handlers/fault/v1/types/types_test.go b/ecs-agent/tmds/handlers/fault/v1/types/types_test.go index 320e6a833f7..c245a81a760 100644 --- a/ecs-agent/tmds/handlers/fault/v1/types/types_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/types/types_test.go @@ -22,19 +22,19 @@ import ( func TestNetworkBlackholePortAddSourceToFilterIfNotAlready(t *testing.T) { t.Run("nil SourcesToFilter is initialized", func(t *testing.T) { - var req *NetworkBlackholePortRequest = &NetworkBlackholePortRequest{} + var req NetworkBlackholePortRequest = NetworkBlackholePortRequest{} req.AddSourceToFilterIfNotAlready("1.2.3.4") require.Equal(t, aws.StringValueSlice(req.SourcesToFilter), []string{"1.2.3.4"}) }) t.Run("Source can be added", func(t *testing.T) { - var req *NetworkBlackholePortRequest = &NetworkBlackholePortRequest{ + var req NetworkBlackholePortRequest = NetworkBlackholePortRequest{ SourcesToFilter: aws.StringSlice([]string{"8.8.8.8"}), } req.AddSourceToFilterIfNotAlready("1.2.3.4") require.Equal(t, aws.StringValueSlice(req.SourcesToFilter), []string{"8.8.8.8", "1.2.3.4"}) }) t.Run("Duplicate source is not added", func(t *testing.T) { - var req *NetworkBlackholePortRequest = &NetworkBlackholePortRequest{ + var req NetworkBlackholePortRequest = NetworkBlackholePortRequest{ SourcesToFilter: aws.StringSlice([]string{"8.8.8.8", "1.2.3.4"}), } req.AddSourceToFilterIfNotAlready("1.2.3.4") From cdd8f8e26101d61721d2ad4b10fadf58f85e6124 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Wed, 23 Oct 2024 18:33:23 +0000 Subject: [PATCH 3/4] Add a test --- .../fault/v1/handlers/handlers_test.go | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index e44135bbc91..99231fb536a 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -757,6 +757,40 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase ) }, }, + { + name: "Error when filtering a source", + expectedStatusCode: 500, + requestBody: map[string]interface{}{ + "Port": 443, + "Protocol": "udp", + "TrafficType": "ingress", + "SourcesToFilter": []string{"1.2.3.4/20"}, + }, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), + "nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443", + "-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT", + ).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) + }, + }, } return append(tcs, commonTcs...) From 76c4f2ee954084f36fcf9a2cf036aab4da6a9434 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Wed, 23 Oct 2024 18:40:17 +0000 Subject: [PATCH 4/4] Test fix --- agent/handlers/task_server_setup_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 221382023a6..c328cd077a9 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -3806,8 +3806,6 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) { cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) } tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)