From 0053d8c8d63c6f244131c773324eb62327a6af31 Mon Sep 17 00:00:00 2001 From: tom twinkle <47764757+tomtwinkle@users.noreply.github.com> Date: Mon, 1 Mar 2021 23:06:59 +0900 Subject: [PATCH] copy field name tags --- copier.go | 137 +++++++++++++++++++++++++++++++++++++------- copier_tags_test.go | 56 ++++++++++++++++++ 2 files changed, 172 insertions(+), 21 deletions(-) diff --git a/copier.go b/copier.go index 723f1d1..13e1369 100644 --- a/copier.go +++ b/copier.go @@ -3,9 +3,11 @@ package copier import ( "database/sql" "database/sql/driver" + "errors" "fmt" "reflect" "strings" + "unicode" ) // These flags define options for tag handling @@ -32,6 +34,19 @@ type Option struct { DeepCopy bool } +// Tag Flags +type Flags struct { + BitFlags map[string]uint8 + SrcNames TagNameMapping + DestNames TagNameMapping +} + +// Field Tag name mapping +type TagNameMapping struct { + FieldNameToTag map[string]string + TagToFieldName map[string]string +} + // Copy copy things func Copy(toValue interface{}, fromValue interface{}) (err error) { return copier(toValue, fromValue, Option{}) @@ -170,9 +185,9 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) } // Get tag options - tagBitFlags := map[string]uint8{} - if dest.IsValid() { - tagBitFlags = getBitFlags(toType) + flags, err := getFlags(dest, source, toType, fromType) + if err != nil { + return err } // check source @@ -183,17 +198,18 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) name := field.Name // Get bit flags for field - fieldFlags, _ := tagBitFlags[name] + fieldFlags, _ := flags.BitFlags[name] // Check if we should ignore copying if (fieldFlags & tagIgnore) != 0 { continue } - if fromField := source.FieldByName(name); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) { + srcFieldName, destFieldName := getFieldName(name, flags) + if fromField := source.FieldByName(srcFieldName); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) { // process for nested anonymous field destFieldNotSet := false - if f, ok := dest.Type().FieldByName(name); ok { + if f, ok := dest.Type().FieldByName(destFieldName); ok { for idx, x := range f.Index { if x >= dest.NumField() { continue @@ -222,7 +238,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) break } - toField := dest.FieldByName(name) + toField := dest.FieldByName(destFieldName) if toField.IsValid() { if toField.CanSet() { if !set(toField, fromField, opt.DeepCopy) { @@ -232,7 +248,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) } else { if fieldFlags != 0 { // Note that a copy was made - tagBitFlags[name] = fieldFlags | hasCopied + flags.BitFlags[name] = fieldFlags | hasCopied } } } @@ -240,9 +256,9 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) // try to set to method var toMethod reflect.Value if dest.CanAddr() { - toMethod = dest.Addr().MethodByName(name) + toMethod = dest.Addr().MethodByName(destFieldName) } else { - toMethod = dest.MethodByName(name) + toMethod = dest.MethodByName(destFieldName) } if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) { @@ -255,16 +271,17 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) // Copy from from method to dest field for _, field := range deepFields(toType) { name := field.Name + srcFieldName, destFieldName := getFieldName(name, flags) var fromMethod reflect.Value if source.CanAddr() { - fromMethod = source.Addr().MethodByName(name) + fromMethod = source.Addr().MethodByName(srcFieldName) } else { - fromMethod = source.MethodByName(name) + fromMethod = source.MethodByName(srcFieldName) } if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) { - if toField := dest.FieldByName(name); toField.IsValid() && toField.CanSet() { + if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() { values := fromMethod.Call([]reflect.Value{}) if len(values) >= 1 { set(toField, values[0], opt.DeepCopy) @@ -284,7 +301,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) to.Set(dest) } - err = checkBitFlags(tagBitFlags) + err = checkBitFlags(flags.BitFlags) } return @@ -416,7 +433,7 @@ func set(to, from reflect.Value, deepCopy bool) bool { } // parseTags Parses struct tags and returns uint8 bit flags. -func parseTags(tag string) (flags uint8) { +func parseTags(tag string) (flags uint8, name string, err error) { for _, t := range strings.Split(tag, ",") { switch t { case "-": @@ -426,24 +443,68 @@ func parseTags(tag string) (flags uint8) { flags = flags | tagMust case "nopanic": flags = flags | tagNoPanic + default: + if unicode.IsUpper([]rune(t)[0]) { + name = strings.TrimSpace(t) + } else { + err = errors.New("copier field name tag must be start upper case") + } } } return } -// getBitFlags Parses struct tags for bit flags. -func getBitFlags(toType reflect.Type) map[string]uint8 { - flags := map[string]uint8{} - toTypeFields := deepFields(toType) +// getTagFlags Parses struct tags for bit flags, field name. +func getFlags(dest, src reflect.Value, toType, fromType reflect.Type) (Flags, error) { + flags := Flags{ + BitFlags: map[string]uint8{}, + SrcNames: TagNameMapping{ + FieldNameToTag: map[string]string{}, + TagToFieldName: map[string]string{}, + }, + DestNames: TagNameMapping{ + FieldNameToTag: map[string]string{}, + TagToFieldName: map[string]string{}, + }, + } + var toTypeFields, fromTypeFields []reflect.StructField + if dest.IsValid() { + toTypeFields = deepFields(toType) + } + if src.IsValid() { + fromTypeFields = deepFields(fromType) + } // Get a list dest of tags for _, field := range toTypeFields { tags := field.Tag.Get("copier") if tags != "" { - flags[field.Name] = parseTags(tags) + var name string + var err error + if flags.BitFlags[field.Name], name, err = parseTags(tags); err != nil { + return Flags{}, err + } else if name != "" { + flags.DestNames.FieldNameToTag[field.Name] = name + flags.DestNames.TagToFieldName[name] = field.Name + } } } - return flags + + // Get a list source of tags + for _, field := range fromTypeFields { + tags := field.Tag.Get("copier") + if tags != "" { + var name string + var err error + if _, name, err = parseTags(tags); err != nil { + return Flags{}, err + } else if name != "" { + flags.SrcNames.FieldNameToTag[field.Name] = name + flags.SrcNames.TagToFieldName[name] = field.Name + } + } + } + return flags, nil } // checkBitFlags Checks flags for error or panic conditions. @@ -463,6 +524,40 @@ func checkBitFlags(flagsList map[string]uint8) (err error) { return } +func getFieldName(fieldName string, flags Flags) (srcFieldName string, destFieldName string) { + // get dest field name + if srcTagName, ok := flags.SrcNames.FieldNameToTag[fieldName]; ok { + destFieldName = srcTagName + if destTagName, ok := flags.DestNames.TagToFieldName[srcTagName]; ok { + destFieldName = destTagName + } + } else { + if destTagName, ok := flags.DestNames.TagToFieldName[fieldName]; ok { + destFieldName = destTagName + } + } + if destFieldName == "" { + destFieldName = fieldName + } + + // get source field name + if destTagName, ok := flags.DestNames.FieldNameToTag[fieldName]; ok { + srcFieldName = destTagName + if srcField, ok := flags.SrcNames.TagToFieldName[destTagName]; ok { + srcFieldName = srcField + } + } else { + if srcField, ok := flags.SrcNames.TagToFieldName[fieldName]; ok { + srcFieldName = srcField + } + } + + if srcFieldName == "" { + srcFieldName = fieldName + } + return +} + func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) { if !v.CanAddr() { diff --git a/copier_tags_test.go b/copier_tags_test.go index c6214f0..c8fe704 100644 --- a/copier_tags_test.go +++ b/copier_tags_test.go @@ -48,3 +48,59 @@ func TestCopyTagMust(t *testing.T) { }() copier.Copy(employee, user) } + +func TestCopyTagFieldName(t *testing.T) { + t.Run("another name field copy", func(t *testing.T) { + type SrcTags struct { + FieldA string + FieldB string `copier:"Field2"` + FieldC string `copier:"FieldTagMatch"` + } + + type DestTags struct { + Field1 string `copier:"FieldA"` + Field2 string + Field3 string `copier:"FieldTagMatch"` + } + + dst := &DestTags{} + src := &SrcTags{ + FieldA: "FieldA->Field1", + FieldB: "FieldB->Field2", + FieldC: "FieldC->Field3", + } + err := copier.Copy(dst, src) + if err != nil { + t.Fatal(err) + } + + if dst.Field1 != src.FieldA { + t.Error("Field1 no copy") + } + if dst.Field2 != src.FieldB { + t.Error("Field2 no copy") + } + if dst.Field3 != src.FieldC { + t.Error("Field3 no copy") + } + }) + + t.Run("validate error flag name", func(t *testing.T) { + type SrcTags struct { + field string + } + + type DestTags struct { + Field1 string `copier:"field"` + } + + dst := &DestTags{} + src := &SrcTags{ + field: "field->Field1", + } + err := copier.Copy(dst, src) + if err == nil { + t.Fatal("must validate error") + } + }) +}