diff --git a/assert/assert_ext_test.go b/assert/assert_ext_test.go new file mode 100644 index 00000000..5903f70f --- /dev/null +++ b/assert/assert_ext_test.go @@ -0,0 +1,112 @@ +package assert_test + +import ( + "go/parser" + "go/token" + "io/ioutil" + "runtime" + "strings" + "testing" + + "gotest.tools/v3/assert" + "gotest.tools/v3/internal/source" +) + +func TestEqual_WithGoldenUpdate(t *testing.T) { + t.Run("assert failed with -update=false", func(t *testing.T) { + ft := &fakeTestingT{} + actual := `not this value` + assert.Equal(ft, actual, expectedOne) + assert.Assert(t, ft.failNowed) + }) + + t.Run("var is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + t.Cleanup(func() { + resetVariable(t, "expectedOne", "") + }) + + actual := `this is the +actual value +that we are testing +` + assert.Equal(t, actual, expectedOne) + + raw, err := ioutil.ReadFile(fileName(t)) + assert.NilError(t, err) + + expected := "var expectedOne = `this is the\nactual value\nthat we are testing\n`" + assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) + }) + + t.Run("const is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + t.Cleanup(func() { + resetVariable(t, "expectedTwo", "") + }) + + actual := `this is the new +expected value +` + assert.Equal(t, actual, expectedTwo) + + raw, err := ioutil.ReadFile(fileName(t)) + assert.NilError(t, err) + + expected := "const expectedTwo = `this is the new\nexpected value\n`" + assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) + }) +} + +// expectedOne is updated by running the tests with -update +var expectedOne = `` + +// expectedTwo is updated by running the tests with -update +const expectedTwo = `` + +func patchUpdate(t *testing.T) { + source.Update = true + t.Cleanup(func() { + source.Update = false + }) +} + +func fileName(t *testing.T) string { + t.Helper() + _, filename, _, ok := runtime.Caller(1) + assert.Assert(t, ok, "failed to get call stack") + return filename +} + +func resetVariable(t *testing.T, varName string, value string) { + t.Helper() + _, filename, _, ok := runtime.Caller(1) + assert.Assert(t, ok, "failed to get call stack") + + fileset := token.NewFileSet() + astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) + assert.NilError(t, err) + + err = source.UpdateVariable(filename, fileset, astFile, varName, value) + assert.NilError(t, err, "failed to reset file") +} + +type fakeTestingT struct { + failNowed bool + failed bool + msgs []string +} + +func (f *fakeTestingT) FailNow() { + f.failNowed = true +} + +func (f *fakeTestingT) Fail() { + f.failed = true +} + +func (f *fakeTestingT) Log(args ...interface{}) { + f.msgs = append(f.msgs, args[0].(string)) +} + +func (f *fakeTestingT) Helper() {} diff --git a/assert/cmp/compare.go b/assert/cmp/compare.go index 1f42bd0c..78f76e4e 100644 --- a/assert/cmp/compare.go +++ b/assert/cmp/compare.go @@ -35,7 +35,7 @@ func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison { if diff == "" { return ResultSuccess } - return multiLineDiffResult(diff) + return multiLineDiffResult(diff, x, y) } } @@ -102,7 +102,7 @@ func Equal(x, y interface{}) Comparison { return ResultSuccess case isMultiLineStringCompare(x, y): diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)}) - return multiLineDiffResult(diff) + return multiLineDiffResult(diff, x, y) } return ResultFailureTemplate(` {{- printf "%v" .Data.x}} ( @@ -128,12 +128,12 @@ func isMultiLineStringCompare(x, y interface{}) bool { return strings.Contains(strX, "\n") || strings.Contains(strY, "\n") } -func multiLineDiffResult(diff string) Result { +func multiLineDiffResult(diff string, x, y interface{}) Result { return ResultFailureTemplate(` --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}} +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}} {{ .Data.diff }}`, - map[string]interface{}{"diff": diff}) + map[string]interface{}{"diff": diff, "x": x, "y": y}) } // Len succeeds if the sequence has the expected length. diff --git a/assert/cmp/result.go b/assert/cmp/result.go index 2b0eb7e3..28ef8d3d 100644 --- a/assert/cmp/result.go +++ b/assert/cmp/result.go @@ -69,6 +69,11 @@ func (r templatedResult) FailureMessage(args []ast.Expr) string { return msg } +func (r templatedResult) UpdatedExpected(stackIndex int) error { + // TODO: would be nice to have structured data instead of a map + return source.UpdateExpectedValue(stackIndex+1, r.data["x"], r.data["y"]) +} + // ResultFailureTemplate returns a Result with a template string and data which // can be used to format a failure message. The template may access data from .Data, // the comparison args with the callArg function, and the formatNode function may diff --git a/golden/golden.go b/golden/golden.go index 72ca05ff..47ea85fe 100644 --- a/golden/golden.go +++ b/golden/golden.go @@ -18,13 +18,11 @@ import ( "gotest.tools/v3/assert" "gotest.tools/v3/assert/cmp" "gotest.tools/v3/internal/format" + "gotest.tools/v3/internal/source" ) -var flagUpdate bool - func init() { - flag.BoolVar(&flagUpdate, "update", false, "update golden files") - flag.BoolVar(&flagUpdate, "test.update-golden", false, "deprecated flag") + flag.BoolVar(&source.Update, "test.update-golden", false, "deprecated flag") } type helperT interface { @@ -46,7 +44,7 @@ var NormalizeCRLFToLF = os.Getenv("GOTESTTOOLS_GOLDEN_NormalizeCRLFToLF") != "fa // FlagUpdate returns true when the -update flag has been set. func FlagUpdate() bool { - return flagUpdate + return source.Update } // Open opens the file in ./testdata @@ -180,7 +178,7 @@ func compare(actual []byte, filename string) (cmp.Result, []byte) { } func update(filename string, actual []byte) error { - if !flagUpdate { + if !source.Update { return nil } if dir := filepath.Dir(Path(filename)); dir != "." { diff --git a/golden/golden_test.go b/golden/golden_test.go index 54e807c5..3b0bd027 100644 --- a/golden/golden_test.go +++ b/golden/golden_test.go @@ -9,6 +9,7 @@ import ( "gotest.tools/v3/assert" "gotest.tools/v3/assert/cmp" "gotest.tools/v3/fs" + "gotest.tools/v3/internal/source" ) type fakeT struct { @@ -190,10 +191,10 @@ func TestGoldenAssertBytes(t *testing.T) { } func setUpdateFlag(t *testing.T) func() { - orig := flagUpdate - flagUpdate = true + orig := source.Update + source.Update = true undo := func() { - flagUpdate = orig + source.Update = orig } t.Cleanup(undo) return undo diff --git a/icmd/command_test.go b/icmd/command_test.go index 5619c3bc..1a5fef93 100644 --- a/icmd/command_test.go +++ b/icmd/command_test.go @@ -12,7 +12,6 @@ import ( exec "golang.org/x/sys/execabs" "gotest.tools/v3/assert" "gotest.tools/v3/fs" - "gotest.tools/v3/golden" "gotest.tools/v3/internal/maint" ) @@ -120,9 +119,22 @@ func TestResult_Match_NotMatched(t *testing.T) { } err := result.match(exp) assert.ErrorContains(t, err, "Failures") - golden.Assert(t, err.Error(), "result-match-no-match.golden") + assert.Equal(t, err.Error(), expectedMatch) } +var expectedMatch = ` +Command: binary arg1 +ExitCode: 99 (timeout) +Error: exit code 99 +Stdout: the output +Stderr: the stderr + +Failures: +ExitCode was 99 expected 101 +Expected command to finish, but it hit the timeout +Expected stdout to contain "Something else" +Expected stderr to contain "[NOTHING]"` + func newLockedBuffer(s string) *lockedBuffer { return &lockedBuffer{buf: *bytes.NewBufferString(s)} } @@ -140,9 +152,20 @@ func TestResult_Match_NotMatchedNoError(t *testing.T) { } err := result.match(exp) assert.ErrorContains(t, err, "Failures") - golden.Assert(t, err.Error(), "result-match-no-match-no-error.golden") + assert.Equal(t, err.Error(), expectedResultMatchNoMatch) } +var expectedResultMatchNoMatch = ` +Command: binary arg1 +ExitCode: 0 +Stdout: the output +Stderr: the stderr + +Failures: +ExitCode was 0 expected 101 +Expected stdout to contain "Something else" +Expected stderr to contain "[NOTHING]"` + func TestResult_Match_Match(t *testing.T) { result := &Result{ Cmd: exec.Command("binary", "arg1"), diff --git a/icmd/testdata/result-match-no-match-no-error.golden b/icmd/testdata/result-match-no-match-no-error.golden deleted file mode 100644 index 162d7665..00000000 --- a/icmd/testdata/result-match-no-match-no-error.golden +++ /dev/null @@ -1,10 +0,0 @@ - -Command: binary arg1 -ExitCode: 0 -Stdout: the output -Stderr: the stderr - -Failures: -ExitCode was 0 expected 101 -Expected stdout to contain "Something else" -Expected stderr to contain "[NOTHING]" \ No newline at end of file diff --git a/icmd/testdata/result-match-no-match.golden b/icmd/testdata/result-match-no-match.golden deleted file mode 100644 index 819f9fdd..00000000 --- a/icmd/testdata/result-match-no-match.golden +++ /dev/null @@ -1,12 +0,0 @@ - -Command: binary arg1 -ExitCode: 99 (timeout) -Error: exit code 99 -Stdout: the output -Stderr: the stderr - -Failures: -ExitCode was 99 expected 101 -Expected command to finish, but it hit the timeout -Expected stdout to contain "Something else" -Expected stderr to contain "[NOTHING]" \ No newline at end of file diff --git a/internal/assert/result.go b/internal/assert/result.go index 20cd5412..36032061 100644 --- a/internal/assert/result.go +++ b/internal/assert/result.go @@ -1,6 +1,7 @@ package assert import ( + "errors" "fmt" "go/ast" @@ -25,6 +26,22 @@ func RunComparison( return true } + if source.Update { + if updater, ok := result.(updateExpected); ok { + const stackIndex = 3 // Assert/Check, assert, RunComparison + err := updater.UpdatedExpected(stackIndex) + switch { + case err == nil: + return true + case errors.Is(err, source.ErrNotFound): + // do nothing, fallthrough to regular failure message + default: + t.Log("failed to update source", err) + return false + } + } + } + var message string switch typed := result.(type) { case resultWithComparisonArgs: @@ -52,6 +69,10 @@ type resultBasic interface { FailureMessage() string } +type updateExpected interface { + UpdatedExpected(stackIndex int) error +} + // filterPrintableExpr filters the ast.Expr slice to only include Expr that are // easy to read when printed and contain relevant information to an assertion. // diff --git a/internal/source/defers.go b/internal/source/defers.go index 8e5a6fb7..392d9fe0 100644 --- a/internal/source/defers.go +++ b/internal/source/defers.go @@ -28,7 +28,7 @@ func guessDefer(node ast.Node) (ast.Node, error) { defers := collectDefers(node) switch len(defers) { case 0: - return nil, fmt.Errorf("failed to expression in defer") + return nil, fmt.Errorf("failed to find expression in defer") case 1: return defers[0].Call, nil default: diff --git a/internal/source/source.go b/internal/source/source.go index 4dbc1bc4..a3f70086 100644 --- a/internal/source/source.go +++ b/internal/source/source.go @@ -10,12 +10,8 @@ import ( "go/token" "os" "runtime" - "strconv" - "strings" ) -const baseStackIndex = 1 - // FormattedCallExprArg returns the argument from an ast.CallExpr at the // index in the call stack. The argument is formatted using FormatNode. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { @@ -32,28 +28,26 @@ func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at // the index in the call stack. func CallExprArgs(stackIndex int) ([]ast.Expr, error) { - _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex) + _, filename, line, ok := runtime.Caller(stackIndex + 1) if !ok { return nil, errors.New("failed to get call stack") } - debug("call stack position: %s:%d", filename, lineNum) + debug("call stack position: %s:%d", filename, line) - node, err := getNodeAtLine(filename, lineNum) - if err != nil { - return nil, err - } - debug("found node: %s", debugFormatNode{node}) - - return getCallExprArgs(node) -} - -func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { fileset := token.NewFileSet() astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors) if err != nil { return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err) } + expr, err := getCallExprArgs(fileset, astFile, line) + if err != nil { + return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err) + } + return expr, nil +} + +func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) { if node := scanToLine(fileset, astFile, lineNum); node != nil { return node, nil } @@ -63,8 +57,7 @@ func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { return node, err } } - return nil, fmt.Errorf( - "failed to find an expression on line %d in %s", lineNum, filename) + return nil, nil } func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { @@ -73,7 +66,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { switch { case node == nil || matchedNode != nil: return false - case nodePosition(fileset, node).Line == lineNum: + case fileset.Position(node.Pos()).Line == lineNum: matchedNode = node return false } @@ -82,46 +75,17 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { return matchedNode } -// In golang 1.9 the line number changed from being the line where the statement -// ended to the line where the statement began. -func nodePosition(fileset *token.FileSet, node ast.Node) token.Position { - if goVersionBefore19 { - return fileset.Position(node.End()) - } - return fileset.Position(node.Pos()) -} - -// GoVersionLessThan returns true if runtime.Version() is semantically less than -// version major.minor. Returns false if a release version can not be parsed from -// runtime.Version(). -func GoVersionLessThan(major, minor int64) bool { - version := runtime.Version() - // not a release version - if !strings.HasPrefix(version, "go") { - return false - } - version = strings.TrimPrefix(version, "go") - parts := strings.Split(version, ".") - if len(parts) < 2 { - return false - } - rMajor, err := strconv.ParseInt(parts[0], 10, 32) - if err != nil { - return false - } - if rMajor != major { - return rMajor < major - } - rMinor, err := strconv.ParseInt(parts[1], 10, 32) - if err != nil { - return false +func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) { + node, err := getNodeAtLine(fileset, astFile, line) + switch { + case err != nil: + return nil, err + case node == nil: + return nil, fmt.Errorf("failed to find an expression") } - return rMinor < minor -} -var goVersionBefore19 = GoVersionLessThan(1, 9) + debug("found node: %s", debugFormatNode{node}) -func getCallExprArgs(node ast.Node) ([]ast.Expr, error) { visitor := &callExprVisitor{} ast.Walk(visitor, node) if visitor.expr == nil { @@ -172,6 +136,9 @@ type debugFormatNode struct { } func (n debugFormatNode) String() string { + if n.Node == nil { + return "none" + } out, err := FormatNode(n.Node) if err != nil { return fmt.Sprintf("failed to format %s: %s", n.Node, err) diff --git a/internal/source/update.go b/internal/source/update.go new file mode 100644 index 00000000..bd9678b8 --- /dev/null +++ b/internal/source/update.go @@ -0,0 +1,138 @@ +package source + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "runtime" + "strings" +) + +// Update is set by the -update flag. It indicates the user running the tests +// would like to update any golden values. +var Update bool + +func init() { + flag.BoolVar(&Update, "update", false, "update golden values") +} + +// ErrNotFound indicates that UpdateExpectedValue failed to find the +// variable to update, likely because it is not a package level variable. +var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value") + +// UpdateExpectedValue looks for a package-level variable with a name that +// starts with expected in the arguments to the caller. If the variable is +// found, the value of the variable will be updated to value of the other +// argument to the caller. +func UpdateExpectedValue(stackIndex int, x, y interface{}) error { + _, filename, line, ok := runtime.Caller(stackIndex + 1) + if !ok { + return errors.New("failed to get call stack") + } + debug("call stack position: %s:%d", filename, line) + + fileset := token.NewFileSet() + astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) + if err != nil { + return fmt.Errorf("failed to parse source file %s: %w", filename, err) + } + + expr, err := getCallExprArgs(fileset, astFile, line) + if err != nil { + return fmt.Errorf("call from %s:%d: %w", filename, line, err) + } + + if len(expr) < 3 { + debug("not enough arguments %d: %v", + len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}}) + return ErrNotFound + } + + argIndex, varName := getVarNameForExpectedValueArg(expr) + if argIndex < 0 || varName == "" { + debug("no arguments started with the word 'expected': %v", + debugFormatNode{Node: &ast.CallExpr{Args: expr}}) + return ErrNotFound + } + + value := x + if argIndex == 1 { + value = y + } + + strValue, ok := value.(string) + if !ok { + debug("value must be type string, got %T", value) + return ErrNotFound + } + return UpdateVariable(filename, fileset, astFile, varName, strValue) +} + +// UpdateVariable writes to filename the contents of astFile with the value of +// the variable updated to value. +func UpdateVariable( + filename string, + fileset *token.FileSet, + astFile *ast.File, + varName string, + value string, +) error { + obj := astFile.Scope.Objects[varName] + if obj == nil { + return ErrNotFound + } + if obj.Kind != ast.Con && obj.Kind != ast.Var { + debug("can only update var and const, found %v", obj.Kind) + return ErrNotFound + } + + spec, ok := obj.Decl.(*ast.ValueSpec) + if !ok { + debug("can only update *ast.ValueSpec, found %T", obj.Decl) + return ErrNotFound + } + if len(spec.Names) != 1 { + debug("more than one name in ast.ValueSpec") + return ErrNotFound + } + + spec.Values[0] = &ast.BasicLit{ + Kind: token.STRING, + Value: "`" + value + "`", + } + + var buf bytes.Buffer + if err := format.Node(&buf, fileset, astFile); err != nil { + return fmt.Errorf("failed to format file after update: %w", err) + } + + fh, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to open file %v: %w", filename, err) + } + if _, err = fh.Write(buf.Bytes()); err != nil { + return fmt.Errorf("failed to write file %v: %w", filename, err) + } + if err := fh.Sync(); err != nil { + return fmt.Errorf("failed to sync file %v: %w", filename, err) + } + return nil +} + +func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) { + for i := 1; i < 3; i++ { + switch e := expr[i].(type) { + case *ast.Ident: + if strings.HasPrefix(strings.ToLower(e.Name), "expected") { + return i, e.Name + } + } + } + return -1, "" +} diff --git a/internal/source/version.go b/internal/source/version.go new file mode 100644 index 00000000..5fa8a903 --- /dev/null +++ b/internal/source/version.go @@ -0,0 +1,35 @@ +package source + +import ( + "runtime" + "strconv" + "strings" +) + +// GoVersionLessThan returns true if runtime.Version() is semantically less than +// version major.minor. Returns false if a release version can not be parsed from +// runtime.Version(). +func GoVersionLessThan(major, minor int64) bool { + version := runtime.Version() + // not a release version + if !strings.HasPrefix(version, "go") { + return false + } + version = strings.TrimPrefix(version, "go") + parts := strings.Split(version, ".") + if len(parts) < 2 { + return false + } + rMajor, err := strconv.ParseInt(parts[0], 10, 32) + if err != nil { + return false + } + if rMajor != major { + return rMajor < major + } + rMinor, err := strconv.ParseInt(parts[1], 10, 32) + if err != nil { + return false + } + return rMinor < minor +}