Skip to content

Commit

Permalink
Source updated to use iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
rusq committed Feb 21, 2025
1 parent 5f046b0 commit 203fa0f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 62 deletions.
40 changes: 34 additions & 6 deletions internal/source/chunkdir.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"iter"
"os"
"time"

Expand Down Expand Up @@ -38,19 +39,46 @@ func NewChunkDir(d *chunk.Directory, fast bool) *ChunkDir {
// it expects for all messages for the requested file to be in the file ID.json.gz.
// If messages for the channel are scattered across multiple file, it will not
// return all of them.
func (c *ChunkDir) AllMessages(ctx context.Context, channelID string) ([]slack.Message, error) {
func (c *ChunkDir) AllMessages(ctx context.Context, channelID string) (iter.Seq2[slack.Message, error], error) {
var (
mm []slack.Message
err error
)
if c.fast {
return c.d.FastAllMessages(channelID)
mm, err = c.d.FastAllMessages(channelID)
} else {
return c.d.AllMessages(ctx, channelID)
mm, err = c.d.AllMessages(ctx, channelID)
}
if err != nil {
return nil, err
}
return toIter(mm), nil
}

func (c *ChunkDir) AllThreadMessages(ctx context.Context, channelID, threadID string) ([]slack.Message, error) {
func toIter(mm []slack.Message) iter.Seq2[slack.Message, error] {
return func(yield func(slack.Message, error) bool) {
for _, m := range mm {
if !yield(m, nil) {
return
}
}
}
}

func (c *ChunkDir) AllThreadMessages(ctx context.Context, channelID, threadID string) (iter.Seq2[slack.Message, error], error) {
var (
mm []slack.Message
err error
)
if c.fast {
return c.d.FastAllThreadMessages(channelID, threadID)
mm, err = c.d.FastAllThreadMessages(channelID, threadID)
} else {
mm, err = c.d.AllThreadMessages(ctx, channelID, threadID)
}
if err != nil {
return nil, err
}
return c.d.AllThreadMessages(ctx, channelID, threadID)
return toIter(mm), nil
}

func (c *ChunkDir) ChannelInfo(_ context.Context, channelID string) (*slack.Channel, error) {
Expand Down
5 changes: 3 additions & 2 deletions internal/source/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package source

import (
"context"
"iter"
"os"
"path/filepath"
"time"
Expand Down Expand Up @@ -69,11 +70,11 @@ func (d *Database) Users(ctx context.Context) ([]slack.User, error) {
return d.s.Users(ctx)
}

func (d *Database) AllMessages(ctx context.Context, channelID string) ([]slack.Message, error) {
func (d *Database) AllMessages(ctx context.Context, channelID string) (iter.Seq2[slack.Message, error], error) {
return d.s.AllMessages(ctx, channelID)
}

func (d *Database) AllThreadMessages(ctx context.Context, channelID, threadID string) ([]slack.Message, error) {
func (d *Database) AllThreadMessages(ctx context.Context, channelID, threadID string) (iter.Seq2[slack.Message, error], error) {
return d.s.AllThreadMessages(ctx, channelID, threadID)
}

Expand Down
18 changes: 11 additions & 7 deletions internal/source/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"io/fs"
"iter"
"log/slog"
"os"
"path"
Expand Down Expand Up @@ -96,7 +97,7 @@ func (d Dump) Users(context.Context) ([]slack.User, error) {
return u, nil
}

func (d Dump) AllMessages(_ context.Context, channelID string) ([]slack.Message, error) {
func (d Dump) AllMessages(_ context.Context, channelID string) (iter.Seq2[slack.Message, error], error) {
var cm []types.Message
c, err := unmarshalOne[types.Conversation](d.fs, d.channelFile(channelID))
if err != nil {
Expand Down Expand Up @@ -142,15 +143,18 @@ func (d Dump) threadHeadMessages(channelID string) ([]types.Message, error) {
return cm, nil
}

func convertMessages(cm []types.Message) []slack.Message {
mm := make([]slack.Message, len(cm))
for i := range cm {
mm[i] = cm[i].Message
func convertMessages(cm []types.Message) iter.Seq2[slack.Message, error] {
iterFn := func(yield func(slack.Message, error) bool) {
for _, m := range cm {
if !yield(m.Message, nil) {
return
}
}
}
return mm
return iterFn
}

func (d Dump) AllThreadMessages(_ context.Context, channelID, threadID string) ([]slack.Message, error) {
func (d Dump) AllThreadMessages(_ context.Context, channelID, threadID string) (iter.Seq2[slack.Message, error], error) {
cm, err := d.findThreadInChannel(channelID, threadID)
if err != nil {
if !os.IsNotExist(err) {
Expand Down
93 changes: 48 additions & 45 deletions internal/source/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"path"
"time"
Expand Down Expand Up @@ -88,71 +89,73 @@ func (e *Export) Type() string {
}

// AllMessages returns all channel messages without thread messages.
func (e *Export) AllMessages(_ context.Context, channelID string) ([]slack.Message, error) {
var mm []slack.Message
if err := e.walkChannelMessages(channelID, func(m *slack.Message) error {
if isThreadMessage(&m.Msg) && m.SubType != structures.SubTypeThreadBroadcast {
// removes thread messages, except the broadcast ones, which are
// included in the channel message list.
return nil
}
mm = append(mm, *m)
return nil
}); err != nil {
return nil, fmt.Errorf("AllMessages: walk: %s", err)
}
return mm, nil
func (e *Export) AllMessages(_ context.Context, channelID string) (iter.Seq2[slack.Message, error], error) {
return e.walkChannelMessages(channelID)
}

func (e *Export) walkChannelMessages(channelID string, fn func(m *slack.Message) error) error {
func (e *Export) walkChannelMessages(channelID string) (iter.Seq2[slack.Message, error], error) {
name, ok := e.chanNames[channelID]
if !ok {
return fmt.Errorf("%w: %s", fs.ErrNotExist, channelID)
return nil, fmt.Errorf("%w: %s", fs.ErrNotExist, channelID)
}
_, err := fs.Stat(e.fs, name)
if err != nil {
return fmt.Errorf("%w: %s", fs.ErrNotExist, name)
return nil, fmt.Errorf("%w: %s", fs.ErrNotExist, name)
}
return fs.WalkDir(e.fs, name, func(pth string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() || path.Ext(pth) != ".json" {
return nil
}
// read the file
em, err := unmarshal[[]export.ExportMessage](e.fs, pth)
if err != nil {
return err
}
for i, m := range em {
if m.Msg == nil {
slog.Default().Debug("skipping an empty message", "pth", pth, "index", i)
continue
iterFn := func(yield func(slack.Message, error) bool) {
err := fs.WalkDir(e.fs, name, func(pth string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if err := fn(&slack.Message{Msg: *m.Msg}); err != nil {
if d.IsDir() || path.Ext(pth) != ".json" {
return nil
}
// read the file
em, err := unmarshal[[]export.ExportMessage](e.fs, pth)
if err != nil {
return err
}
for i, m := range em {
if m.Msg == nil {
slog.Default().Debug("skipping an empty message", "pth", pth, "index", i)
continue
}
if !yield(slack.Message{Msg: *m.Msg}, nil) {
return fs.SkipAll
}
}
return nil
})
if err != nil {
yield(slack.Message{}, err)
}
return nil
})
}
return iterFn, nil
}

func isThreadMessage(m *slack.Msg) bool {
return m.ThreadTimestamp != "" && m.ThreadTimestamp != m.Timestamp
}

func (e *Export) AllThreadMessages(_ context.Context, channelID, threadID string) ([]slack.Message, error) {
var tm []slack.Message
if err := e.walkChannelMessages(channelID, func(m *slack.Message) error {
if m.ThreadTimestamp == threadID {
tm = append(tm, *m)
func (e *Export) AllThreadMessages(_ context.Context, channelID, threadID string) (iter.Seq2[slack.Message, error], error) {
it, err := e.walkChannelMessages(channelID)
if err != nil {
return nil, err
}
iterFn := func(yield func(slack.Message, error) bool) {
for m, err := range it {
if err != nil {
yield(slack.Message{}, err)
return
}
if m.ThreadTimestamp == threadID {
if !yield(m, nil) {
return
}
}
}
return nil
}); err != nil {
return nil, fmt.Errorf("AllThreadMessages: walk: %s", err)
}
return tm, nil
return iterFn, nil
}

func (e *Export) ChannelInfo(ctx context.Context, channelID string) (*slack.Channel, error) {
Expand Down
5 changes: 3 additions & 2 deletions internal/source/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"os"
"path"
Expand All @@ -39,10 +40,10 @@ type Sourcer interface {
// Users should return all users.
Users(ctx context.Context) ([]slack.User, error)
// AllMessages should return all messages for the given channel id.
AllMessages(ctx context.Context, channelID string) ([]slack.Message, error)
AllMessages(ctx context.Context, channelID string) (iter.Seq2[slack.Message, error], error)
// AllThreadMessages should return all messages for the given tuple
// (channelID, threadID).
AllThreadMessages(ctx context.Context, channelID, threadID string) ([]slack.Message, error)
AllThreadMessages(ctx context.Context, channelID, threadID string) (iter.Seq2[slack.Message, error], error)
// ChannelInfo should return the channel information for the given channel
// id.
ChannelInfo(ctx context.Context, channelID string) (*slack.Channel, error)
Expand Down

0 comments on commit 203fa0f

Please sign in to comment.