Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SourcesToFilter support for network-blackhole-port fault #4408

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3806,8 +3806,6 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 20 additions & 8 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,19 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
if err != nil {
return
}

// Validate the fault request
err = validateRequest(w, request, requestType)
if err != nil {
return
}

if aws.StringValue(request.TrafficType) == types.TrafficTypeEgress &&
aws.Uint16Value(request.Port) == tmds.PortForTasks {
// Add TMDS IP to SouresToFilter so that access to TMDS is not blocked for the task
request.AddSourceToFilterIfNotAlready(tmds.IPForTasks)
}

// Obtain the task metadata via the endpoint container ID
taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r)
if err != nil {
Expand Down Expand Up @@ -154,7 +161,8 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
insertTable = "OUTPUT"
}

_, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName,
_, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol),
port, aws.StringValueSlice(request.SourcesToFilter), chainName,
networkMode, networkNSPath, insertTable, taskArn)
if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) {
statusCode = http.StatusInternalServerError
Expand Down Expand Up @@ -187,7 +195,10 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
// 2. Creates a new chain via `iptables -N <chain>` (the chain name is in the form of "<trafficType>-<protocol>-<port>")
// 3. Appends a new rule to the newly created chain via `iptables -A <chain> -p <protocol> --dport <port> -j DROP`
// 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table
func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) {
func (h *FaultHandler) startNetworkBlackholePort(
ctx context.Context, protocol, port string, sourcesToFilter []string,
chain, networkMode, netNs, insertTable, taskArn string,
) (string, error) {
running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn)
if err != nil {
return cmdOutput, err
Expand Down Expand Up @@ -246,12 +257,13 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
return "", nil
}

// Add a rule to accept all traffic to TMDS
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
acceptTarget)
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
return out, err
for _, sourceToFilter := range sourcesToFilter {
filterRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, sourceToFilter, port,
acceptTarget)
if out, err := execRuleChangeCommand(filterRuleCmdString); err != nil {
return out, err
}
}

// Add a rule to drop all traffic to the port that the fault targets
Expand Down
136 changes: 132 additions & 4 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig"
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
"github.com/aws/aws-sdk-go/aws"

"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
Expand Down Expand Up @@ -521,8 +522,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -556,8 +555,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -663,6 +660,137 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
)
},
},
{
name: "SourcesToFilter validation failure",
expectedStatusCode: 400,
requestBody: map[string]interface{}{
"Port": port,
"Protocol": protocol,
"TrafficType": trafficType,
"SourcesToFilter": aws.StringSlice([]string{"1.2.3.4", "bad"}),
},
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value bad for parameter SourcesToFilter"),
},
{
name: "TMDS IP is added to SourcesToFilter if needed",
requestBody: map[string]interface{}{
"Port": 80,
"Protocol": protocol,
"TrafficType": "egress",
},
expectedStatusCode: 200,
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80",
"-p", "tcp", "-d", "169.254.170.2", "--dport", "80", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80",
"-p", "tcp", "-d", "0.0.0.0/0", "--dport", "80", "-j", "DROP",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
{
name: "Sources to filter are filtered",
requestBody: map[string]interface{}{
"Port": 443,
"Protocol": "udp",
"TrafficType": "ingress",
"SourcesToFilter": []string{"1.2.3.4/20", "8.8.8.8"},
},
expectedStatusCode: 200,
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "8.8.8.8", "--dport", "443", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "0.0.0.0/0", "--dport", "443", "-j", "DROP",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
{
name: "Error when filtering a source",
expectedStatusCode: 500,
requestBody: map[string]interface{}{
"Port": 443,
"Protocol": "udp",
"TrafficType": "ingress",
"SourcesToFilter": []string{"1.2.3.4/20"},
},
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
)
},
},
}

return append(tcs, commonTcs...)
Expand Down
Loading
Loading