From 817f998d49546b66d5ad55613757a82e67b27f60 Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Sat, 23 Mar 2024 22:18:25 +0000 Subject: [PATCH] update: add splitter tests (#21) * update: add splitter tests * fix: bugfixing sadness Signed-off-by: Milos Gajdos --------- Signed-off-by: Milos Gajdos --- document/text/character.go | 23 +++---- document/text/character_test.go | 110 ++++++++++++++++++++++++++++++++ document/text/recursive.go | 34 +++++----- document/text/recursive_test.go | 76 ++++++++++++++++++++++ document/text/splitter.go | 88 ++++++++++++++++++------- document/text/text.go | 25 ++++++-- 6 files changed, 295 insertions(+), 61 deletions(-) create mode 100644 document/text/character_test.go create mode 100644 document/text/recursive_test.go diff --git a/document/text/character.go b/document/text/character.go index 8f183fe..66becfc 100644 --- a/document/text/character.go +++ b/document/text/character.go @@ -10,8 +10,7 @@ import ( // or a regular expression. type CharSplitter struct { *Splitter - sep string - isSepRegex bool + sep Sep } // NewSplitter creates a new splitter @@ -30,22 +29,20 @@ func (s *CharSplitter) WithSplitter(splitter *Splitter) *CharSplitter { } // WithSep sets the separator. -func (s *CharSplitter) WithSep(sep string, isSepRegex bool) *CharSplitter { +func (s *CharSplitter) WithSep(sep Sep) *CharSplitter { s.sep = sep - s.isSepRegex = isSepRegex - return nil + return s } // Split splits text into chunks. func (s *CharSplitter) Split(text string) []string { - sep := s.sep - if !s.isSepRegex { - sep = regexp.QuoteMeta(s.sep) + sep := Sep{Value: s.sep.Value, IsRegexp: s.sep.IsRegexp} + if !sep.IsRegexp { + sep.Value = regexp.QuoteMeta(sep.Value) } - chunks := s.splitText(text, sep) - sep = "" - if !s.keepSep { - sep = s.sep + splits := s.Splitter.Split(text, sep) + if s.keepSep { + sep.Value = "" } - return s.merge(chunks, sep) + return s.merge(splits, sep) } diff --git a/document/text/character_test.go b/document/text/character_test.go new file mode 100644 index 0000000..7b89cdc --- /dev/null +++ b/document/text/character_test.go @@ -0,0 +1,110 @@ +package text + +import ( + "fmt" + "reflect" + "testing" +) + +func TestCharSplitter(t *testing.T) { + t.Parallel() + var testCases = []struct { + size int + overlap int + trim bool + keepSep bool + sep Sep + input string + exp []string + }{ + { + size: 7, + overlap: 3, + sep: Sep{Value: " "}, + input: "foo bar baz 123", + exp: []string{"foo bar", "bar baz", "baz 123"}, + }, + { + size: 2, + overlap: 0, + sep: Sep{Value: " "}, + input: "foo bar", + exp: []string{"foo", "bar"}, + }, + { + size: 3, + overlap: 1, + sep: Sep{Value: " "}, + input: "foo bar baz a a", + exp: []string{"foo", "bar", "baz", "a a"}, + }, + { + size: 3, + overlap: 1, + sep: Sep{Value: " "}, + input: "a a foo bar baz", + exp: []string{"a a", "foo", "bar", "baz"}, + }, + { + size: 1, + overlap: 1, + sep: Sep{Value: " "}, + input: "foo bar baz 123", + exp: []string{"foo", "bar", "baz", "123"}, + }, + { + size: 1, + overlap: 0, + keepSep: true, + sep: Sep{Value: ".", IsRegexp: false}, + input: "foo.bar.baz.123", + exp: []string{"foo", ".bar", ".baz", ".123"}, + }, + { + size: 1, + overlap: 0, + keepSep: true, + sep: Sep{Value: `\.`, IsRegexp: true}, + input: "foo.bar.baz.123", + exp: []string{"foo", ".bar", ".baz", ".123"}, + }, + { + size: 1, + overlap: 0, + sep: Sep{Value: ".", IsRegexp: false}, + input: "foo.bar.baz.123", + exp: []string{"foo", "bar", "baz", "123"}, + }, + { + size: 1, + overlap: 0, + sep: Sep{Value: `\.`, IsRegexp: true}, + input: "foo.bar.baz.123", + exp: []string{"foo", "bar", "baz", "123"}, + }, + } + + for _, tc := range testCases { + tc := tc + s := NewSplitterWithConfig(Config{ + ChunkSize: tc.size, + ChunkOverlap: tc.overlap, + TrimSpace: tc.trim, + KeepSep: tc.keepSep, + LenFunc: DefaultLenFunc, + }) + cs := NewCharSplitter(). + WithSplitter(s). + WithSep(tc.sep) + + t.Run(fmt.Sprintf("sep=%#v,size=%d,overlap=%d,trim=%v,keepSep=%v", + tc.sep, tc.size, tc.overlap, tc.trim, tc.keepSep), + func(t *testing.T) { + t.Parallel() + splits := cs.Split(tc.input) + if !reflect.DeepEqual(splits, tc.exp) { + t.Errorf("expected: %#v, got: %#v", tc.exp, splits) + } + }) + } +} diff --git a/document/text/recursive.go b/document/text/recursive.go index c6efbf6..c4b1545 100644 --- a/document/text/recursive.go +++ b/document/text/recursive.go @@ -10,8 +10,7 @@ import ( // separators to find one that works. type RecursiveCharSplitter struct { *Splitter - seps []string - isSepRegex bool + seps []Sep } // NewSplitter creates a new splitter and returns it. @@ -22,36 +21,35 @@ func NewRecursiveCharSplitter() *RecursiveCharSplitter { } } -// WithSplitter sets the splitter +// WithSplitter sets the splitter. func (r *RecursiveCharSplitter) WithSplitter(splitter *Splitter) *RecursiveCharSplitter { r.Splitter = splitter return r } -// WithSeps sets separators -func (r *RecursiveCharSplitter) WithSeps(seps []string, isSepRegex bool) *RecursiveCharSplitter { +// WithSeps sets separators. +func (r *RecursiveCharSplitter) WithSeps(seps []Sep) *RecursiveCharSplitter { r.seps = seps - r.isSepRegex = isSepRegex - return nil + return r } -func (r *RecursiveCharSplitter) split(text string, seps []string) []string { +func (r *RecursiveCharSplitter) split(text string, seps []Sep) []string { var ( resChunks []string - newSeps []string + newSeps []Sep ) sep := seps[len(seps)-1] for i, s := range seps { - if !r.isSepRegex { - s = regexp.QuoteMeta(s) + if !s.IsRegexp { + s.Value = regexp.QuoteMeta(s.Value) } - if s == "" { + if s.Value == "" { sep = s break } - if match, _ := regexp.MatchString(s, text); match { + if match, _ := regexp.MatchString(s.Value, text); match { sep = s newSeps = seps[i+1:] break @@ -59,16 +57,16 @@ func (r *RecursiveCharSplitter) split(text string, seps []string) []string { } // TODO should we escape again? Seems weird. - newSep := sep - if !r.isSepRegex { - newSep = regexp.QuoteMeta(sep) + newSep := Sep{Value: sep.Value, IsRegexp: sep.IsRegexp} + if !sep.IsRegexp { + newSep.Value = regexp.QuoteMeta(sep.Value) } - chunks := r.splitText(text, newSep) + chunks := r.Splitter.Split(text, newSep) var goodChunks []string if r.keepSep { - newSep = "" + newSep.Value = "" } for _, chunk := range chunks { diff --git a/document/text/recursive_test.go b/document/text/recursive_test.go new file mode 100644 index 0000000..4a4d8a6 --- /dev/null +++ b/document/text/recursive_test.go @@ -0,0 +1,76 @@ +package text + +import ( + "fmt" + "reflect" + "testing" +) + +func TestRecursiveCharSplitter(t *testing.T) { + t.Parallel() + var testCases = []struct { + size int + overlap int + trim bool + keepSep bool + seps []Sep + input string + exp []string + }{ + { + size: 10, + overlap: 1, + trim: true, + keepSep: true, + seps: DefaultSeparators, + input: `Hi.` + "\n\n" + `I'm Harrison.` + "\n\n" + `How? Are? You?` + "\n" + `Okay then f f f f. +This is a weird text to write, but gotta test the splittingggg some how. + +Bye!` + "\n\n" + `-H.`, + exp: []string{ + "Hi.", + "I'm", + "Harrison.", + "How? Are?", + "You?", + "Okay then", + "f f f f.", + "This is a", + "weird", + "text to", + "write,", + "but gotta", + "test the", + "splitting", + "gggg", + "some how.", + "Bye!", + "-H.", + }, + }, + } + + for _, tc := range testCases { + tc := tc + s := NewSplitterWithConfig(Config{ + ChunkSize: tc.size, + ChunkOverlap: tc.overlap, + TrimSpace: tc.trim, + KeepSep: tc.keepSep, + LenFunc: DefaultLenFunc, + }) + cs := NewRecursiveCharSplitter(). + WithSplitter(s). + WithSeps(tc.seps) + + t.Run(fmt.Sprintf("sep=%#v,size=%d,overlap=%d,trim=%v,keepSep=%v", + tc.seps, tc.size, tc.overlap, tc.trim, tc.keepSep), + func(t *testing.T) { + t.Parallel() + splits := cs.Split(tc.input) + if !reflect.DeepEqual(splits, tc.exp) { + t.Errorf("expected: %#v, got: %#v", tc.exp, splits) + } + }) + } +} diff --git a/document/text/splitter.go b/document/text/splitter.go index 1f02a1f..02fed14 100644 --- a/document/text/splitter.go +++ b/document/text/splitter.go @@ -1,8 +1,10 @@ package text import ( + "errors" "log" "regexp" + "regexp/syntax" "strings" ) @@ -16,8 +18,8 @@ type Splitter struct { } // Config configures the splitter -// NOTE: this is used to prevent situation -// where values in constructors accideentally +// NOTE: this is used to prevent situations +// where values in constructors accidentally // mix the order of parameters of the same type // leading to unpredicable behaviour. type Config struct { @@ -36,6 +38,7 @@ func NewSplitter() *Splitter { chunkSize: DefaultChunkSize, chunkOverlap: DefaultChunkOverlap, lenFunc: DefaultLenFunc, + trimSpace: true, } } @@ -90,23 +93,23 @@ func (s *Splitter) join(chunks []string, sep string) string { return text } -// merge merges chunks over the given separator and returns +// merge merges splits over the given separator and returns // the new slice of chunks taking into account chunk overlap. // It ignores empty string chunks and warns if a chunk is generated // that exceeds the set chunk size. -func (s *Splitter) merge(chunks []string, sep string) []string { +func (s *Splitter) merge(splits []string, sep Sep) []string { // nolint:prealloc var ( // resulting chunk slice - resChunks []string + chunks []string // buffer of chunks chunkBuffer []string ) totalChunksLen := 0 - sepLen := s.lenFunc(sep) + sepLen := s.lenFunc(sep.Value) - for _, chunk := range chunks { + for _, chunk := range splits { if chunk == "" { continue } @@ -115,12 +118,12 @@ func (s *Splitter) merge(chunks []string, sep string) []string { // if it does and if the buffer contains any chunks, we'll pop them add them into resulting chunk set. if totalChunksLen+splitLen+(sepLen*boolToInt(len(chunkBuffer) > 0)) > s.chunkSize { if totalChunksLen > s.chunkSize { - log.Printf("Created a chunk of size %d, which is longer than the requested %d\n", totalChunksLen, s.chunkSize) + log.Printf("created chunk is longer (%d) than requested: %d\n", totalChunksLen, s.chunkSize) } if len(chunkBuffer) > 0 { - doc := s.join(chunkBuffer, sep) + doc := s.join(chunkBuffer, sep.Value) if doc != "" { - resChunks = append(resChunks, doc) + chunks = append(chunks, doc) } // Keep on popping chunks from the bffer if: // - we have a larger chunk than in the chunk overlap @@ -138,36 +141,43 @@ func (s *Splitter) merge(chunks []string, sep string) []string { totalChunksLen += splitLen + sepLen*boolToInt(len(chunkBuffer) > 1) } - chunk := s.join(chunkBuffer, sep) + chunk := s.join(chunkBuffer, sep.Value) if chunk != "" { - resChunks = append(resChunks, chunk) + chunks = append(chunks, chunk) } - return resChunks + return chunks } -// splitText splits the text over a separator optionally keeping +// Split splits the text over a separator optionally keeping // the separator and returns the the chunks in a slice. // If the separator is empty string it splits on individual characters. -func (s *Splitter) splitText(text string, sep string) []string { - if sep != "" { +// TODO: rename this to Split +func (s *Splitter) Split(text string, sep Sep) []string { + if sep.Value != "" { if s.keepSep { + // NOTE: we must do this to unescape + // the escaped separator so we keep the raw separator. + sepVal, err := unquoteMeta(sep.Value) + if err != nil { + panic(err) + } + var results []string - re := regexp.MustCompile("(" + sep + ")") - splits := re.Split(text, -1) + splits := regexp.MustCompile("("+sep.Value+")").Split(text, -1) + // NOTE: we start iterating from 1, not 0! for i := 1; i < len(splits); i++ { // make sure the separator remains in the result split // because Go reasons: https://github.com/golang/go/issues/18868 - results = append(results, sep+splits[i]) + results = append(results, sepVal+splits[i]) } results = append([]string{splits[0]}, results...) - return results + return filterEmptyStrings(results) } - re := regexp.MustCompile(sep) - return re.Split(text, -1) + return filterEmptyStrings(regexp.MustCompile(sep.Value).Split(text, -1)) } // If separator is empty, split into individual characters. - return strings.Split(text, "") + return filterEmptyStrings(strings.Split(text, "")) } // boolToInt returns 1 if b is true @@ -178,3 +188,35 @@ func boolToInt(b bool) int { } return 0 } + +// filterEmptyStrings removes empty strings from a slice of strings. +func filterEmptyStrings(slice []string) []string { + count := 0 + for _, s := range slice { + if s != "" { + count++ + } + } + + result := make([]string, 0, count) + + for _, s := range slice { + if s != "" { + result = append(result, s) + } + } + + return result +} + +// unQuote regexp string. +func unquoteMeta(s string) (string, error) { + r, err := syntax.Parse(s, 0) + if err != nil { + return "", err + } + if r.Op != syntax.OpLiteral { + return "", errors.New("not a quoted meta") + } + return string(r.Rune), nil +} diff --git a/document/text/text.go b/document/text/text.go index 6e633f4..452ca10 100644 --- a/document/text/text.go +++ b/document/text/text.go @@ -5,9 +5,6 @@ import ( ) const ( - // DefaultSeparator is default text separator. - // It's intention is to splitt by paragraphs. - DefaultSeparator = "\n\n" // DefaultChunkSize is default chunk size. DefaultChunkSize = 1 // DefaultChunkOverlap is default chunk overlap. @@ -21,11 +18,25 @@ var ( // StringBytesLenFunc counts number of bytes in a string. // Faster for some documents, but less accurate for multiling. StringBytesLenFunc = func(s string) int { return len(s) } - // DefaultSeparators are used in RecursiveSplitter. - // The splitter recursively keeps splitting document - // using the separators until done. - DefaultSeparators = []string{"\n\n", "\n", " ", ""} + // DefaultSeparator is default text separator. + // Its intention is to splitt by paragraphs. + DefaultSeparator = Sep{Value: "\n\n"} + // DefaultSeparators are used in RecursiveCharSplitter. + // RecursiveCharSplitter keeps splitting document + // recursively using the separators until done. + DefaultSeparators = []Sep{ + {Value: "\n\n"}, + {Value: "\n"}, + {Value: " "}, + {Value: ""}, + } ) +// Sep is a text separator. +type Sep struct { + Value string + IsRegexp bool +} + // LenFunc is used for funcs that calculate string lengths. type LenFunc func(s string) int