Skip to content

Commit

Permalink
feat: add InitTrace to Langfuse
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn committed Mar 6, 2025
1 parent 331b373 commit c0721da
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 47 deletions.
2 changes: 1 addition & 1 deletion callbacks/langfuse/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.18
require (
github.com/bytedance/mockey v1.2.13
github.com/bytedance/sonic v1.12.7
github.com/cloudwego/eino v0.3.10
github.com/cloudwego/eino v0.3.13
github.com/cloudwego/eino-ext/libs/acl/langfuse v0.0.0-20250113033825-eb19b2b6b386
github.com/golang/mock v1.6.0
github.com/stretchr/testify v1.10.0
Expand Down
4 changes: 2 additions & 2 deletions callbacks/langfuse/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/eino v0.3.10 h1:KQoc+FXt+5VkoStAxkle0J21HjHumu6+cdVHjBT7BuA=
github.com/cloudwego/eino v0.3.10/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino v0.3.13 h1:5fq5hM+UzbLtv4nXMhU6tAxgb7Q3AyaJ6/566XsJqc0=
github.com/cloudwego/eino v0.3.13/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
github.com/cloudwego/eino-ext/libs/acl/langfuse v0.0.0-20250113033825-eb19b2b6b386 h1:dF//5iW+PCS8ZnZ0PwmO2enn3Oek++mbgB6dmaJAz6o=
github.com/cloudwego/eino-ext/libs/acl/langfuse v0.0.0-20250113033825-eb19b2b6b386/go.mod h1:77jqGUJZjxg+V/sJ8S6dd0JtRLO782yVWHmhuFgb9ig=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
Expand Down
75 changes: 34 additions & 41 deletions callbacks/langfuse/langfuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ type Config struct {
Public bool
}

func NewLangfuseHandler(cfg *Config) (handler callbacks.Handler, flusher func()) {
func NewLangfuseHandler(cfg *Config) (handler *CallbackHandler, flusher func()) {
var langfuseOpts []langfuse.Option
if cfg.Threads > 0 {
langfuseOpts = append(langfuseOpts, langfuse.WithThreads(cfg.Threads))
Expand Down Expand Up @@ -157,7 +157,7 @@ func NewLangfuseHandler(cfg *Config) (handler callbacks.Handler, flusher func())
langfuseOpts...,
)

return &langfuseHandler{
return &CallbackHandler{
cli: cli,

name: cfg.Name,
Expand All @@ -169,7 +169,7 @@ func NewLangfuseHandler(cfg *Config) (handler callbacks.Handler, flusher func())
}, cli.Flush
}

type langfuseHandler struct {
type CallbackHandler struct {
cli langfuse.Langfuse

name string
Expand All @@ -186,18 +186,18 @@ type langfuseState struct {
observationID string
}

func (l *langfuseHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
if info == nil {
return ctx
}

ctx, state := l.getOrInitState(ctx, getName(info))
ctx, state := c.getOrInitState(ctx, getName(info))
if state == nil {
return ctx
}
if info.Component == components.ComponentOfChatModel {
mcbi := model.ConvCallbackInput(input)
generationID, err := l.cli.CreateGeneration(&langfuse.GenerationEventBody{
generationID, err := c.cli.CreateGeneration(&langfuse.GenerationEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
Name: getName(info),
Expand Down Expand Up @@ -226,7 +226,7 @@ func (l *langfuseHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
log.Printf("marshal input error: %v, runinfo: %+v", err, info)
return ctx
}
spanID, err := l.cli.CreateSpan(&langfuse.SpanEventBody{
spanID, err := c.cli.CreateSpan(&langfuse.SpanEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
Name: getName(info),
Expand All @@ -247,7 +247,7 @@ func (l *langfuseHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
})
}

func (l *langfuseHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
if info == nil {
return ctx
}
Expand Down Expand Up @@ -278,7 +278,7 @@ func (l *langfuseHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou
}
}

err := l.cli.EndGeneration(body)
err := c.cli.EndGeneration(body)
if err != nil {
log.Printf("end generation error: %v, runinfo: %+v", err, info)
}
Expand All @@ -290,7 +290,7 @@ func (l *langfuseHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou
log.Printf("marshal output error: %v, runinfo: %+v", err, info)
return ctx
}
err = l.cli.EndSpan(&langfuse.SpanEventBody{
err = c.cli.EndSpan(&langfuse.SpanEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
ID: state.observationID,
Expand All @@ -305,7 +305,7 @@ func (l *langfuseHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou
return ctx
}

func (l *langfuseHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
func (c *CallbackHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
if info == nil {
return ctx
}
Expand All @@ -329,14 +329,14 @@ func (l *langfuseHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
CompletionStartTime: time.Now(),
}

reportErr := l.cli.EndGeneration(body)
reportErr := c.cli.EndGeneration(body)
if reportErr != nil {
log.Printf("end generation fail: %v, runinfo: %+v, execute error: %v", reportErr, info, err)
}
return ctx
}

reportErr := l.cli.EndSpan(&langfuse.SpanEventBody{
reportErr := c.cli.EndSpan(&langfuse.SpanEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
ID: state.observationID,
Expand All @@ -352,18 +352,18 @@ func (l *langfuseHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
return ctx
}

func (l *langfuseHandler) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
if info == nil {
return ctx
}

ctx, state := l.getOrInitState(ctx, getName(info))
ctx, state := c.getOrInitState(ctx, getName(info))
if state == nil {
return ctx
}

if info.Component == components.ComponentOfChatModel {
generationID, err := l.cli.CreateGeneration(&langfuse.GenerationEventBody{
generationID, err := c.cli.CreateGeneration(&langfuse.GenerationEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
Name: getName(info),
Expand Down Expand Up @@ -404,7 +404,7 @@ func (l *langfuseHandler) OnStartWithStreamInput(ctx context.Context, info *call
log.Printf("extract stream model input error: %v, runinfo: %+v", err_, info)
return
}
err = l.cli.EndGeneration(&langfuse.GenerationEventBody{
err = c.cli.EndGeneration(&langfuse.GenerationEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
ID: generationID,
Expand All @@ -426,7 +426,7 @@ func (l *langfuseHandler) OnStartWithStreamInput(ctx context.Context, info *call
})
}

spanID, err := l.cli.CreateSpan(&langfuse.SpanEventBody{
spanID, err := c.cli.CreateSpan(&langfuse.SpanEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
Name: getName(info),
Expand Down Expand Up @@ -467,7 +467,7 @@ func (l *langfuseHandler) OnStartWithStreamInput(ctx context.Context, info *call
log.Printf("marshal input error: %v, runinfo: %+v", err_, info)
return
}
err = l.cli.EndSpan(&langfuse.SpanEventBody{
err = c.cli.EndSpan(&langfuse.SpanEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
ID: spanID,
Expand All @@ -486,7 +486,7 @@ func (l *langfuseHandler) OnStartWithStreamInput(ctx context.Context, info *call
})
}

func (l *langfuseHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
if info == nil {
return ctx
}
Expand Down Expand Up @@ -539,7 +539,7 @@ func (l *langfuseHandler) OnEndWithStreamOutput(ctx context.Context, info *callb
}
}

err = l.cli.EndGeneration(body)
err = c.cli.EndGeneration(body)
if err != nil {
log.Printf("end stream generation error: %v, runinfo: %+v", err, info)
}
Expand Down Expand Up @@ -571,7 +571,7 @@ func (l *langfuseHandler) OnEndWithStreamOutput(ctx context.Context, info *callb
if err != nil {
log.Printf("marshal stream output error: %v, runinfo: %+v", err, info)
}
err = l.cli.EndSpan(&langfuse.SpanEventBody{
err = c.cli.EndSpan(&langfuse.SpanEventBody{
BaseObservationEventBody: langfuse.BaseObservationEventBody{
BaseEventBody: langfuse.BaseEventBody{
ID: state.observationID,
Expand All @@ -588,33 +588,26 @@ func (l *langfuseHandler) OnEndWithStreamOutput(ctx context.Context, info *callb
return ctx
}

func (l *langfuseHandler) getOrInitState(ctx context.Context, curName string) (context.Context, *langfuseState) {
func (c *CallbackHandler) getOrInitState(ctx context.Context, curName string) (context.Context, *langfuseState) {
state := ctx.Value(langfuseStateKey{})
if state == nil {
name := l.name
name := c.name
if len(name) == 0 {
name = curName
}
traceID, err := l.cli.CreateTrace(&langfuse.TraceEventBody{
BaseEventBody: langfuse.BaseEventBody{
Name: name,
},
TimeStamp: time.Now(),
UserID: l.userID,
SessionID: l.sessionID,
Release: l.release,
Tags: l.tags,
Public: l.public,
nState, err := initState(ctx, c.cli, &traceOptions{
Name: c.name,
UserID: c.userID,
SessionID: c.sessionID,
Release: c.release,
Tags: c.tags,
Public: c.public,
})
if err != nil {
log.Printf("create trace error: %v", err)
return ctx, nil
}
s := &langfuseState{
traceID: traceID,
log.Printf("init state fail: %v", err)
}
ctx = context.WithValue(ctx, langfuseStateKey{}, s)
return ctx, s
ctx = context.WithValue(ctx, langfuseStateKey{}, nState)
return ctx, nState
}
return ctx, state.(*langfuseState)
}
Expand Down
33 changes: 30 additions & 3 deletions callbacks/langfuse/langfuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestLangfuseCallback(t *testing.T) {
}

mockey.PatchConvey("test span", t, func() {
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace id", nil).Times(1)
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace id", nil).Times(2)
createSpanTimes := 0
mockLangfuse.EXPECT().CreateSpan(gomock.Any()).DoAndReturn(func(body *langfuse.SpanEventBody) (string, error) {
defer func() {
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestLangfuseCallback(t *testing.T) {
})

mockey.PatchConvey("test span stream", t, func() {
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace id", nil).AnyTimes()
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace id", nil).Times(1)
mockLangfuse.EXPECT().CreateSpan(gomock.Any()).DoAndReturn(func(body *langfuse.SpanEventBody) (string, error) {
return "", nil
}).AnyTimes()
Expand Down Expand Up @@ -221,7 +221,7 @@ func TestLangfuseCallback(t *testing.T) {
})

mockey.PatchConvey("test generation stream", t, func() {
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace id", nil).AnyTimes()
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace id", nil).Times(1)
mockLangfuse.EXPECT().CreateGeneration(gomock.Any()).DoAndReturn(func(body *langfuse.GenerationEventBody) (string, error) {
return "generation id", nil
}).AnyTimes()
Expand Down Expand Up @@ -263,4 +263,31 @@ func TestLangfuseCallback(t *testing.T) {
ctx2 := cbh.OnStartWithStreamInput(ctx, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, insr)
cbh.OnEndWithStreamOutput(ctx2, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, outsr)
})
mockey.PatchConvey("test init trace", t, func() {
mockLangfuse.EXPECT().CreateTrace(gomock.Any()).
DoAndReturn(func(body *langfuse.TraceEventBody) (string, error) {
assert.Equal(t, map[string]string{"key": "value"}, body.MetaData)
assert.Equal(t, "name", body.Name)
assert.Equal(t, "release", body.Release)
assert.Equal(t, "traceid", body.ID)
assert.Equal(t, "userid", body.UserID)
assert.Equal(t, "sessionid", body.SessionID)
assert.Equal(t, []string{"tags"}, body.Tags)
assert.Equal(t, true, body.Public)
return "trace id", nil
}).Times(1)

ctx, err = (&CallbackHandler{cli: mockLangfuse}).InitTrace(context.Background(),
WithMetadata(map[string]string{"key": "value"}),
WithName("name"),
WithRelease("release"),
WithID("traceid"),
WithUserID("userid"),
WithSessionID("sessionid"),
WithTags("tags"),
WithPublic(true),
)
assert.NoError(t, err)
assert.Equal(t, "trace id", ctx.Value(langfuseStateKey{}).(*langfuseState).traceID)
})
}
Loading

0 comments on commit c0721da

Please sign in to comment.