diff --git a/tag.go b/tag.go index 9e3ad4b..734f5b9 100644 --- a/tag.go +++ b/tag.go @@ -25,6 +25,7 @@ type structTag struct { name string size uint64 omitEmpty bool + stop bool } // ErrEmptyTag means that a tag string has empty item. @@ -39,23 +40,24 @@ func parseTag(rawtag string) (*structTag, error) { ts := strings.Split(rawtag, ",") for i, t := range ts { + if i == 0 { + tag.name = t + continue + } kv := strings.SplitN(t, "=", 2) - if len(kv) == 1 { - if i == 0 { - tag.name = kv[0] - } else { - switch kv[0] { - case "": - return nil, ErrEmptyTag - case "omitempty": - tag.omitEmpty = true - case "inf": - os.Stderr.WriteString("Deprecated: \"inf\" tag is replaced by \"size=unknown\"\n") - tag.size = SizeUnknown - default: - return nil, wrapErrorf(ErrInvalidTag, "parsing \"%s\"", t) - } + switch kv[0] { + case "": + return nil, ErrEmptyTag + case "omitempty": + tag.omitEmpty = true + case "inf": + os.Stderr.WriteString("Deprecated: \"inf\" tag is replaced by \"size=unknown\"\n") + tag.size = SizeUnknown + case "stop": + tag.stop = true + default: + return nil, wrapErrorf(ErrInvalidTag, "parsing \"%s\"", t) } continue } diff --git a/unmarshal.go b/unmarshal.go index 1a7c197..07c4c20 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -19,7 +19,6 @@ import ( "errors" "io" "reflect" - "strings" ) // ErrUnknownElement means that a decoded element is not known. @@ -90,29 +89,25 @@ func (vd *valueDecoder) readElement(r0 io.Reader, n int64, vo reflect.Value, dep switch vo.Kind() { case reflect.Struct: for i := 0; i < vo.NumField(); i++ { - var nn []string - if n, ok := vo.Type().Field(i).Tag.Lookup("ebml"); ok { - nn = strings.Split(n, ",") + f := fieldDef{ + v: vo.Field(i), } var name string - if len(nn) > 0 && len(nn[0]) > 0 { - name = nn[0] - } else { + if n, ok := vo.Type().Field(i).Tag.Lookup("ebml"); ok { + t, err := parseTag(n) + if err != nil { + return nil, err + } + name = t.name + f.stop = t.stop + } + if name == "" { name = vo.Type().Field(i).Name } t, err := ElementTypeFromString(name) if err != nil { return nil, err } - f := fieldDef{ - v: vo.Field(i), - } - for i := 1; i < len(nn); i++ { - switch nn[i] { - case "stop": - f.stop = true - } - } fieldMap[t] = f } case reflect.Map: diff --git a/unmarshal_test.go b/unmarshal_test.go index 6f2d5ac..f561ca7 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -450,6 +450,15 @@ func TestUnmarshal_Error(t *testing.T) { t.Errorf("Expected error: '%v', got: '%v'", ErrUnknownElementName, err) } }) + t.Run("InvalidTag", func(t *testing.T) { + input := &struct { + Header struct { + } `ebml:"EBML,ivalid"` + }{} + if err := Unmarshal(bytes.NewBuffer([]byte{}), input); !errs.Is(err, ErrInvalidTag) { + t.Errorf("Expected error: '%v', got: '%v'", ErrInvalidTag, err) + } + }) t.Run("UnknownElement", func(t *testing.T) { input := &TestEBML{} b := []byte{0x81}