Skip to content

Commit

Permalink
copy field name tags
Browse files Browse the repository at this point in the history
  • Loading branch information
tomtwinkle committed Mar 1, 2021
1 parent db116d0 commit 0053d8c
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 21 deletions.
137 changes: 116 additions & 21 deletions copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package copier
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
"unicode"
)

// These flags define options for tag handling
Expand All @@ -32,6 +34,19 @@ type Option struct {
DeepCopy bool
}

// Tag Flags
type Flags struct {
BitFlags map[string]uint8
SrcNames TagNameMapping
DestNames TagNameMapping
}

// Field Tag name mapping
type TagNameMapping struct {
FieldNameToTag map[string]string
TagToFieldName map[string]string
}

// Copy copy things
func Copy(toValue interface{}, fromValue interface{}) (err error) {
return copier(toValue, fromValue, Option{})
Expand Down Expand Up @@ -170,9 +185,9 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
}

// Get tag options
tagBitFlags := map[string]uint8{}
if dest.IsValid() {
tagBitFlags = getBitFlags(toType)
flags, err := getFlags(dest, source, toType, fromType)
if err != nil {
return err
}

// check source
Expand All @@ -183,17 +198,18 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
name := field.Name

// Get bit flags for field
fieldFlags, _ := tagBitFlags[name]
fieldFlags, _ := flags.BitFlags[name]

// Check if we should ignore copying
if (fieldFlags & tagIgnore) != 0 {
continue
}

if fromField := source.FieldByName(name); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) {
srcFieldName, destFieldName := getFieldName(name, flags)
if fromField := source.FieldByName(srcFieldName); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) {
// process for nested anonymous field
destFieldNotSet := false
if f, ok := dest.Type().FieldByName(name); ok {
if f, ok := dest.Type().FieldByName(destFieldName); ok {
for idx, x := range f.Index {
if x >= dest.NumField() {
continue
Expand Down Expand Up @@ -222,7 +238,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
break
}

toField := dest.FieldByName(name)
toField := dest.FieldByName(destFieldName)
if toField.IsValid() {
if toField.CanSet() {
if !set(toField, fromField, opt.DeepCopy) {
Expand All @@ -232,17 +248,17 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
} else {
if fieldFlags != 0 {
// Note that a copy was made
tagBitFlags[name] = fieldFlags | hasCopied
flags.BitFlags[name] = fieldFlags | hasCopied
}
}
}
} else {
// try to set to method
var toMethod reflect.Value
if dest.CanAddr() {
toMethod = dest.Addr().MethodByName(name)
toMethod = dest.Addr().MethodByName(destFieldName)
} else {
toMethod = dest.MethodByName(name)
toMethod = dest.MethodByName(destFieldName)
}

if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) {
Expand All @@ -255,16 +271,17 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
// Copy from from method to dest field
for _, field := range deepFields(toType) {
name := field.Name
srcFieldName, destFieldName := getFieldName(name, flags)

var fromMethod reflect.Value
if source.CanAddr() {
fromMethod = source.Addr().MethodByName(name)
fromMethod = source.Addr().MethodByName(srcFieldName)
} else {
fromMethod = source.MethodByName(name)
fromMethod = source.MethodByName(srcFieldName)
}

if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) {
if toField := dest.FieldByName(name); toField.IsValid() && toField.CanSet() {
if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() {
values := fromMethod.Call([]reflect.Value{})
if len(values) >= 1 {
set(toField, values[0], opt.DeepCopy)
Expand All @@ -284,7 +301,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
to.Set(dest)
}

err = checkBitFlags(tagBitFlags)
err = checkBitFlags(flags.BitFlags)
}

return
Expand Down Expand Up @@ -416,7 +433,7 @@ func set(to, from reflect.Value, deepCopy bool) bool {
}

// parseTags Parses struct tags and returns uint8 bit flags.
func parseTags(tag string) (flags uint8) {
func parseTags(tag string) (flags uint8, name string, err error) {
for _, t := range strings.Split(tag, ",") {
switch t {
case "-":
Expand All @@ -426,24 +443,68 @@ func parseTags(tag string) (flags uint8) {
flags = flags | tagMust
case "nopanic":
flags = flags | tagNoPanic
default:
if unicode.IsUpper([]rune(t)[0]) {
name = strings.TrimSpace(t)
} else {
err = errors.New("copier field name tag must be start upper case")
}
}
}
return
}

// getBitFlags Parses struct tags for bit flags.
func getBitFlags(toType reflect.Type) map[string]uint8 {
flags := map[string]uint8{}
toTypeFields := deepFields(toType)
// getTagFlags Parses struct tags for bit flags, field name.
func getFlags(dest, src reflect.Value, toType, fromType reflect.Type) (Flags, error) {
flags := Flags{
BitFlags: map[string]uint8{},
SrcNames: TagNameMapping{
FieldNameToTag: map[string]string{},
TagToFieldName: map[string]string{},
},
DestNames: TagNameMapping{
FieldNameToTag: map[string]string{},
TagToFieldName: map[string]string{},
},
}
var toTypeFields, fromTypeFields []reflect.StructField
if dest.IsValid() {
toTypeFields = deepFields(toType)
}
if src.IsValid() {
fromTypeFields = deepFields(fromType)
}

// Get a list dest of tags
for _, field := range toTypeFields {
tags := field.Tag.Get("copier")
if tags != "" {
flags[field.Name] = parseTags(tags)
var name string
var err error
if flags.BitFlags[field.Name], name, err = parseTags(tags); err != nil {
return Flags{}, err
} else if name != "" {
flags.DestNames.FieldNameToTag[field.Name] = name
flags.DestNames.TagToFieldName[name] = field.Name
}
}
}
return flags

// Get a list source of tags
for _, field := range fromTypeFields {
tags := field.Tag.Get("copier")
if tags != "" {
var name string
var err error
if _, name, err = parseTags(tags); err != nil {
return Flags{}, err
} else if name != "" {
flags.SrcNames.FieldNameToTag[field.Name] = name
flags.SrcNames.TagToFieldName[name] = field.Name
}
}
}
return flags, nil
}

// checkBitFlags Checks flags for error or panic conditions.
Expand All @@ -463,6 +524,40 @@ func checkBitFlags(flagsList map[string]uint8) (err error) {
return
}

func getFieldName(fieldName string, flags Flags) (srcFieldName string, destFieldName string) {
// get dest field name
if srcTagName, ok := flags.SrcNames.FieldNameToTag[fieldName]; ok {
destFieldName = srcTagName
if destTagName, ok := flags.DestNames.TagToFieldName[srcTagName]; ok {
destFieldName = destTagName
}
} else {
if destTagName, ok := flags.DestNames.TagToFieldName[fieldName]; ok {
destFieldName = destTagName
}
}
if destFieldName == "" {
destFieldName = fieldName
}

// get source field name
if destTagName, ok := flags.DestNames.FieldNameToTag[fieldName]; ok {
srcFieldName = destTagName
if srcField, ok := flags.SrcNames.TagToFieldName[destTagName]; ok {
srcFieldName = srcField
}
} else {
if srcField, ok := flags.SrcNames.TagToFieldName[fieldName]; ok {
srcFieldName = srcField
}
}

if srcFieldName == "" {
srcFieldName = fieldName
}
return
}

func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) {

if !v.CanAddr() {
Expand Down
56 changes: 56 additions & 0 deletions copier_tags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,59 @@ func TestCopyTagMust(t *testing.T) {
}()
copier.Copy(employee, user)
}

func TestCopyTagFieldName(t *testing.T) {
t.Run("another name field copy", func(t *testing.T) {
type SrcTags struct {
FieldA string
FieldB string `copier:"Field2"`
FieldC string `copier:"FieldTagMatch"`
}

type DestTags struct {
Field1 string `copier:"FieldA"`
Field2 string
Field3 string `copier:"FieldTagMatch"`
}

dst := &DestTags{}
src := &SrcTags{
FieldA: "FieldA->Field1",
FieldB: "FieldB->Field2",
FieldC: "FieldC->Field3",
}
err := copier.Copy(dst, src)
if err != nil {
t.Fatal(err)
}

if dst.Field1 != src.FieldA {
t.Error("Field1 no copy")
}
if dst.Field2 != src.FieldB {
t.Error("Field2 no copy")
}
if dst.Field3 != src.FieldC {
t.Error("Field3 no copy")
}
})

t.Run("validate error flag name", func(t *testing.T) {
type SrcTags struct {
field string
}

type DestTags struct {
Field1 string `copier:"field"`
}

dst := &DestTags{}
src := &SrcTags{
field: "field->Field1",
}
err := copier.Copy(dst, src)
if err == nil {
t.Fatal("must validate error")
}
})
}

0 comments on commit 0053d8c

Please sign in to comment.