Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: workflow can create FieldMapping from or to fields of maps #95

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Invokable = "Invokable"
invokable = "invokable"
InvokableLambda = "InvokableLambda"
InvokableRun = "InvokableRun"
typ = "typ"

[files]
extend-exclude = ["go.mod", "go.sum", "check_branch_name.sh"]
125 changes: 98 additions & 27 deletions compose/field_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func (m *FieldMapping) String() string {
sb.WriteString("(field) of ")
}

sb.WriteString(m.fromNodeKey)

if m.to != "" {
sb.WriteString(" to ")
sb.WriteString(m.to)
Expand All @@ -60,20 +62,23 @@ func (m *FieldMapping) String() string {
// FromField creates a FieldMapping that maps a single predecessor field to the entire successor input.
// This is an exclusive mapping - once set, no other field mappings can be added since the successor input
// has already been fully mapped.
// Field: either the field of a struct, or the key of a map.
func FromField(from string) *FieldMapping {
return &FieldMapping{
from: from,
}
}

// ToField creates a FieldMapping that maps the entire predecessor output to a single successor field
// ToField creates a FieldMapping that maps the entire predecessor output to a single successor field.
// Field: either the field of a struct, or the key of a map.
func ToField(to string) *FieldMapping {
return &FieldMapping{
to: to,
}
}

// MapFields creates a FieldMapping that maps a single predecessor field to a single successor field
// MapFields creates a FieldMapping that maps a single predecessor field to a single successor field.
// Field: either the field of a struct, or the key of a map.
func MapFields(from, to string) *FieldMapping {
return &FieldMapping{
from: from,
Expand Down Expand Up @@ -160,6 +165,16 @@ func assignOne[T any](dest T, taken any, to string) (T, error) {

toSet := reflect.ValueOf(taken)

if destValue.Kind() == reflect.Map {
key, err := checkAndExtractToMapKey(to, destValue, toSet)
if err != nil {
return dest, err
}

destValue.SetMapIndex(key, toSet)
return destValue.Interface().(T), nil
}

field, err := checkAndExtractToField(to, destValue, toSet)
if err != nil {
return dest, err
Expand All @@ -171,30 +186,44 @@ func assignOne[T any](dest T, taken any, to string) (T, error) {
}

func checkAndExtractFromField(fromField string, input reflect.Value) (reflect.Value, error) {
if input.Kind() == reflect.Ptr {
input = input.Elem()
}

if input.Kind() != reflect.Struct {
return reflect.Value{}, fmt.Errorf("mapping has from but input is not struct or struct ptr, type= %v", input.Type())
}

f := input.FieldByName(fromField)
if !f.IsValid() {
return reflect.Value{}, fmt.Errorf("mapping has from not found. field=%v, inputType=%v", fromField, input.Type())
return reflect.Value{}, fmt.Errorf("field mapping from a struct field, but field not found. field=%v, inputType=%v", fromField, input.Type())
}

if !f.CanInterface() {
return reflect.Value{}, fmt.Errorf("mapping has from not exported. field= %v, inputType=%v", fromField, input.Type())
return reflect.Value{}, fmt.Errorf("field mapping from a struct field, but field not exported. field= %v, inputType=%v", fromField, input.Type())
}

return f, nil
}

func checkAndExtractFromMapKey(fromMapKey string, input reflect.Value) (reflect.Value, error) {
if !reflect.TypeOf(fromMapKey).AssignableTo(input.Type().Key()) {
return reflect.Value{}, fmt.Errorf("field mapping from a map key, but input is not a map with string key, type=%v", input.Type())
}

v := input.MapIndex(reflect.ValueOf(fromMapKey))
if !v.IsValid() {
return reflect.Value{}, fmt.Errorf("field mapping from a map key, but key not found in input. key=%s, inputType= %v", fromMapKey, input.Type())
}

return v, nil
}

func checkAndExtractFieldType(field string, typ reflect.Type) (reflect.Type, error) {
if len(field) == 0 {
return typ, nil
}

if typ.Kind() == reflect.Map {
if typ.Key() != strType {
return nil, fmt.Errorf("type[%v] is not a map with string key", typ)
}

return typ.Elem(), nil
}

for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
Expand All @@ -215,31 +244,49 @@ func checkAndExtractFieldType(field string, typ reflect.Type) (reflect.Type, err
return f.Type, nil
}

var strType = reflect.TypeOf("")

func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflect.Value, error) {
for output.Kind() == reflect.Ptr {
output = output.Elem()
}

if output.Kind() != reflect.Struct {
return reflect.Value{}, fmt.Errorf("mapping has to but output is not a struct, type=%v", output.Type())
return reflect.Value{}, fmt.Errorf("field mapping to a struct field but output is not a struct, type=%v", output.Type())
}

field := output.FieldByName(toField)
if !field.IsValid() {
return reflect.Value{}, fmt.Errorf("mapping has to not found. field=%v, outputType=%v", toField, output.Type())
return reflect.Value{}, fmt.Errorf("field mapping to a struct field, but field not found. field=%v, outputType=%v", toField, output.Type())
}

if !field.CanSet() {
return reflect.Value{}, fmt.Errorf("mapping has to not exported. field=%v, outputType=%v", toField, output.Type())
return reflect.Value{}, fmt.Errorf("field mapping to a struct field, but field not exported. field=%v, outputType=%v", toField, output.Type())
}

if !toSet.Type().AssignableTo(field.Type()) {
return reflect.Value{}, fmt.Errorf("mapping to has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type())
return reflect.Value{}, fmt.Errorf("field mapping to a struct field, but field has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type())
}

return field, nil
}

func checkAndExtractToMapKey(toMapKey string, output, toSet reflect.Value) (reflect.Value, error) {
if output.Kind() != reflect.Map {
return reflect.Value{}, fmt.Errorf("field mapping to a map key but output is not a map, type=%v", output.Type())
}

if !reflect.TypeOf(toMapKey).AssignableTo(output.Type().Key()) {
return reflect.Value{}, fmt.Errorf("field mapping to a map key but output is not a map with string key, type=%v", output.Type())
}

if !toSet.Type().AssignableTo(output.Type().Elem()) {
return reflect.Value{}, fmt.Errorf("field mapping to a map key but map value has a mismatched type. key=%s, from=%v, to=%v", toMapKey, toSet.Type(), output.Type().Elem())
}

return reflect.ValueOf(toMapKey), nil
}

func fieldMap(mappings []*FieldMapping) func(any) (map[string]any, error) {
return func(input any) (map[string]any, error) {
result := make(map[string]any, len(mappings))
Expand All @@ -262,14 +309,26 @@ func streamFieldMap(mappings []*FieldMapping) func(streamReader) streamReader {
}
}

func takeOne(input any, from string) (any, error) {
func takeOne(input any, from string) (taken any, err error) {
if len(from) == 0 {
return input, nil
}

inputValue := reflect.ValueOf(input)

f, err := checkAndExtractFromField(from, inputValue)
var f reflect.Value
switch inputValue.Kind() {
case reflect.Map:
f, err = checkAndExtractFromMapKey(from, inputValue)
case reflect.Ptr:
inputValue = inputValue.Elem()
fallthrough
case reflect.Struct:
f, err = checkAndExtractFromField(from, inputValue)
default:
return reflect.Value{}, fmt.Errorf("field mapping from a field, but input is not struct, struct ptr or map, type= %v", inputValue.Type())
}

if err != nil {
return nil, err
}
Expand All @@ -295,11 +354,18 @@ func isToAll(mappings []*FieldMapping) bool {
return false
}

func validateStruct(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
func validateStructOrMap(t reflect.Type) bool {
switch t.Kind() {
case reflect.Map:
return true
case reflect.Ptr:
t = t.Elem()
fallthrough
case reflect.Struct:
return true
default:
return false
}
return t.Kind() != reflect.Struct
}

func validateFieldMapping(predecessorType reflect.Type, successorType reflect.Type, mappings []*FieldMapping) (*handlerPair, error) {
Expand All @@ -308,20 +374,25 @@ func validateFieldMapping(predecessorType reflect.Type, successorType reflect.Ty
// check if mapping is legal
if isFromAll(mappings) && isToAll(mappings) {
return nil, fmt.Errorf("invalid field mappings: from all fields to all, use common edge instead")
} else if !isToAll(mappings) && validateStruct(successorType) {
} else if !isToAll(mappings) && !validateStructOrMap(successorType) {
// if user has not provided a specific struct type, graph cannot construct any struct in the runtime
return nil, fmt.Errorf("static check fail: upstream input type should be struct, actual: %v", successorType)
} else if !isFromAll(mappings) && validateStruct(predecessorType) {
return nil, fmt.Errorf("static check fail: successor input type should be struct or map, actual: %v", successorType)
} else if !isFromAll(mappings) && !validateStructOrMap(predecessorType) {
// TODO: should forbid?
return nil, fmt.Errorf("static check fail: downstream output type should be struct, actual: %v", predecessorType)
return nil, fmt.Errorf("static check fail: predecessor output type should be struct or map, actual: %v", predecessorType)
}

var (
predecessorFieldType, successorFieldType reflect.Type
err error
)

for _, mapping := range mappings {
predecessorFieldType, err := checkAndExtractFieldType(mapping.from, predecessorType)
predecessorFieldType, err = checkAndExtractFieldType(mapping.from, predecessorType)
if err != nil {
return nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err)
}
successorFieldType, err := checkAndExtractFieldType(mapping.to, successorType)
successorFieldType, err = checkAndExtractFieldType(mapping.to, successorType)
if err != nil {
return nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err)
}
Expand Down
35 changes: 29 additions & 6 deletions compose/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,29 @@ func TestWorkflow(t *testing.T) {
}
}

func TestWorkflowWithMap(t *testing.T) {
ctx := context.Background()

type structA struct {
F1 any
}

wf := NewWorkflow[map[string]any, map[string]any]()
wf.AddLambdaNode("lambda1", InvokableLambda(func(ctx context.Context, in map[string]any) (map[string]any, error) {
return in, nil
})).AddInput(START, MapFields("map_key", "lambda1_key"))
wf.AddLambdaNode("lambda2", InvokableLambda(func(ctx context.Context, in *structA) (*structA, error) {
return in, nil
})).AddInput(START, MapFields("map_key", "F1"))
wf.AddEnd("lambda1", MapFields("lambda1_key", "end_lambda1"))
wf.AddEnd("lambda2", MapFields("F1", "end_lambda2"))
r, err := wf.Compile(ctx)
assert.NoError(t, err)
out, err := r.Invoke(ctx, map[string]any{"map_key": "value"})
assert.NoError(t, err)
assert.Equal(t, map[string]any{"end_lambda1": "value", "end_lambda2": "value"}, out)
}

func TestWorkflowCompile(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
Expand All @@ -260,32 +283,32 @@ func TestWorkflowCompile(t *testing.T) {
assert.ErrorContains(t, err, "mismatch")
})

t.Run("upstream not struct/struct ptr, mapping has FromField", func(t *testing.T) {
t.Run("predecessor's output not struct/struct ptr/map, mapping has FromField", func(t *testing.T) {
w := NewWorkflow[[]*schema.Document, []string]()

w.AddIndexerNode("indexer", indexer.NewMockIndexer(ctrl)).AddInput(START, FromField("F1"))
w.AddEnd("indexer")
_, err := w.Compile(ctx)
assert.ErrorContains(t, err, "downstream output type should be struct")
assert.ErrorContains(t, err, "predecessor output type should be struct")
})

t.Run("downstream not struct/struct ptr, mapping has ToField", func(t *testing.T) {
t.Run("successor's input not struct/struct ptr/map, mapping has ToField", func(t *testing.T) {
w := NewWorkflow[[]string, [][]float64]()
w.AddEmbeddingNode("embedder", embedding.NewMockEmbedder(ctrl)).AddInput(START, ToField("F1"))
w.AddEnd("embedder")
_, err := w.Compile(ctx)
assert.ErrorContains(t, err, "upstream input type should be struct")
assert.ErrorContains(t, err, "successor input type should be struct")
})

t.Run("map to non existing field in upstream", func(t *testing.T) {
t.Run("map to non existing field in predecessor", func(t *testing.T) {
w := NewWorkflow[*schema.Message, []*schema.Message]()
w.AddToolsNode("tools_node", &ToolsNode{}).AddInput(START, FromField("non_exist"))
w.AddEnd("tools_node")
_, err := w.Compile(ctx)
assert.ErrorContains(t, err, "type[schema.Message] has no field[non_exist]")
})

t.Run("map to not exported field in downstream", func(t *testing.T) {
t.Run("map to not exported field in successor", func(t *testing.T) {
w := NewWorkflow[string, *FieldMapping]()
w.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return input, nil
Expand Down
Loading