diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 438d0cdbc74..389a7412af3 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -15,12 +15,15 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" "net/http" + "strings" "sync" + "time" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -30,6 +33,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" "github.com/gorilla/mux" ) @@ -42,6 +46,11 @@ const ( faultInjectionEnabledError = "fault injection is not enabled for task: %s" ) +var ( + tcBaseCommand = []string{"tc"} + tcCheckInjectionCommandString = "-j q" +) + type FaultHandler struct { // mutexMap is used to avoid multiple clients to manipulate same resource at same // time. The 'key' is the the network namespace path and 'value' is the RWMutex. @@ -49,6 +58,7 @@ type FaultHandler struct { mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory + OsExecWrapper execwrapper.Exec } func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { @@ -56,6 +66,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { AgentState: agentState, MetricsFactory: mf, mutexMap: sync.Map{}, + OsExecWrapper: execwrapper.NewExec(), } } @@ -472,7 +483,6 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return @@ -484,17 +494,32 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. rwMu.RLock() defer rwMu.RUnlock() - // TODO: Check status of current fault injection - // TODO: Return the correct status state - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully checked status for fault", logger.Fields{ + // Check status of current fault injection + faultStatus, err := h.checkPacketLossFault() + var responseBody types.NetworkFaultInjectionResponse + var stringToBeLogged string + var httpStatusCode int + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + stringToBeLogged = "Error: failed to check fault status" + httpStatusCode = http.StatusInternalServerError + } else { + stringToBeLogged = "Successfully checked status for fault" + httpStatusCode = http.StatusOK + if faultStatus { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + } else { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("not-running") + } + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + httpStatusCode, responseBody, requestType, ) @@ -692,3 +717,35 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return nil } + +// checkPacketLossFault checks if there's existing network-packet-loss fault running. +func (h *FaultHandler) checkPacketLossFault() (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // We will run the following Linux command to assess if there existing fault. + // "tc -j q" + // The command above gives the output of "tc q" in json format. + // We will then unmarshall the json string and check if '"kind":"netem"' exists. + parameterList := strings.Split(tcCheckInjectionCommandString, " ") + cmdToExec := append(tcBaseCommand, parameterList...) + cmdExec := h.OsExecWrapper.CommandContext(ctx, cmdToExec[0], cmdToExec[1:]...) + outputInBytes, err := cmdExec.Output() + if err != nil { + return false, err + } + + var outputUnmarshalled []map[string]interface{} + err = json.Unmarshal(outputInBytes, &outputUnmarshalled) + if err != nil { + return false, errors.New("failed to unmarshal tc command output: " + err.Error()) + } + + for _, line := range outputUnmarshalled { + if line["kind"] == "netem" { + return true, nil + } + } + + return false, nil +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go new file mode 100644 index 00000000000..674ceef8d7c --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go @@ -0,0 +1,114 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package execwrapper + +import ( + "context" + "io" + "os" + "os/exec" +) + +// Exec acts as a wrapper to functions exposed by the exec package. +// Having this interface enables us to create mock objects we can use +// for testing. +type Exec interface { + CommandContext(ctx context.Context, name string, arg ...string) Cmd +} + +// execWrapper is a placeholder struct which implements the Exec interface. +type execWrapper struct { +} + +func NewExec() Exec { + return &execWrapper{} +} + +// CommandContext essentially acts as a wrapper function for exec.CommandContext function. +func (e *execWrapper) CommandContext(ctx context.Context, name string, arg ...string) Cmd { + return NewCMDContext(ctx, name, arg...) +} + +// Cmd acts as a wrapper to functions exposed by the exec.Cmd object. +// Having this interface enables us to create mock objects we can use +// for testing. +type Cmd interface { + Run() error + Start() error + Wait() error + KillProcess() error + AppendExtraFiles(...*os.File) + Args() []string + SetIOStreams(io.Reader, io.Writer, io.Writer) + Output() ([]byte, error) + CombinedOutput() ([]byte, error) +} + +type cmdWrapper struct { + *exec.Cmd +} + +func NewCMDContext(ctx context.Context, name string, arg ...string) Cmd { + cmd := exec.CommandContext(ctx, name, arg...) + return &cmdWrapper{Cmd: cmd} +} + +func NewCMD(name string, arg ...string) Cmd { + cmd := exec.Command(name, arg...) + return &cmdWrapper{Cmd: cmd} +} + +func (c *cmdWrapper) Run() error { + return c.Cmd.Run() +} + +func (c *cmdWrapper) Start() error { + return c.Cmd.Start() +} + +func (c *cmdWrapper) Wait() error { + return c.Cmd.Wait() +} + +func (c *cmdWrapper) KillProcess() error { + return c.Cmd.Process.Kill() +} + +func (c *cmdWrapper) AppendExtraFiles(ef ...*os.File) { + c.ExtraFiles = append(c.ExtraFiles, ef...) +} + +func (c *cmdWrapper) Args() []string { + return c.Cmd.Args +} + +func (c *cmdWrapper) SetIOStreams(stdin io.Reader, stdout io.Writer, stderr io.Writer) { + if stdin != nil { + c.Stdin = stdin + } + if stdout != nil { + c.Stdout = stdout + } + if stderr != nil { + c.Stderr = stderr + } +} + +func (c *cmdWrapper) Output() ([]byte, error) { + return c.Cmd.Output() +} + +func (c *cmdWrapper) CombinedOutput() ([]byte, error) { + return c.Cmd.CombinedOutput() +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go new file mode 100644 index 00000000000..0c0f9e9aba1 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +//go:generate mockgen -build_flags=--mod=mod -destination=mocks/execwrapper_mocks.go -copyright_file=../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper Cmd,Exec + +package execwrapper diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 54ddb3d910b..9ea6b1cbd6b 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -74,6 +74,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux github.com/aws/amazon-ecs-agent/ecs-agent/utils github.com/aws/amazon-ecs-agent/ecs-agent/utils/arn github.com/aws/amazon-ecs-agent/ecs-agent/utils/cipher +github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper github.com/aws/amazon-ecs-agent/ecs-agent/utils/httpproxy github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 438d0cdbc74..389a7412af3 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -15,12 +15,15 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" "net/http" + "strings" "sync" + "time" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -30,6 +33,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" "github.com/gorilla/mux" ) @@ -42,6 +46,11 @@ const ( faultInjectionEnabledError = "fault injection is not enabled for task: %s" ) +var ( + tcBaseCommand = []string{"tc"} + tcCheckInjectionCommandString = "-j q" +) + type FaultHandler struct { // mutexMap is used to avoid multiple clients to manipulate same resource at same // time. The 'key' is the the network namespace path and 'value' is the RWMutex. @@ -49,6 +58,7 @@ type FaultHandler struct { mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory + OsExecWrapper execwrapper.Exec } func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { @@ -56,6 +66,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { AgentState: agentState, MetricsFactory: mf, mutexMap: sync.Map{}, + OsExecWrapper: execwrapper.NewExec(), } } @@ -472,7 +483,6 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return @@ -484,17 +494,32 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. rwMu.RLock() defer rwMu.RUnlock() - // TODO: Check status of current fault injection - // TODO: Return the correct status state - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully checked status for fault", logger.Fields{ + // Check status of current fault injection + faultStatus, err := h.checkPacketLossFault() + var responseBody types.NetworkFaultInjectionResponse + var stringToBeLogged string + var httpStatusCode int + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + stringToBeLogged = "Error: failed to check fault status" + httpStatusCode = http.StatusInternalServerError + } else { + stringToBeLogged = "Successfully checked status for fault" + httpStatusCode = http.StatusOK + if faultStatus { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + } else { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("not-running") + } + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + httpStatusCode, responseBody, requestType, ) @@ -692,3 +717,35 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return nil } + +// checkPacketLossFault checks if there's existing network-packet-loss fault running. +func (h *FaultHandler) checkPacketLossFault() (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // We will run the following Linux command to assess if there existing fault. + // "tc -j q" + // The command above gives the output of "tc q" in json format. + // We will then unmarshall the json string and check if '"kind":"netem"' exists. + parameterList := strings.Split(tcCheckInjectionCommandString, " ") + cmdToExec := append(tcBaseCommand, parameterList...) + cmdExec := h.OsExecWrapper.CommandContext(ctx, cmdToExec[0], cmdToExec[1:]...) + outputInBytes, err := cmdExec.Output() + if err != nil { + return false, err + } + + var outputUnmarshalled []map[string]interface{} + err = json.Unmarshal(outputInBytes, &outputUnmarshalled) + if err != nil { + return false, errors.New("failed to unmarshal tc command output: " + err.Error()) + } + + for _, line := range outputUnmarshalled { + if line["kind"] == "netem" { + return true, nil + } + } + + return false, nil +} diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index c810a48b94c..49a2bf9905f 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -24,6 +24,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" @@ -31,6 +32,7 @@ import ( v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks" + mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -39,17 +41,19 @@ import ( ) const ( - endpointId = "endpointId" - port = 1234 - protocol = "tcp" - trafficType = "ingress" - delayMilliseconds = 123456789 - jitterMilliseconds = 4567 - lossPercent = 6 - taskARN = "taskArn" - awsvpcNetworkMode = "awsvpc" - deviceName = "eth0" - invalidNetworkMode = "invalid" + endpointId = "endpointId" + port = 1234 + protocol = "tcp" + trafficType = "ingress" + delayMilliseconds = 123456789 + jitterMilliseconds = 4567 + lossPercent = 6 + taskARN = "taskArn" + awsvpcNetworkMode = "awsvpc" + deviceName = "eth0" + invalidNetworkMode = "invalid" + tcFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","dev":"eth0","parent":"1:1","options":{"limit":1000,"loss-random":{"loss":0.1,"correlation":0},"ecn":false,"gap":0}}]` + tcFaultDoesNotExistCommandOutput = `[{"kind":"dummyname"}]` ) var ( @@ -99,6 +103,7 @@ type networkFaultInjectionTestCase struct { requestBody interface{} expectedResponseBody types.NetworkFaultInjectionResponse setAgentStateExpectations func(agentState *mock_state.MockAgentState) + setExecExpectations func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) } // Tests the path for Fault Network Faults API @@ -885,20 +890,40 @@ func TestCheckNetworkLatency(t *testing.T) { } } -func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) []networkFaultInjectionTestCase { +func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string, expectedUnhappyResponseBody string) []networkFaultInjectionTestCase { happyNetworkPacketLossReqBody := map[string]interface{}{ "LossPercent": lossPercent, "Sources": ipSources, } tcs := []networkFaultInjectionTestCase{ { - name: fmt.Sprintf("%s success", name), + name: fmt.Sprintf("%s success-running", name), expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().Output().Times(1).Return([]byte(tcFaultExistsCommandOutput), nil) + }, + }, + { + name: fmt.Sprintf("%s success-not-running", name), + expectedStatusCode: 200, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedUnhappyResponseBody), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().Output().Times(1).Return([]byte(tcFaultDoesNotExistCommandOutput), nil) + + }, }, { name: fmt.Sprintf("%s unknown request body", name), @@ -912,6 +937,31 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().Output().Times(1).Return([]byte(tcFaultExistsCommandOutput), nil) + + }, + }, + { + name: fmt.Sprintf("%s failed to unmarshal json", name), + expectedStatusCode: 500, + requestBody: map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + "Unknown": "", + }, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to unmarshal tc command output: unexpected end of JSON input"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().Output().Times(1).Return([]byte(""), nil) + + }, }, { name: fmt.Sprintf("%s malformed request body 1", name), @@ -1105,9 +1155,15 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) } func TestStartNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("start network packet loss", "running") + tcs := generateNetworkPacketLossTestCases("start network packet loss", "running", "") for _, tc := range tcs { + // Currently the logic that the following test case covers is only implemented for CheckNetworkPacketLoss(). + // It will fail for Start and Stop. Thus, skipping them until the logic is fully implemented. + if strings.Contains(tc.name, "failed to unmarshal json") || + strings.Contains(tc.name, "success-not-running") { + continue + } t.Run(tc.name, func(t *testing.T) { // Mocks ctrl := gomock.NewController(t) @@ -1154,8 +1210,14 @@ func TestStartNetworkPacketLoss(t *testing.T) { } func TestStopNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("stop network packet loss", "stopped") + tcs := generateNetworkPacketLossTestCases("stop network packet loss", "stopped", "") for _, tc := range tcs { + // Currently the logic that the following test case covers is only implemented for CheckNetworkPacketLoss(). + // It will fail for Start and Stop. Thus, skipping them until the logic is fully implemented. + if strings.Contains(tc.name, "failed to unmarshal json") || + strings.Contains(tc.name, "success-not-running") { + continue + } t.Run(tc.name, func(t *testing.T) { // Mocks ctrl := gomock.NewController(t) @@ -1202,7 +1264,7 @@ func TestStopNetworkPacketLoss(t *testing.T) { } func TestCheckNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("check network packet loss", "running") + tcs := generateNetworkPacketLossTestCases("check network packet loss", "running", "not-running") for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -1214,10 +1276,15 @@ func TestCheckNetworkPacketLoss(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() + mockExec := mock_execwrapper.NewMockExec(ctrl) handler := New(agentState, metricsFactory) + handler.OsExecWrapper = mockExec if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } + if tc.setExecExpectations != nil { + tc.setExecExpectations(mockExec, ctrl) + } router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), diff --git a/ecs-agent/utils/execwrapper/exec.go b/ecs-agent/utils/execwrapper/exec.go index 776a7dfb57f..674ceef8d7c 100644 --- a/ecs-agent/utils/execwrapper/exec.go +++ b/ecs-agent/utils/execwrapper/exec.go @@ -51,6 +51,8 @@ type Cmd interface { AppendExtraFiles(...*os.File) Args() []string SetIOStreams(io.Reader, io.Writer, io.Writer) + Output() ([]byte, error) + CombinedOutput() ([]byte, error) } type cmdWrapper struct { @@ -102,3 +104,11 @@ func (c *cmdWrapper) SetIOStreams(stdin io.Reader, stdout io.Writer, stderr io.W c.Stderr = stderr } } + +func (c *cmdWrapper) Output() ([]byte, error) { + return c.Cmd.Output() +} + +func (c *cmdWrapper) CombinedOutput() ([]byte, error) { + return c.Cmd.CombinedOutput() +} diff --git a/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go b/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go index 962077bd141..f61f2e777e3 100644 --- a/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go +++ b/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go @@ -81,6 +81,21 @@ func (mr *MockCmdMockRecorder) Args() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Args", reflect.TypeOf((*MockCmd)(nil).Args)) } +// CombinedOutput mocks base method. +func (m *MockCmd) CombinedOutput() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CombinedOutput") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CombinedOutput indicates an expected call of CombinedOutput. +func (mr *MockCmdMockRecorder) CombinedOutput() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CombinedOutput", reflect.TypeOf((*MockCmd)(nil).CombinedOutput)) +} + // KillProcess mocks base method. func (m *MockCmd) KillProcess() error { m.ctrl.T.Helper() @@ -95,6 +110,21 @@ func (mr *MockCmdMockRecorder) KillProcess() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KillProcess", reflect.TypeOf((*MockCmd)(nil).KillProcess)) } +// Output mocks base method. +func (m *MockCmd) Output() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Output") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Output indicates an expected call of Output. +func (mr *MockCmdMockRecorder) Output() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Output", reflect.TypeOf((*MockCmd)(nil).Output)) +} + // Run mocks base method. func (m *MockCmd) Run() error { m.ctrl.T.Helper()