diff --git a/common/tags.go b/common/tags.go index db1fecf..fab50f5 100644 --- a/common/tags.go +++ b/common/tags.go @@ -1,5 +1,22 @@ // Package common provides constants that are used among different dials sources package common -// DialsTagName is the name of the dials tag. -const DialsTagName = "dials" +const ( + // DialsTagName is the name of the dials tag. + DialsTagName = "dials" + + // DialsEnvTagName is the name of the dialsenv tag. + DialsEnvTagName = "dialsenv" + + // DialsFlagTagName is the name of the dialsflag tag. + DialsFlagTagName = "dialsflag" + + // DialsPFlagTagName is the name of the dialspflag tag. + DialsPFlagTag = "dialspflag" + + // DialsFlagAliasTag is the name of the dialsflagalias tag. + DialsPFlagShortTag = "dialspflagshort" + + // HelpTextTag is the name of the struct tag for flag descriptions + DialsHelpTextTag = "dialsdesc" +) diff --git a/ez/ez.go b/ez/ez.go index cbd77be..bdc18a4 100644 --- a/ez/ez.go +++ b/ez/ez.go @@ -171,16 +171,6 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct flagSrc = fset } - // If file-watching is not enabled, we should shutdown the monitor - // goroutine when exiting this function. - // Usually `dials.Config` is smart enough not to start a monitor when - // there are no `Watcher` implementations in the source-list, but the - // `Blank` source uses `Watcher` for its core functionality, so we need - // to shutdown the blank source to actually clean up resources. - if !params.WatchConfigFile { - defer blank.Done(ctx) - } - dp := dials.Params[T]{ // Set the OnNewConfig callback. It'll be suppressed by the // CallGlobalCallbacksAfterVerificationEnabled until just before we return. @@ -199,6 +189,16 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct return nil, err } + // If file-watching is not enabled, we should shutdown the monitor + // goroutine when exiting this function. + // Usually `dials.Config` is smart enough not to start a monitor when + // there are no `Watcher` implementations in the source-list, but the + // `Blank` source uses `Watcher` for its core functionality, so we need + // to shutdown the blank source to actually clean up resources. + if !params.WatchConfigFile { + defer blank.Done(ctx) + } + basecfg := d.View() cfgPath, filepathSet := (TP)(basecfg).ConfigPath() if !filepathSet { @@ -219,7 +219,8 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct return nil, fmt.Errorf("decoderFactory provided a nil decoder for path: %s", cfgPath) } - manglers := make([]transform.Mangler, 0, 2) + manglers := make([]transform.Mangler, 0, 3) + manglers = append(manglers, &transform.AliasMangler{}) if params.FileFieldNameEncoder != nil { tagDecoder := params.DialsTagNameDecoder diff --git a/ez/ez_test.go b/ez/ez_test.go index 60440c1..66a7f37 100644 --- a/ez/ez_test.go +++ b/ez/ez_test.go @@ -18,7 +18,7 @@ import ( type config struct { // Path will contain the path to the config file and will be set by // environment variable - Path string `dials:"CONFIGPATH"` + Path string `dials:"CONFIGPATH" dialsalias:"ALTCONFIGPATH"` Val1 int `dials:"Val1"` Val2 string `dials:"Val2"` Set map[string]struct{} `dials:"Set"` @@ -59,6 +59,33 @@ func TestYAMLConfigEnvFlagWithValidConfig(t *testing.T) { assert.EqualValues(t, expectedConfig, *populatedConf) } +func TestYAMLConfigEnvFlagWithValidConfigAndAlias(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + envErr := os.Setenv("ALTCONFIGPATH", "../testhelper/testconfig.yaml") + require.NoError(t, envErr) + defer os.Unsetenv("ALTCONFIGPATH") + + c := &config{} + view, dialsErr := YAMLConfigEnvFlag(ctx, c, Params[config]{}) + require.NoError(t, dialsErr) + + // Val1 and Val2 come from the config file and Path will be populated from env variable + expectedConfig := config{ + Path: "../testhelper/testconfig.yaml", + Val1: 456, + Val2: "hello-world", + Set: map[string]struct{}{ + "Keith": {}, + "Gary": {}, + "Jack": {}, + }, + } + populatedConf := view.View() + assert.EqualValues(t, expectedConfig, *populatedConf) +} + type beatlesConfig struct { YAMLPath string BeatlesMembers map[string]string diff --git a/sources/env/env.go b/sources/env/env.go index 7e619ea..ff2d443 100644 --- a/sources/env/env.go +++ b/sources/env/env.go @@ -13,8 +13,6 @@ import ( "github.com/vimeo/dials/transform" ) -const envTagName = "dialsenv" - // Source implements the dials.Source interface to set configuration from // environment variables. type Source struct { @@ -36,10 +34,12 @@ func (e *Source) Value(_ context.Context, t *dials.Type) (reflect.Value, error) // reformat the tags so they are SCREAMING_SNAKE_CASE reformatTagMangler := tagformat.NewTagReformattingMangler(common.DialsTagName, caseconversion.DecodeGoTags, caseconversion.EncodeUpperSnakeCase) // copy tags from "dials" to "dialsenv" tag - tagCopyingMangler := &tagformat.TagCopyingMangler{SrcTag: common.DialsTagName, NewTag: envTagName} + tagCopyingMangler := &tagformat.TagCopyingMangler{SrcTag: common.DialsTagName, NewTag: common.DialsEnvTagName} // convert all the fields in the flattened struct to string type so the environment variables can be set stringCastingMangler := &transform.StringCastingMangler{} - tfmr := transform.NewTransformer(t.Type(), flattenMangler, reformatTagMangler, tagCopyingMangler, stringCastingMangler) + // allow aliasing to migrate from one name to another + aliasMangler := &transform.AliasMangler{} + tfmr := transform.NewTransformer(t.Type(), aliasMangler, flattenMangler, reformatTagMangler, tagCopyingMangler, stringCastingMangler) val, err := tfmr.Translate() if err != nil { @@ -49,11 +49,11 @@ func (e *Source) Value(_ context.Context, t *dials.Type) (reflect.Value, error) valType := val.Type() for i := 0; i < val.NumField(); i++ { sf := valType.Field(i) - envTagVal := sf.Tag.Get(envTagName) + envTagVal := sf.Tag.Get(common.DialsEnvTagName) if envTagVal == "" { // dialsenv tag should be populated because dials tag is populated // after flatten mangler and we copy from dials to dialsenv tag - panic(fmt.Errorf("empty %s tag for field name %s", envTagName, sf.Name)) + panic(fmt.Errorf("empty %s tag for field name %s", common.DialsEnvTagName, sf.Name)) } if e.Prefix != "" { diff --git a/sources/flag/flag.go b/sources/flag/flag.go index b6948dd..a065936 100644 --- a/sources/flag/flag.go +++ b/sources/flag/flag.go @@ -67,8 +67,6 @@ var ( _ dials.Source = (*Set)(nil) ) -const dialsFlagTag = "dialsflag" - // NameConfig defines the parameters for separating components of a flag-name type NameConfig struct { // FieldNameEncodeCasing is for the field names used by the flatten mangler @@ -204,7 +202,7 @@ func (s *Set) parse() error { func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { fm := transform.NewFlattenMangler(common.DialsTagName, s.NameCfg.FieldNameEncodeCasing, s.NameCfg.TagEncodeCasing) - tfmr := transform.NewTransformer(ptyp, fm) + tfmr := transform.NewTransformer(ptyp, &transform.AliasMangler{}, fm) val, TrnslErr := tfmr.Translate() if TrnslErr != nil { return TrnslErr @@ -241,7 +239,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { // If the field's dialsflag tag is a hyphen (ex: `dialsflag:"-"`), // don't register the flag. Currently nested fields with "-" tag will // still be registered - if dft, ok := sf.Tag.Lookup(dialsFlagTag); ok && (dft == "-") { + if dft, ok := sf.Tag.Lookup(common.DialsFlagTagName); ok && (dft == "-") { continue } @@ -506,7 +504,7 @@ func willOverflow(val, target reflect.Value) bool { // decoded field name and converting it into kebab case func (s *Set) mkname(sf reflect.StructField) string { // use the name from the dialsflag tag for the flag name - if name, ok := sf.Tag.Lookup(dialsFlagTag); ok { + if name, ok := sf.Tag.Lookup(common.DialsFlagTagName); ok { return name } // check if the dials tag is populated (it should be once it goes through diff --git a/sources/pflag/pflag.go b/sources/pflag/pflag.go index 410933c..269ee7a 100644 --- a/sources/pflag/pflag.go +++ b/sources/pflag/pflag.go @@ -70,13 +70,9 @@ var ( ) const ( - dialsPFlagTag = "dialspflag" - dialsPFlagShortTag = "dialspflagshort" - // HelpTextTag is the name of the struct tag for flag descriptions - HelpTextTag = "dialsdesc" // DefaultFlagHelpText is the default help-text for fields with an // unset dialsdesc tag. - DefaultFlagHelpText = "unset description (`" + HelpTextTag + "` struct tag)" + DefaultFlagHelpText = "unset description (`" + common.DialsHelpTextTag + "` struct tag)" ) // NameConfig defines the parameters for separating components of a flag-name @@ -214,7 +210,7 @@ func (s *Set) parse() error { func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { fm := transform.NewFlattenMangler(common.DialsTagName, s.NameCfg.FieldNameEncodeCasing, s.NameCfg.TagEncodeCasing) - tfmr := transform.NewTransformer(ptyp, fm) + tfmr := transform.NewTransformer(ptyp, &transform.AliasMangler{}, fm) val, TrnslErr := tfmr.Translate() if TrnslErr != nil { return TrnslErr @@ -235,7 +231,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { for i := 0; i < t.NumField(); i++ { sf := t.Field(i) help := DefaultFlagHelpText - if x, ok := sf.Tag.Lookup(HelpTextTag); ok { + if x, ok := sf.Tag.Lookup(common.DialsHelpTextTag); ok { help = x } @@ -251,7 +247,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { // If the field's dialspflag tag is a hyphen (ex: `dialspflag:"-"`), // don't register the flag. Currently nested fields with "-" tag will // still be registered - if dpt, ok := sf.Tag.Lookup(dialsPFlagTag); ok && (dpt == "-") { + if dpt, ok := sf.Tag.Lookup(common.DialsPFlagTag); ok && (dpt == "-") { continue } @@ -267,7 +263,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { // get the concrete value of the field from the template fieldVal := transform.GetField(sf, tmpl) - shorthand, _ := sf.Tag.Lookup(dialsPFlagShortTag) + shorthand, _ := sf.Tag.Lookup(common.DialsPFlagShortTag) var f interface{} switch { @@ -516,7 +512,7 @@ func stripTypePtr(t reflect.Type) reflect.Type { // decoded field name and converting it into kebab case func (s *Set) mkname(sf reflect.StructField) string { // use the name from the dialspflag tag for the flag name - if name, ok := sf.Tag.Lookup(dialsPFlagTag); ok { + if name, ok := sf.Tag.Lookup(common.DialsPFlagTag); ok { return name } // check if the dials tag is populated (it should be once it goes through diff --git a/transform/alias_mangler.go b/transform/alias_mangler.go new file mode 100644 index 0000000..775da6a --- /dev/null +++ b/transform/alias_mangler.go @@ -0,0 +1,146 @@ +package transform + +import ( + "fmt" + "reflect" + + "github.com/fatih/structtag" + "github.com/vimeo/dials/common" +) + +const ( + dialsAliasTagSuffix = "alias" + aliasFieldSuffix = "_alias9wr876rw3" // a random string to append to the alias field to avoid collisions +) + +// the list of tags that we should search for aliases +var aliasSourceTags = []string{ + common.DialsTagName, + common.DialsFlagTagName, + common.DialsEnvTagName, + common.DialsFlagTagName, + common.DialsPFlagTag, + common.DialsPFlagShortTag, +} + +// AliasMangler manages aliases for dials, dialsenv, dialsflag, and dialspflag +// struct tags to make it possible to migrate from one name to another +// conveniently. +type AliasMangler struct{} + +// Mangle implements the Mangler interface. If an alias tag is defined, the +// struct field will be copied with the non-aliased tag set to the alias's +// value. +func (a AliasMangler) Mangle(sf reflect.StructField) ([]reflect.StructField, error) { + originalVals := map[string]string{} + aliasVals := map[string]string{} + + sfTags, parseErr := structtag.Parse(string(sf.Tag)) + if parseErr != nil { + return nil, fmt.Errorf("error parsing source tags %w", parseErr) + } + + anyAliasFound := false + for _, tag := range aliasSourceTags { + if originalVal, getErr := sfTags.Get(tag); getErr == nil { + originalVals[tag] = originalVal.Name + } + + if aliasVal, getErr := sfTags.Get(tag + dialsAliasTagSuffix); getErr == nil { + aliasVals[tag] = aliasVal.Name + anyAliasFound = true + + // remove the alias tag from the definition + sfTags.Delete(tag + dialsAliasTagSuffix) + } + } + + if !anyAliasFound { + // we didn't find any aliases so just get out early + return []reflect.StructField{sf}, nil + } + + aliasField := sf + aliasField.Name += aliasFieldSuffix + + // now that we've copied it, reset the struct tags on the source field to + // not include the alias tags + sf.Tag = reflect.StructTag(sfTags.String()) + + tags, parseErr := structtag.Parse(string(aliasField.Tag)) + if parseErr != nil { + return nil, fmt.Errorf("error parsing struct tags: %w", parseErr) + } + + for _, tag := range aliasSourceTags { + // remove the alias tag so it's not left on the copied StructField + tags.Delete(tag + dialsAliasTagSuffix) + + if aliasVals[tag] == "" { + // if the particular flag isn't set at all just move on... + continue + } + + newDialsTag := &structtag.Tag{ + Key: tag, + Name: aliasVals[tag], + } + + if setErr := tags.Set(newDialsTag); setErr != nil { + return nil, fmt.Errorf("error setting new value for dials tag: %w", setErr) + } + + // update dialsdesc if there is one + if desc, getErr := tags.Get("dialsdesc"); getErr == nil { + newDesc := &structtag.Tag{ + Key: "dialsdesc", + Name: desc.Name + " (alias of `" + originalVals[tag] + "`)", + } + if setErr := tags.Set(newDesc); setErr != nil { + return nil, fmt.Errorf("error setting amended dialsdesc for tag %q: %w", tag, setErr) + } + } + } + + // set the new flags on the alias field + aliasField.Tag = reflect.StructTag(tags.String()) + + return []reflect.StructField{sf, aliasField}, nil +} + +// Unmangle implements the Mangler interface and unwinds the alias copying +// operation. Note that if both the source and alias are both set in the +// configuration, an error will be returned. +func (a AliasMangler) Unmangle(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, error) { + switch len(fvs) { + case 1: + // if there's only one tuple that means there was no alias, so just + // return... + return fvs[0].Value, nil + case 2: + // two means there's an alias so we should continue on... + default: + return reflect.Value{}, fmt.Errorf("expected 1 or 2 tuples, got %d", len(fvs)) + } + + if !fvs[0].Value.IsNil() && !fvs[1].Value.IsNil() { + return reflect.Value{}, fmt.Errorf("both alias and original set for field %q", sf.Name) + } + + // return the first one that isn't nil + for _, fv := range fvs { + if !fv.Value.IsNil() { + return fv.Value, nil + } + } + + // if we made it this far, they were both nil, which is fine -- just return + // one of them. + return fvs[0].Value, nil +} + +// ShouldRecurse is called after Mangle for each field so nested struct +// fields get iterated over after any transformation done by Mangle(). +func (a AliasMangler) ShouldRecurse(_ reflect.StructField) bool { + return true +} diff --git a/transform/alias_mangler_test.go b/transform/alias_mangler_test.go new file mode 100644 index 0000000..25f7ec8 --- /dev/null +++ b/transform/alias_mangler_test.go @@ -0,0 +1,156 @@ +package transform + +import ( + "reflect" + "testing" + + "github.com/fatih/structtag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAliasManglerMangle(t *testing.T) { + + for testName, itbl := range map[string]struct { + tag string + expectedOrig map[string]string + expectedAlias map[string]string + }{ + "dialsOnly": { + tag: `dials:"name" dialsalias:"anothername"`, + expectedOrig: map[string]string{"dials": "name"}, + expectedAlias: map[string]string{"dials": "anothername"}, + }, + "withDialsDesc": { + tag: `dials:"name" dialsalias:"anothername" dialsdesc:"the name for this"`, + expectedOrig: map[string]string{"dials": "name", "dialsdesc": "the name for this"}, + expectedAlias: map[string]string{"dials": "anothername", "dialsdesc": "the name for this (alias of `name`)"}, + }, + "dialsDialsEnvFlagPFlag": { + tag: `dials:"name" dialsflag:"flagname" dialspflag:"pflagname" dialsenv:"envname" dialsalias:"anothername" dialsflagalias:"flagalias" dialspflagalias:"pflagalias" dialsenvalias:"envalias"`, + expectedOrig: map[string]string{"dials": "name", "dialsflag": "flagname", "dialspflag": "pflagname", "dialsenv": "envname"}, + expectedAlias: map[string]string{"dials": "anothername", "dialsflag": "flagalias", "dialspflag": "pflagalias", "dialsenv": "envalias"}, + }, + } { + tbl := itbl + t.Run(testName, func(t *testing.T) { + sf := reflect.StructField{ + Name: "Foo", + Type: reflect.TypeOf(""), + Tag: reflect.StructTag(tbl.tag), + } + + aliasMangler := AliasMangler{} + fields, mangleErr := aliasMangler.Mangle(sf) + require.NoError(t, mangleErr) + + require.Len(t, fields, 2) + + originalTags, parseErr := structtag.Parse(string(fields[0].Tag)) + require.NoError(t, parseErr) + + for k, v := range tbl.expectedOrig { + val, err := originalTags.Get(k) + require.NoError(t, err) + assert.Equal(t, v, val.Name) + } + assert.Equal(t, len(tbl.expectedOrig), originalTags.Len()) + + aliasTags, parseErr := structtag.Parse(string(fields[1].Tag)) + require.NoError(t, parseErr) + + for k, v := range tbl.expectedAlias { + val, err := aliasTags.Get(k) + require.NoError(t, err) + assert.Equal(t, v, val.Name) + } + assert.Equal(t, len(tbl.expectedAlias), aliasTags.Len()) + }) + } +} + +func TestAliasManglerUnmangle(t *testing.T) { + sf := reflect.StructField{ + Name: "Foo", + Type: reflect.TypeOf(""), + } + + num := 42 + var nilInt *int + + aliasMangler := &AliasMangler{} + + originalSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + } + + val, err := aliasMangler.Unmangle(sf, originalSet) + require.NoError(t, err) + + assert.Equal(t, 42, val.Elem().Interface()) + + aliasSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + } + + val, err = aliasMangler.Unmangle(sf, aliasSet) + require.NoError(t, err) + + assert.Equal(t, 42, val.Elem().Interface()) + + bothSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + } + + _, err = aliasMangler.Unmangle(sf, bothSet) + assert.NotNil(t, err) // there should be an error if both are set! + + neitherSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + } + + val, err = aliasMangler.Unmangle(sf, neitherSet) + require.NoError(t, err) + + assert.True(t, val.IsNil()) + + noAlias := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + } + + val, err = aliasMangler.Unmangle(sf, noAlias) + require.NoError(t, err) + + assert.Equal(t, 42, val.Elem().Interface()) +} diff --git a/transform/transformer.go b/transform/transformer.go index 3e98da6..cb345ae 100644 --- a/transform/transformer.go +++ b/transform/transformer.go @@ -128,12 +128,20 @@ func (t *Transformer) maybeRecursivelyMangle(mangler Mangler, state *transformMa if !mangler.ShouldRecurse(field) { continue } + ft := field.Type + + // also don't recurse into TextUnarshaler types + if ft.Implements(textMReflectType) || reflect.PointerTo(ft).Implements(textMReflectType) { + continue + } + // strip any outer pointerification, slice or array switch ft.Kind() { case reflect.Ptr, reflect.Array, reflect.Slice: ft = ft.Elem() } + fieldTransformer := Transformer{ manglers: []Mangler{mangler}, mState: nil,