diff --git a/copier.go b/copier.go index 43a14f1..638bc86 100644 --- a/copier.go +++ b/copier.go @@ -136,6 +136,15 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) return ErrInvalidCopyFrom } + cacheToValue := indirect(reflect.New(to.Type())) + cacheToValue.Set(to) + defer func() { + // if err occur, toValue needs to recover to init state. + if err != nil { + to.Set(cacheToValue) + } + }() + fromType, isPtrFrom := indirectType(from.Type()) toType, _ := indirectType(to.Type()) @@ -241,13 +250,6 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) return } - if len(converters) > 0 { - if ok, e := set(to, from, opt.DeepCopy, converters); e == nil && ok { - // converter supported - return - } - } - if from.Kind() == reflect.Slice || to.Kind() == reflect.Slice { isSlice = true if from.Kind() == reflect.Slice { @@ -272,24 +274,10 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) dest = indirect(to) } + isSet := false if len(converters) > 0 { - if ok, e := set(dest, source, opt.DeepCopy, converters); e == nil && ok { - if isSlice { - // FIXME: maybe should check the other types? - if to.Type().Elem().Kind() == reflect.Ptr { - to.Index(i).Set(dest.Addr()) - } else { - if to.Len() < i+1 { - reflect.Append(to, dest) - } else { - to.Index(i).Set(dest) - } - } - } else { - to.Set(dest) - } - - continue + if isSet, err = set(dest, source, opt.DeepCopy, converters); err != nil { + return err } } @@ -307,7 +295,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) } // check source - if source.IsValid() { + if source.IsValid() && !isSet { copyUnexportedStructFields(dest, source) // Copy from source field to dest field or method diff --git a/copier_issue170_test.go b/copier_issue170_test.go index f97278b..5e370b0 100644 --- a/copier_issue170_test.go +++ b/copier_issue170_test.go @@ -1,9 +1,10 @@ package copier_test import ( - "github.com/jinzhu/copier" "reflect" "testing" + + "github.com/jinzhu/copier" ) type A struct { diff --git a/copier_test.go b/copier_test.go index 4769eba..fd085da 100644 --- a/copier_test.go +++ b/copier_test.go @@ -1665,7 +1665,6 @@ func TestDeepCopyAnonymousFieldTime(t *testing.T) { } func TestSqlNullFiled(t *testing.T) { - type sqlStruct struct { MkId sql.NullInt64 MkExpiryDateType sql.NullInt32 @@ -1762,3 +1761,70 @@ func TestNestedNilPointerStruct(t *testing.T) { t.Errorf("to (%v) value should equal from (%v) value", to.Title, from.Title) } } + +func TestOccurErr(t *testing.T) { + t.Run("CopyWithOption err occur", func(t *testing.T) { + type srcTags struct { + Field string + Index int + } + type destTags struct { + Field string + Index string + } + + dst := &destTags{ + Field: "init", + Index: "0", + } + src := &srcTags{ + Field: "copied", + Index: 1, + } + err := copier.CopyWithOption(dst, src, copier.Option{ + Converters: []copier.TypeConverter{ + { + SrcType: 1, + DstType: "", + Fn: func(src interface{}) (dst interface{}, err error) { + return nil, fmt.Errorf("return err") + }, + }, + }, + }) + if err == nil { + t.Errorf("should return err") + } + if dst.Field != "init" || dst.Index != "0" { + t.Error("when err occur, the dst should be init") + } + + }) + t.Run("copy err occur", func(t *testing.T) { + type srcTags struct { + field string + Field2 string + } + + type destTags struct { + Field string `copier:"field"` + Field2 string `copier:"Field2"` + } + + dst := &destTags{ + Field: "init", + Field2: "init2", + } + src := &srcTags{ + field: "Field1->Field1", + Field2: "Field2->Field2", + } + err := copier.Copy(dst, src) + if err == nil { + t.Errorf("should return err") + } + if dst.Field != "init" || dst.Field2 != "init2" { + t.Error("when err occur, the dst should be init") + } + }) +}