diff --git a/ez/ez.go b/ez/ez.go index cbd77be..e201773 100644 --- a/ez/ez.go +++ b/ez/ez.go @@ -220,6 +220,7 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct } manglers := make([]transform.Mangler, 0, 2) + // manglers[0] = &transform.AliasMangler{} if params.FileFieldNameEncoder != nil { tagDecoder := params.DialsTagNameDecoder diff --git a/ez/ez_test.go b/ez/ez_test.go index 60440c1..66a7f37 100644 --- a/ez/ez_test.go +++ b/ez/ez_test.go @@ -18,7 +18,7 @@ import ( type config struct { // Path will contain the path to the config file and will be set by // environment variable - Path string `dials:"CONFIGPATH"` + Path string `dials:"CONFIGPATH" dialsalias:"ALTCONFIGPATH"` Val1 int `dials:"Val1"` Val2 string `dials:"Val2"` Set map[string]struct{} `dials:"Set"` @@ -59,6 +59,33 @@ func TestYAMLConfigEnvFlagWithValidConfig(t *testing.T) { assert.EqualValues(t, expectedConfig, *populatedConf) } +func TestYAMLConfigEnvFlagWithValidConfigAndAlias(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + envErr := os.Setenv("ALTCONFIGPATH", "../testhelper/testconfig.yaml") + require.NoError(t, envErr) + defer os.Unsetenv("ALTCONFIGPATH") + + c := &config{} + view, dialsErr := YAMLConfigEnvFlag(ctx, c, Params[config]{}) + require.NoError(t, dialsErr) + + // Val1 and Val2 come from the config file and Path will be populated from env variable + expectedConfig := config{ + Path: "../testhelper/testconfig.yaml", + Val1: 456, + Val2: "hello-world", + Set: map[string]struct{}{ + "Keith": {}, + "Gary": {}, + "Jack": {}, + }, + } + populatedConf := view.View() + assert.EqualValues(t, expectedConfig, *populatedConf) +} + type beatlesConfig struct { YAMLPath string BeatlesMembers map[string]string diff --git a/sources/env/env.go b/sources/env/env.go index 7e619ea..b748b9c 100644 --- a/sources/env/env.go +++ b/sources/env/env.go @@ -39,7 +39,9 @@ func (e *Source) Value(_ context.Context, t *dials.Type) (reflect.Value, error) tagCopyingMangler := &tagformat.TagCopyingMangler{SrcTag: common.DialsTagName, NewTag: envTagName} // convert all the fields in the flattened struct to string type so the environment variables can be set stringCastingMangler := &transform.StringCastingMangler{} - tfmr := transform.NewTransformer(t.Type(), flattenMangler, reformatTagMangler, tagCopyingMangler, stringCastingMangler) + // allow aliasing to migrate from one name to another + aliasMangler := &transform.AliasMangler{} + tfmr := transform.NewTransformer(t.Type(), aliasMangler, flattenMangler, reformatTagMangler, tagCopyingMangler, stringCastingMangler) val, err := tfmr.Translate() if err != nil { diff --git a/transform/alias_mangler.go b/transform/alias_mangler.go new file mode 100644 index 0000000..4f26808 --- /dev/null +++ b/transform/alias_mangler.go @@ -0,0 +1,128 @@ +package transform + +import ( + "fmt" + "reflect" + + "github.com/fatih/structtag" +) + +const ( + dialsAliasTagSuffix = "alias" + aliasFieldSuffix = "_alias9wr876rw3" // a random string to append to the alias field to avoid collisions +) + +// AliasMangler manages aliases for dials, dialsenv, dialsflag, and dialspflag +// struct tags to make it possible to migrate from one name to another +// conveniently. +type AliasMangler struct{} + +// Mangle implements the Mangler interface. If an alias tag is defined, the +// struct field will be copied with the non-aliased tag set to the alias's +// value. +func (a AliasMangler) Mangle(sf reflect.StructField) ([]reflect.StructField, error) { + sourceTags := []string{"dials", "dialsenv", "dialsflag", "dialspflag"} + originalVals := map[string]string{} + aliasVals := map[string]string{} + + sfTags, parseErr := structtag.Parse(string(sf.Tag)) + if parseErr != nil { + return nil, fmt.Errorf("error parsing source tags %w", parseErr) + } + + anyAliasFound := false + for _, tag := range sourceTags { + if originalVal, getErr := sfTags.Get(tag); getErr == nil { + originalVals[tag] = originalVal.Name + } + + if aliasVal, getErr := sfTags.Get(tag + dialsAliasTagSuffix); getErr == nil { + aliasVals[tag] = aliasVal.Name + anyAliasFound = true + + // remove the alias tag from the definition + sfTags.Delete(tag + dialsAliasTagSuffix) + } + } + + if !anyAliasFound { + // we didn't find any aliases so just get out early + return []reflect.StructField{sf}, nil + } + + aliasField := sf + aliasField.Name += aliasFieldSuffix + + // now that we've copied it, reset the struct tags on the source field to + // not include the alias tags + sf.Tag = reflect.StructTag(sfTags.String()) + + tags, parseErr := structtag.Parse(string(aliasField.Tag)) + if parseErr != nil { + return nil, fmt.Errorf("error parsing struct tags: %w", parseErr) + } + + for _, tag := range sourceTags { + // remove the alias tag so it's not left on the copied StructField + tags.Delete(tag + dialsAliasTagSuffix) + + if aliasVals[tag] == "" { + // if the particular flag isn't set at all just move on... + continue + } + + newDialsTag := &structtag.Tag{ + Key: tag, + Name: aliasVals[tag], + } + + if setErr := tags.Set(newDialsTag); setErr != nil { + return nil, fmt.Errorf("error setting new value for dials tag: %w", setErr) + } + + // update dialsdesc if there is one + if desc, getErr := tags.Get("dialsdesc"); getErr == nil { + newDesc := &structtag.Tag{ + Key: "dialsdesc", + Name: desc.Name + " (alias of `" + originalVals[tag] + "`)", + } + if setErr := tags.Set(newDesc); setErr != nil { + return nil, fmt.Errorf("error setting amended dialsdesc for tag %q: %w", tag, setErr) + } + } + } + + // set the new flags on the alias field + aliasField.Tag = reflect.StructTag(tags.String()) + + return []reflect.StructField{sf, aliasField}, nil +} + +// Unmangle implements the Mangler interface and unwinds the alias copying +// operation. Note that if both the source and alias are both set in the +// configuration, an error will be returned. +func (a AliasMangler) Unmangle(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, error) { + if len(fvs) != 2 { + return reflect.Value{}, fmt.Errorf("expected 2 tuples, got %d", len(fvs)) + } + + if !fvs[0].Value.IsNil() && !fvs[1].Value.IsNil() { + return reflect.Value{}, fmt.Errorf("both alias and original set for field %q", sf.Name) + } + + for _, fv := range fvs { + if !fv.Value.IsNil() { + return fv.Value, nil + } + } + + // if we made it this far, they were both nil, which is fine -- just return + // one of them. + return fvs[0].Value, nil +} + +// ShouldRecurse is called after Mangle for each field so nested struct +// fields get iterated over after any transformation done by Mangle(). +func (a AliasMangler) ShouldRecurse(_ reflect.StructField) bool { + return true +} diff --git a/transform/alias_mangler_test.go b/transform/alias_mangler_test.go new file mode 100644 index 0000000..f544260 --- /dev/null +++ b/transform/alias_mangler_test.go @@ -0,0 +1,148 @@ +package transform + +import ( + "reflect" + "testing" + + "github.com/fatih/structtag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAliasManglerMangle(t *testing.T) { + + for testName, tbl := range map[string]struct { + tag string + expectedOrig map[string]string + expectedAlias map[string]string + }{ + "dialsOnly": { + tag: `dials:"name" dialsalias:"anothername"`, + expectedOrig: map[string]string{"dials": "name"}, + expectedAlias: map[string]string{"dials": "anothername"}, + }, + "withDialsDesc": { + tag: `dials:"name" dialsalias:"anothername" dialsdesc:"the name for this"`, + expectedOrig: map[string]string{"dials": "name", "dialsdesc": "the name for this"}, + expectedAlias: map[string]string{"dials": "anothername", "dialsdesc": "the name for this (alias of `name`)"}, + }, + "dialsDialsEnvFlagPFlag": { + tag: `dials:"name" dialsflag:"flagname" dialspflag:"pflagname" dialsenv:"envname" dialsalias:"anothername" dialsflagalias:"flagalias" dialspflagalias:"pflagalias" dialsenvalias:"envalias"`, + expectedOrig: map[string]string{"dials": "name", "dialsflag": "flagname", "dialspflag": "pflagname", "dialsenv": "envname"}, + expectedAlias: map[string]string{"dials": "anothername", "dialsflag": "flagalias", "dialspflag": "pflagalias", "dialsenv": "envalias"}, + }, + } { + t.Run(testName, func(t *testing.T) { + sf := reflect.StructField{ + Name: "Foo", + Type: reflect.TypeOf(""), + Tag: reflect.StructTag(tbl.tag), + } + + aliasMangler := AliasMangler{} + fields, mangleErr := aliasMangler.Mangle(sf) + require.NoError(t, mangleErr) + + require.Len(t, fields, 2) + + originalTags, parseErr := structtag.Parse(string(fields[0].Tag)) + require.NoError(t, parseErr) + + for k, v := range tbl.expectedOrig { + val, err := originalTags.Get(k) + require.NoError(t, err) + assert.Equal(t, v, val.Name) + } + assert.Equal(t, len(tbl.expectedOrig), originalTags.Len()) + + aliasTags, parseErr := structtag.Parse(string(fields[1].Tag)) + require.NoError(t, parseErr) + + for k, v := range tbl.expectedAlias { + val, err := aliasTags.Get(k) + require.NoError(t, err) + assert.Equal(t, v, val.Name) + } + assert.Equal(t, len(tbl.expectedAlias), aliasTags.Len()) + + // for _, f := range fields { + // t.Logf("name: %s tag: %s", f.Name, string(f.Tag)) + // } + }) + } +} + +func TestAliasManglerUnmangle(t *testing.T) { + sf := reflect.StructField{ + Name: "Foo", + Type: reflect.TypeOf(""), + } + + num := 42 + var nilInt *int + + aliasMangler := &AliasMangler{} + + originalSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + } + + val, err := aliasMangler.Unmangle(sf, originalSet) + require.NoError(t, err) + + assert.Equal(t, 42, val.Elem().Interface()) + + aliasSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + } + + val, err = aliasMangler.Unmangle(sf, aliasSet) + require.NoError(t, err) + + assert.Equal(t, 42, val.Elem().Interface()) + + bothSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + { + Field: sf, + Value: reflect.ValueOf(&num), + }, + } + + _, err = aliasMangler.Unmangle(sf, bothSet) + assert.NotNil(t, err) // there should be an error if both are set! + + neitherSet := []FieldValueTuple{ + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + { + Field: sf, + Value: reflect.ValueOf(nilInt), + }, + } + + val, err = aliasMangler.Unmangle(sf, neitherSet) + require.NoError(t, err) + + assert.True(t, val.IsNil()) + +}