From 8a8987758e45e87d535cde6f1418a85058a94007 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Tue, 3 Aug 2021 20:31:28 +0200 Subject: [PATCH] Add WriteAnswer support for promoted fields (#366) --- core/write.go | 54 +++++++++----- core/write_test.go | 178 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 193 insertions(+), 39 deletions(-) diff --git a/core/write.go b/core/write.go index 47d0e498..dfe9b499 100644 --- a/core/write.go +++ b/core/write.go @@ -24,6 +24,11 @@ type OptionAnswer struct { Index int } +type reflectField struct { + value reflect.Value + fieldType reflect.StructField +} + func OptionAnswerList(incoming []string) []OptionAnswer { list := []OptionAnswer{} for i, opt := range incoming { @@ -63,13 +68,12 @@ func WriteAnswer(t interface{}, name string, v interface{}) (err error) { } // get the name of the field that matches the string we were given - fieldIndex, err := findFieldIndex(elem, name) + field, _, err := findField(elem, name) // if something went wrong if err != nil { // bubble up return err } - field := elem.Field(fieldIndex) // handle references to the Settable interface aswell if s, ok := field.Interface().(Settable); ok { // use the interface method @@ -156,37 +160,51 @@ func IsFieldNotMatch(err error) (string, bool) { // BUG(AlecAivazis): the current implementation might cause weird conflicts if there are // two fields with same name that only differ by casing. -func findFieldIndex(s reflect.Value, name string) (int, error) { - // the type of the value - sType := s.Type() +func findField(s reflect.Value, name string) (reflect.Value, reflect.StructField, error) { - // first look for matching tags so we can overwrite matching field names - for i := 0; i < sType.NumField(); i++ { - // the field we are current scanning - field := sType.Field(i) + fields := flattenFields(s) + // first look for matching tags so we can overwrite matching field names + for _, f := range fields { // the value of the survey tag - tag := field.Tag.Get(tagName) + tag := f.fieldType.Tag.Get(tagName) // if the tag matches the name we are looking for if tag != "" && tag == name { // then we found our index - return i, nil + return f.value, f.fieldType, nil } } // then look for matching names - for i := 0; i < sType.NumField(); i++ { - // the field we are current scanning - field := sType.Field(i) - + for _, f := range fields { // if the name of the field matches what we're looking for - if strings.ToLower(field.Name) == strings.ToLower(name) { - return i, nil + if strings.ToLower(f.fieldType.Name) == strings.ToLower(name) { + return f.value, f.fieldType, nil } } // we didn't find the field - return -1, errFieldNotMatch{name} + return reflect.Value{}, reflect.StructField{}, errFieldNotMatch{name} +} + +func flattenFields(s reflect.Value) []reflectField { + sType := s.Type() + numField := sType.NumField() + fields := make([]reflectField, 0, numField) + for i := 0; i < numField; i++ { + fieldType := sType.Field(i) + field := s.Field(i) + + if field.Kind() == reflect.Struct && fieldType.Anonymous { + // field is a promoted structure + for _, f := range flattenFields(field) { + fields = append(fields, f) + } + continue + } + fields = append(fields, reflectField{field, fieldType}) + } + return fields } // isList returns true if the element is something we can Len() diff --git a/core/write_test.go b/core/write_test.go index 87e902ed..05b40b4c 100644 --- a/core/write_test.go +++ b/core/write_test.go @@ -305,12 +305,12 @@ func TestWriteAnswer_returnsErrWhenFieldNotFound(t *testing.T) { } } -func TestFindFieldIndex_canFindExportedField(t *testing.T) { +func TestFindField_canFindExportedField(t *testing.T) { // create a reflective wrapper over the struct to look through - val := reflect.ValueOf(struct{ Name string }{}) + val := reflect.ValueOf(struct{ Name string }{Name: "Jack"}) // find the field matching "name" - fieldIndex, err := findFieldIndex(val, "name") + field, fieldType, err := findField(val, "name") // if something went wrong if err != nil { // the test failed @@ -319,20 +319,28 @@ func TestFindFieldIndex_canFindExportedField(t *testing.T) { } // make sure we got the right value - if val.Type().Field(fieldIndex).Name != "Name" { + if field.Interface() != "Jack" { // the test failed - t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name) + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + + // make sure we got the right field type + if fieldType.Name != "Name" { + // the test failed + t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name) } } -func TestFindFieldIndex_canFindTaggedField(t *testing.T) { +func TestFindField_canFindTaggedField(t *testing.T) { // the struct to look through val := reflect.ValueOf(struct { Username string `survey:"name"` - }{}) + }{ + Username: "Jack", + }) // find the field matching "name" - fieldIndex, err := findFieldIndex(val, "name") + field, fieldType, err := findField(val, "name") // if something went wrong if err != nil { // the test failed @@ -341,52 +349,180 @@ func TestFindFieldIndex_canFindTaggedField(t *testing.T) { } // make sure we got the right value - if val.Type().Field(fieldIndex).Name != "Username" { + if field.Interface() != "Jack" { // the test failed - t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name) + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + + // make sure we got the right fieldType + if fieldType.Name != "Username" { + // the test failed + t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name) } } -func TestFindFieldIndex_canHandleCapitalAnswerNames(t *testing.T) { +func TestFindField_canHandleCapitalAnswerNames(t *testing.T) { // create a reflective wrapper over the struct to look through - val := reflect.ValueOf(struct{ Name string }{}) + val := reflect.ValueOf(struct{ Name string }{Name: "Jack"}) // find the field matching "name" - fieldIndex, err := findFieldIndex(val, "Name") + field, fieldType, err := findField(val, "Name") // if something went wrong if err != nil { // the test failed t.Error(err.Error()) return } - // make sure we got the right value - if val.Type().Field(fieldIndex).Name != "Name" { + if field.Interface() != "Jack" { + // the test failed + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + + // make sure we got the right fieldType + if fieldType.Name != "Name" { // the test failed - t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name) + t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name) } } -func TestFindFieldIndex_tagOverwriteFieldName(t *testing.T) { +func TestFindField_tagOverwriteFieldName(t *testing.T) { // the struct to look through val := reflect.ValueOf(struct { Name string Username string `survey:"name"` - }{}) + }{ + Name: "Ralf", + Username: "Jack", + }) + + // find the field matching "name" + field, fieldType, err := findField(val, "name") + // if something went wrong + if err != nil { + // the test failed + t.Error(err.Error()) + return + } + + // make sure we got the right value + if field.Interface() != "Jack" { + // the test failed + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + + // make sure we got the right fieldType + if fieldType.Name != "Username" { + // the test failed + t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name) + } +} + +func TestFindField_supportsPromotedFields(t *testing.T) { + // create a reflective wrapper over the struct to look through + type Common struct { + Name string + } + + type Strct struct { + Common // Name field added by composition + Username string + } + + val := reflect.ValueOf(Strct{Common: Common{Name: "Jack"}}) + + // find the field matching "name" + field, fieldType, err := findField(val, "Name") + // if something went wrong + if err != nil { + // the test failed + t.Error(err.Error()) + return + } + // make sure we got the right value + if field.Interface() != "Jack" { + // the test failed + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + + // make sure we got the right fieldType + if fieldType.Name != "Name" { + // the test failed + t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name) + } +} + +func TestFindField_promotedFieldsWithTag(t *testing.T) { + // create a reflective wrapper over the struct to look through + type Common struct { + Username string `survey:"name"` + } + + type Strct struct { + Common // Name field added by composition + Name string + } + + val := reflect.ValueOf(Strct{ + Common: Common{Username: "Jack"}, + Name: "Ralf", + }) // find the field matching "name" - fieldIndex, err := findFieldIndex(val, "name") + field, fieldType, err := findField(val, "name") // if something went wrong if err != nil { // the test failed t.Error(err.Error()) return } + // make sure we got the right value + if field.Interface() != "Jack" { + // the test failed + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + // make sure we got the right fieldType + if fieldType.Name != "Username" { + // the test failed + t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name) + } +} + +func TestFindField_promotedFieldsDontHavePriorityOverTags(t *testing.T) { + // create a reflective wrapper over the struct to look through + type Common struct { + Name string + } + + type Strct struct { + Common // Name field added by composition + Username string `survey:"name"` + } + + val := reflect.ValueOf(Strct{ + Common: Common{Name: "Ralf"}, + Username: "Jack", + }) + + // find the field matching "name" + field, fieldType, err := findField(val, "name") + // if something went wrong + if err != nil { + // the test failed + t.Error(err.Error()) + return + } // make sure we got the right value - if val.Type().Field(fieldIndex).Name != "Username" { + if field.Interface() != "Jack" { + // the test failed + t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface()) + } + + // make sure we got the right fieldType + if fieldType.Name != "Username" { // the test failed - t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name) + t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name) } }