diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..ce85d3b8c --- /dev/null +++ b/.dockerignore @@ -0,0 +1,2 @@ +vendor/* +bin/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..ff8a9b166 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea +vendor +bin +.DS_Store diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..a414f33f7 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..9ec4f2c9c --- /dev/null +++ b/.travis.yml @@ -0,0 +1,26 @@ +sudo: required +language: go +go: + - "1.12" +services: + - docker +jobs: + include: + - if: fork = true + stage: test + name: docker build + install: true + script: make docker_build + - if: fork = false + stage: test + name: docker build and push + install: true + script: make dockerhub_push + - stage: test + install: make install + name: lint + script: make lint + - stage: test + name: unit tests + install: make install + script: make test_unit diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..803d8a77f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,3 @@ +This project is governed by [Lyft's code of +conduct](https://github.com/lyft/code-of-conduct). All contributors +and participants agree to abide by its terms. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..2ad29f2ff --- /dev/null +++ b/Dockerfile @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/flytepropeller/ +WORKDIR /go/src/github.com/lyft/flytepropeller + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/flytepropeller/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/flytepropeller +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["flytepropeller"] diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 000000000..08a0aa8c7 --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,1433 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:80e5d0810f1448259385b08f381852a83f87b6c958d8500e621db821e15c3771" + name = "cloud.google.com/go" + packages = ["compute/metadata"] + pruneopts = "" + revision = "cdaaf98f9226c39dc162b8e55083b2fbc67b4674" + version = "v0.43.0" + +[[projects]] + digest = "1:6158256042564abf0da300ea7cb016f79ddaf24fdda2cc06c9712b0c2e06dd2a" + name = "contrib.go.opencensus.io/exporter/ocagent" + packages = ["."] + pruneopts = "" + revision = "dcb33c7f3b7cfe67e8a2cea10207ede1b7c40764" + version = "v0.4.12" + +[[projects]] + digest = "1:9a11be778d5fcb8e4873e64a097dfd2862d8665d9e2d969b90810d5272e51acb" + name = "github.com/Azure/azure-sdk-for-go" + packages = ["storage"] + pruneopts = "" + revision = "2d49bb8f2cee530cc16f1f1a9f0aae763dee257d" + version = "v10.2.1-beta" + +[[projects]] + digest = "1:5cb9540799639936e705a6ac54cfb6744b598519485fb357acb6e3285f43fbfb" + name = "github.com/Azure/go-autorest" + packages = [ + "autorest", + "autorest/adal", + "autorest/azure", + "autorest/date", + "logger", + "tracing", + ] + pruneopts = "" + revision = "7166fb346dbf8978ad28211a1937b20fdabc08c8" + version = "v12.4.2" + +[[projects]] + digest = "1:558b53577dc0c9fde49b08405d706b202bcac3064320e9be53a75fc866280ee3" + name = "github.com/DiSiqueira/GoTree" + packages = ["."] + pruneopts = "" + revision = "53a8e837f2952215f256fc9acf4ecb2045b056fb" + version = "2.0.3" + +[[projects]] + digest = "1:e1549ae10031ac55dd7d26ac4d480130ddbdf97f9a26ebbedff089aa0335798f" + name = "github.com/GoogleCloudPlatform/spark-on-k8s-operator" + packages = [ + "pkg/apis/sparkoperator.k8s.io", + "pkg/apis/sparkoperator.k8s.io/v1beta1", + ] + pruneopts = "" + revision = "5306d013b4dbd6a9c75879c1643c7fcb237560ec" + source = "https://github.com/lyft/spark-on-k8s-operator" + version = "v0.1.3" + +[[projects]] + digest = "1:60942d250d0e06d3722ddc8e22bc52f8cef7961ba6d8d3e95327a32b6b024a7b" + name = "github.com/appscode/jsonpatch" + packages = ["."] + pruneopts = "" + revision = "7c0e3b262f30165a8ec3d0b4c6059fd92703bfb2" + version = "1.0.0" + +[[projects]] + digest = "1:cfe39a015adcf9cc2bce0e8bd38ecf041cb516b8ab7a2ecb11b1c84a4b8acabf" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/json/jsonutil", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/s3", + "service/s3/s3iface", + "service/s3/s3manager", + "service/sts", + "service/sts/stsiface", + ] + pruneopts = "" + revision = "14379de571db1ac1b08f2f723a1acc1810c4dd0d" + version = "v1.22.2" + +[[projects]] + branch = "master" + digest = "1:d700667a9f768e1c6db34b091ec28b709dd8de6a62a61cd9a3ceb39d442154a7" + name = "github.com/benlaurie/objecthash" + packages = ["go/objecthash"] + pruneopts = "" + revision = "d1e3d6079fc16f8f542183fb5b2fdc11d9f00866" + +[[projects]] + digest = "1:ac2a05be7167c495fe8aaf8aaf62ecf81e78d2180ecb04e16778dc6c185c96a5" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "" + revision = "37c8de3658fcb183f997c4e13e8337516ab753e6" + version = "v1.0.1" + +[[projects]] + digest = "1:ad70cf78ff17abf96d92a6082f4d3241fef8f149118f87c3a267ed47a08be603" + name = "github.com/census-instrumentation/opencensus-proto" + packages = [ + "gen-go/agent/common/v1", + "gen-go/agent/metrics/v1", + "gen-go/agent/trace/v1", + "gen-go/metrics/v1", + "gen-go/resource/v1", + "gen-go/trace/v1", + ] + pruneopts = "" + revision = "d89fa54de508111353cb0b06403c00569be780d8" + version = "v0.2.1" + +[[projects]] + digest = "1:f6485831252319cd6ca29fc170adecf1eb81bf1e805f62f44eb48564ce2485fe" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:193f6d32d751f26540aa8eeedc114ce0a51f9e77b6c22dda3a4db4e5f65aec66" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "" + revision = "3c79a0a23c1940ab4479332fb3e0127265650ce3" + version = "v1.1.0" + +[[projects]] + digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6" + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + pruneopts = "" + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + +[[projects]] + digest = "1:46ddeb9dd35d875ac7568c4dc1fc96ce424e034bdbb984239d8ffc151398ec01" + name = "github.com/evanphx/json-patch" + packages = ["."] + pruneopts = "" + revision = "026c730a0dcc5d11f93f1cf1cc65b01247ea7b6f" + version = "v4.5.0" + +[[projects]] + digest = "1:e988ed0ca0d81f4d28772760c02ee95084961311291bdfefc1b04617c178b722" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + branch = "master" + digest = "1:135223bf2c128b2158178ee48779ac9983b003634864d46b73e913c95f7a847e" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "" + revision = "1485a34d5d5723fea214f5710708e19a831720e4" + +[[projects]] + digest = "1:b13707423743d41665fd23f0c36b2f37bb49c30e94adb813319c44188a51ba22" + name = "github.com/ghodss/yaml" + packages = ["."] + pruneopts = "" + revision = "0ca9ea5df5451ffdf184b4428c902747c2c11cd7" + version = "v1.0.0" + +[[projects]] + digest = "1:65587005c6fa4293c0b8a2e457e689df7fda48cc5e1f5449ea2c1e7784551558" + name = "github.com/go-logr/logr" + packages = ["."] + pruneopts = "" + revision = "9fb12b3b21c5415d16ac18dc5cd42c1cfdd40c4e" + version = "v0.1.0" + +[[projects]] + digest = "1:d81dfed1aa731d8e4a45d87154ec15ef18da2aa80fa9a2f95bec38577a244a99" + name = "github.com/go-logr/zapr" + packages = ["."] + pruneopts = "" + revision = "03f06a783fbb7dfaf3f629c7825480e43a7105e6" + version = "v0.1.1" + +[[projects]] + digest = "1:c2db84082861ca42d0b00580d28f4b31aceec477a00a38e1a057fb3da75c8adc" + name = "github.com/go-redis/redis" + packages = [ + ".", + "internal", + "internal/consistenthash", + "internal/hashtag", + "internal/pool", + "internal/proto", + "internal/util", + ] + pruneopts = "" + revision = "75795aa4236dc7341eefac3bbe945e68c99ef9df" + version = "v6.15.3" + +[[projects]] + digest = "1:fd53b471edb4c28c7d297f617f4da0d33402755f58d6301e7ca1197ef0a90937" + name = "github.com/gogo/protobuf" + packages = [ + "proto", + "sortkeys", + ] + pruneopts = "" + revision = "ba06b47c162d49f2af050fb4c75bcbc86a159d5c" + version = "v1.2.1" + +[[projects]] + branch = "master" + digest = "1:e1822d37be8e11e101357a27170527b1056c99182407f270e080f76409adbd9a" + name = "github.com/golang/groupcache" + packages = ["lru"] + pruneopts = "" + revision = "869f871628b6baa9cfbc11732cdf6546b17c1298" + +[[projects]] + digest = "1:b852d2b62be24e445fcdbad9ce3015b44c207815d631230dfce3f14e7803f5bf" + name = "github.com/golang/protobuf" + packages = [ + "jsonpb", + "proto", + "protoc-gen-go/descriptor", + "protoc-gen-go/generator", + "protoc-gen-go/generator/internal/remap", + "protoc-gen-go/plugin", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + "ptypes/wrappers", + ] + pruneopts = "" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" + +[[projects]] + digest = "1:1e5b1e14524ed08301977b7b8e10c719ed853cbf3f24ecb66fae783a46f207a6" + name = "github.com/google/btree" + packages = ["."] + pruneopts = "" + revision = "4030bb1f1f0c35b30ca7009e9ebd06849dd45306" + version = "v1.0.0" + +[[projects]] + digest = "1:8d4a577a9643f713c25a32151c0f26af7228b4b97a219b5ddb7fd38d16f6e673" + name = "github.com/google/gofuzz" + packages = ["."] + pruneopts = "" + revision = "f140a6486e521aad38f5917de355cbf147cc0496" + version = "v1.0.0" + +[[projects]] + digest = "1:ad92aa49f34cbc3546063c7eb2cabb55ee2278b72842eda80e2a20a8a06a8d73" + name = "github.com/google/uuid" + packages = ["."] + pruneopts = "" + revision = "0cd6bf5da1e1c83f8b45653022c74f71af0538a4" + version = "v1.1.1" + +[[projects]] + digest = "1:5facc3828b6a56f9aec988433ea33fb4407a89460952ed75be5347cec07318c0" + name = "github.com/googleapis/gnostic" + packages = [ + "OpenAPIv2", + "compiler", + "extensions", + ] + pruneopts = "" + revision = "e73c7ec21d36ddb0711cb36d1502d18363b5c2c9" + version = "v0.3.0" + +[[projects]] + digest = "1:1ea91d049b6a609f628ecdfda32e85f445a0d3671980dcbf7cbe1bbd7ee6aabc" + name = "github.com/graymeta/stow" + packages = [ + ".", + "azure", + "google", + "local", + "oracle", + "s3", + "swift", + ] + pruneopts = "" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" + +[[projects]] + branch = "master" + digest = "1:e1fd67b5695fb12f54f979606c5d650a5aa72ef242f8e71072bfd4f7b5a141a0" + name = "github.com/gregjones/httpcache" + packages = [ + ".", + "diskcache", + ] + pruneopts = "" + revision = "901d90724c7919163f472a9812253fb26761123d" + +[[projects]] + digest = "1:9a0b2dd1f882668a3d7fbcd424eed269c383a16f1faa3a03d14e0dd5fba571b1" + name = "github.com/grpc-ecosystem/go-grpc-middleware" + packages = [ + ".", + "retry", + "util/backoffutils", + "util/metautils", + ] + pruneopts = "" + revision = "c250d6563d4d4c20252cd865923440e829844f4e" + version = "v1.0.0" + +[[projects]] + digest = "1:e24dc5ef44694848785de507f439a24e9e6d96d7b43b8cf3d6cfa857aa1e2186" + name = "github.com/grpc-ecosystem/go-grpc-prometheus" + packages = ["."] + pruneopts = "" + revision = "c225b8c3b01faf2899099b768856a9e916e5087b" + version = "v1.2.0" + +[[projects]] + digest = "1:4ab82898193e99be9d4f1f1eb4ca3b1113ab6b7b2ff4605198ae305de864f05e" + name = "github.com/grpc-ecosystem/grpc-gateway" + packages = [ + "internal", + "protoc-gen-swagger/options", + "runtime", + "utilities", + ] + pruneopts = "" + revision = "ad529a448ba494a88058f9e5be0988713174ac86" + version = "v1.9.5" + +[[projects]] + digest = "1:7f6f07500a0b7d3766b00fa466040b97f2f5b5f3eef2ecabfe516e703b05119a" + name = "github.com/hashicorp/golang-lru" + packages = [ + ".", + "simplelru", + ] + pruneopts = "" + revision = "7f827b33c0f158ec5dfbba01bb0b14a4541fd81d" + version = "v0.5.3" + +[[projects]] + digest = "1:d14365c51dd1d34d5c79833ec91413bfbb166be978724f15701e17080dc06dec" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:31bfd110d31505e9ffbc9478e31773bf05bf02adcaeb9b139af42684f9294c13" + name = "github.com/imdario/mergo" + packages = ["."] + pruneopts = "" + revision = "7c29201646fa3de8506f701213473dd407f19646" + version = "v0.3.7" + +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "" + revision = "c2b33e84" + +[[projects]] + digest = "1:e716a02584d94519e2ccf7ac461c4028da736d41a58c1ed95e641c1603bdb056" + name = "github.com/json-iterator/go" + packages = ["."] + pruneopts = "" + revision = "27518f6661eba504be5a7a9a9f6d9460d892ade3" + version = "v1.1.7" + +[[projects]] + digest = "1:0f51cee70b0d254dbc93c22666ea2abf211af81c1701a96d04e2284b408621db" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "" + revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" + version = "v1.0.2" + +[[projects]] + digest = "1:2b5f0e6bc8fb862fed5bccf9fbb1ab819c8b3f8a21e813fe442c06aec3bb3e86" + name = "github.com/lyft/flyteidl" + packages = [ + "clients/go/admin", + "clients/go/admin/mocks", + "clients/go/coreutils", + "clients/go/coreutils/logs", + "clients/go/datacatalog/mocks", + "clients/go/events", + "clients/go/events/errors", + "gen/pb-go/flyteidl/admin", + "gen/pb-go/flyteidl/core", + "gen/pb-go/flyteidl/datacatalog", + "gen/pb-go/flyteidl/event", + "gen/pb-go/flyteidl/plugins", + "gen/pb-go/flyteidl/service", + ] + pruneopts = "" + revision = "793b09d190148236f41ad8160b5cec9a3325c16f" + source = "https://github.com/lyft/flyteidl" + version = "v0.1.0" + +[[projects]] + digest = "1:500471ee50c4141d3523c79615cc90529b3152f8aa5924b63122df6bf201a7a0" + name = "github.com/lyft/flyteplugins" + packages = [ + "go/tasks", + "go/tasks/v1", + "go/tasks/v1/config", + "go/tasks/v1/errors", + "go/tasks/v1/events", + "go/tasks/v1/flytek8s", + "go/tasks/v1/flytek8s/config", + "go/tasks/v1/k8splugins", + "go/tasks/v1/logs", + "go/tasks/v1/qubole", + "go/tasks/v1/qubole/client", + "go/tasks/v1/qubole/config", + "go/tasks/v1/qubole/mocks", + "go/tasks/v1/resourcemanager", + "go/tasks/v1/types", + "go/tasks/v1/types/mocks", + "go/tasks/v1/utils", + ] + pruneopts = "" + revision = "8c85a7c9f19de4df4767de329c56a7f09d0a7bbc" + source = "https://github.com/lyft/flyteplugins" + version = "v0.1.0" + +[[projects]] + digest = "1:77615fd7bcfd377c5e898c41d33be827f108166691cb91257d27113ee5d08650" + name = "github.com/lyft/flytestdlib" + packages = [ + "atomic", + "config", + "config/files", + "config/viper", + "contextutils", + "errors", + "ioutils", + "logger", + "pbhash", + "profutils", + "promutils", + "promutils/labeled", + "sets", + "storage", + "utils", + "version", + "yamlutils", + ] + pruneopts = "" + revision = "7292f20ec17b42f104fd61d7f0120e17bcacf751" + source = "https://github.com/lyft/flytestdlib" + version = "v0.2.16" + +[[projects]] + digest = "1:ae39921edb7f801f7ce1b6b5484f9715a1dd2b52cb645daef095cd10fd6ee774" + name = "github.com/magiconair/properties" + packages = [ + ".", + "assert", + ] + pruneopts = "" + revision = "de8848e004dd33dc07a2947b3d76f618a7fc7ef1" + version = "v1.8.1" + +[[projects]] + digest = "1:9ea83adf8e96d6304f394d40436f2eb44c1dc3250d223b74088cc253a6cd0a1c" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:dbfae9da5a674236b914e486086671145b37b5e3880a38da906665aede3c9eab" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "" + revision = "1311e847b0cb909da63b5fecfb5370aa66236465" + version = "v0.0.8" + +[[projects]] + digest = "1:63722a4b1e1717be7b98fc686e0b30d5e7f734b9e93d7dee86293b6deab7ea28" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:bcc46a0fbd9e933087bef394871256b5c60269575bb661935874729c65bbbf60" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:0c0ff2a89c1bb0d01887e1dac043ad7efbf3ec77482ef058ac423d13497e16fd" + name = "github.com/modern-go/concurrent" + packages = ["."] + pruneopts = "" + revision = "bacd9c7ef1dd9b15be4a9909b8ac7a4e313eec94" + version = "1.0.3" + +[[projects]] + digest = "1:e32bdbdb7c377a07a9a46378290059822efdce5c8d96fe71940d87cb4f918855" + name = "github.com/modern-go/reflect2" + packages = ["."] + pruneopts = "" + revision = "4b7aa43c6742a2c18fdef89dd197aaae7dac7ccd" + version = "1.0.1" + +[[projects]] + branch = "master" + digest = "1:b6c101f6c8ab09c631e969c30d3a4b42aeca82580499253bad77cb2426d4fc27" + name = "github.com/ncw/swift" + packages = ["."] + pruneopts = "" + revision = "a24ef33bc9b7e59ae4bed9e87a51d7bc76122731" + +[[projects]] + digest = "1:c1a07a723fa656d4ba5ac489fcb4dfa3aef0fec6b34e415f0002dfc5ee2ba872" + name = "github.com/operator-framework/operator-sdk" + packages = ["pkg/util/k8sutil"] + pruneopts = "" + revision = "e5a0ab096e1a7c0e6b937d2b41707eccb82c3c77" + version = "v0.0.7" + +[[projects]] + digest = "1:a5484d4fa43127138ae6e7b2299a6a52ae006c7f803d98d717f60abf3e97192e" + name = "github.com/pborman/uuid" + packages = ["."] + pruneopts = "" + revision = "adf5a7427709b9deb95d29d3fa8a2bf9cfd388f1" + version = "v1.2" + +[[projects]] + digest = "1:3d2c33720d4255686b9f4a7e4d3b94938ee36063f14705c5eb0f73347ed4c496" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "" + revision = "728039f679cbcd4f6a54e080d2219a4c4928c546" + version = "v1.4.0" + +[[projects]] + branch = "master" + digest = "1:5f0faa008e8ff4221b55a1a5057c8b02cb2fd68da6a65c9e31c82b72cbc836d0" + name = "github.com/petar/GoLLRB" + packages = ["llrb"] + pruneopts = "" + revision = "33fb24c13b99c46c93183c291836c573ac382536" + +[[projects]] + digest = "1:4709c61d984ef9ba99b037b047546d8a576ae984fb49486e48d99658aa750cd5" + name = "github.com/peterbourgon/diskv" + packages = ["."] + pruneopts = "" + revision = "0be1b92a6df0e4f5cb0a5d15fb7f643d0ad93ce6" + version = "v3.0.0" + +[[projects]] + digest = "1:1d7e1867c49a6dd9856598ef7c3123604ea3daabf5b83f303ff457bcbc410b1d" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:c826496cad27bd9a7644a01230a79d472b4093dd33587236e8f8369bb1d8534e" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + "prometheus/promhttp", + ] + pruneopts = "" + revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" + version = "v0.9.4" + +[[projects]] + branch = "master" + digest = "1:cd67319ee7536399990c4b00fae07c3413035a53193c644549a676091507cadc" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "" + revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + +[[projects]] + digest = "1:0f2cee44695a3208fe5d6926076641499c72304e6f015348c9ab2df90a202cdf" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "" + revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" + version = "v0.6.0" + +[[projects]] + digest = "1:9b33e539d6bf6e4453668a847392d1e9e6345225ea1426f9341212c652bcbee4" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/fs", + ] + pruneopts = "" + revision = "3f98efb27840a48a7a2898ec80be07674d19f9c8" + version = "v0.0.3" + +[[projects]] + digest = "1:7f569d906bdd20d906b606415b7d794f798f91a62fcfb6a4daa6d50690fb7a3f" + name = "github.com/satori/uuid" + packages = ["."] + pruneopts = "" + revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" + version = "v1.2.0" + +[[projects]] + digest = "1:1a405cddcf3368445051fb70ab465ae99da56ad7be8d8ca7fc52159d1c2d873c" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "" + revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" + version = "v1.4.2" + +[[projects]] + digest = "1:956f655c87b7255c6b1ae6c203ebb0af98cf2a13ef2507e34c9bf1c0332ac0f5" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "" + revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424" + version = "v1.2.2" + +[[projects]] + digest = "1:ae3493c780092be9d576a1f746ab967293ec165e8473425631f06658b6212afc" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "" + revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" + version = "v0.0.5" + +[[projects]] + digest = "1:cc15ae4fbdb02ce31f3392361a70ac041f4f02e0485de8ffac92bd8033e3d26e" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "" + revision = "94f6ae3ed3bceceafa716478c5fbf8d29ca601a1" + version = "v1.1.0" + +[[projects]] + digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:c25a789c738f7cc8ec7f34026badd4e117853f329334a5aa45cf5d0727d7d442" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:711eebe744c0151a9d09af2315f0bb729b2ec7637ef4c410fa90a18ef74b65b6" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c" + name = "github.com/stretchr/testify" + packages = [ + "assert", + "mock", + ] + pruneopts = "" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + +[[projects]] + digest = "1:98f63c8942146f9bf4b3925db1d96637b86c1d83693a894a244eae54aa53bb40" + name = "go.opencensus.io" + packages = [ + ".", + "exemplar", + "internal", + "internal/tagencoding", + "plugin/ocgrpc", + "plugin/ochttp", + "plugin/ochttp/propagation/b3", + "plugin/ochttp/propagation/tracecontext", + "resource", + "stats", + "stats/internal", + "stats/view", + "tag", + "trace", + "trace/internal", + "trace/propagation", + "trace/tracestate", + ] + pruneopts = "" + revision = "aab39bd6a98b853ab66c8a564f5d6cfcad59ce8a" + +[[projects]] + digest = "1:e6ff7840319b6fda979a918a8801005ec2049abca62af19211d96971d8ec3327" + name = "go.uber.org/atomic" + packages = ["."] + pruneopts = "" + revision = "df976f2515e274675050de7b3f42545de80594fd" + version = "v1.4.0" + +[[projects]] + digest = "1:22c7effcb4da0eacb2bb1940ee173fac010e9ef3c691f5de4b524d538bd980f5" + name = "go.uber.org/multierr" + packages = ["."] + pruneopts = "" + revision = "3c4937480c32f4c13a875a1829af76c98ca3d40a" + version = "v1.1.0" + +[[projects]] + digest = "1:984e93aca9088b440b894df41f2043b6a3db8f9cf30767032770bfc4796993b0" + name = "go.uber.org/zap" + packages = [ + ".", + "buffer", + "internal/bufferpool", + "internal/color", + "internal/exit", + "zapcore", + ] + pruneopts = "" + revision = "27376062155ad36be76b0f12cf1572a221d3a48c" + version = "v1.10.0" + +[[projects]] + branch = "master" + digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328" + name = "golang.org/x/crypto" + packages = ["ssh/terminal"] + pruneopts = "" + revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" + +[[projects]] + branch = "master" + digest = "1:955694a7c42527d7fb188505a22f10b3e158c6c2cf31fe64b1e62c9ab7b18401" + name = "golang.org/x/net" + packages = [ + "context", + "context/ctxhttp", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "" + revision = "ca1201d0de80cfde86cb01aea620983605dfe99b" + +[[projects]] + branch = "master" + digest = "1:01bdbbc604dcd5afb6f66a717f69ad45e9643c72d5bc11678d44ffa5c50f9e42" + name = "golang.org/x/oauth2" + packages = [ + ".", + "google", + "internal", + "jws", + "jwt", + ] + pruneopts = "" + revision = "0f29369cfe4552d0e4bcddc57cc75f4d7e672a33" + +[[projects]] + branch = "master" + digest = "1:9f6efefb4e401a4f699a295d14518871368eb89403f2dd23ec11dfcd2c0836ba" + name = "golang.org/x/sync" + packages = ["semaphore"] + pruneopts = "" + revision = "112230192c580c3556b8cee6403af37a4fc5f28c" + +[[projects]] + branch = "master" + digest = "1:0b5c2207c72f2d13995040f176feb6e3f453d6b01af2b9d57df76b05ded2e926" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows", + ] + pruneopts = "" + revision = "51ab0e2deafac1f46c46ad59cf0921be2f180c3d" + +[[projects]] + digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" + name = "golang.org/x/text" + packages = [ + "collate", + "collate/build", + "internal/colltab", + "internal/gen", + "internal/language", + "internal/language/compact", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", + "transform", + "unicode/bidi", + "unicode/cldr", + "unicode/norm", + "unicode/rangetable", + ] + pruneopts = "" + revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" + version = "v0.3.2" + +[[projects]] + branch = "master" + digest = "1:9522af4be529c108010f95b05f1022cb872f2b9ff8b101080f554245673466e1" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "" + revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" + +[[projects]] + branch = "master" + digest = "1:3f52587092bc722a3c3843989e6b88ec26924dc4b7b9c971095b7e93a11e0eff" + name = "golang.org/x/tools" + packages = [ + "go/ast/astutil", + "go/gcexportdata", + "go/internal/gcimporter", + "go/internal/packagesdriver", + "go/packages", + "go/types/typeutil", + "imports", + "internal/fastwalk", + "internal/gopathwalk", + "internal/imports", + "internal/module", + "internal/semver", + ] + pruneopts = "" + revision = "e713427fea3f98cb070e72a058c557a1a560cf22" + +[[projects]] + branch = "master" + digest = "1:f77558501305be5977ac30110f9820d21c5f1a89328667dc82db0bd9ebaab4c4" + name = "google.golang.org/api" + packages = [ + "gensupport", + "googleapi", + "googleapi/internal/uritemplates", + "googleapi/transport", + "internal", + "option", + "storage/v1", + "support/bundler", + "transport/http", + "transport/http/internal/propagation", + ] + pruneopts = "" + revision = "6f3912904777a209e099b9dbda3ed7bcb4e25ad7" + +[[projects]] + digest = "1:47f391ee443f578f01168347818cb234ed819521e49e4d2c8dd2fb80d48ee41a" + name = "google.golang.org/appengine" + packages = [ + ".", + "internal", + "internal/app_identity", + "internal/base", + "internal/datastore", + "internal/log", + "internal/modules", + "internal/remote_api", + "internal/urlfetch", + "urlfetch", + ] + pruneopts = "" + revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" + version = "v1.6.1" + +[[projects]] + branch = "master" + digest = "1:95b0a53d4d31736b2483a8c41667b2bd83f303706106f81bd2f54e3f9c24eaf4" + name = "google.golang.org/genproto" + packages = [ + "googleapis/api/annotations", + "googleapis/api/httpbody", + "googleapis/rpc/status", + "protobuf/field_mask", + ] + pruneopts = "" + revision = "fa694d86fc64c7654a660f8908de4e879866748d" + +[[projects]] + digest = "1:425ee670b3e8b6562e31754021a82d78aa46b9281247827376616c8aa78f4687" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/balancerload", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "serviceconfig", + "stats", + "status", + "tap", + ] + pruneopts = "" + revision = "045159ad57f3781d409358e3ade910a018c16b30" + version = "v1.22.1" + +[[projects]] + digest = "1:75fb3fcfc73a8c723efde7777b40e8e8ff9babf30d8c56160d01beffea8a95a6" + name = "gopkg.in/inf.v0" + packages = ["."] + pruneopts = "" + revision = "d2d2541c53f18d2a059457998ce2876cc8e67cbf" + version = "v0.9.1" + +[[projects]] + digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + digest = "1:73ee122857f257aa507ebae097783fe08ad8af49398e5b3876787325411f1a4b" + name = "k8s.io/api" + packages = [ + "admission/v1beta1", + "admissionregistration/v1alpha1", + "admissionregistration/v1beta1", + "apps/v1", + "apps/v1beta1", + "apps/v1beta2", + "auditregistration/v1alpha1", + "authentication/v1", + "authentication/v1beta1", + "authorization/v1", + "authorization/v1beta1", + "autoscaling/v1", + "autoscaling/v2beta1", + "autoscaling/v2beta2", + "batch/v1", + "batch/v1beta1", + "batch/v2alpha1", + "certificates/v1beta1", + "coordination/v1beta1", + "core/v1", + "events/v1beta1", + "extensions/v1beta1", + "networking/v1", + "policy/v1beta1", + "rbac/v1", + "rbac/v1alpha1", + "rbac/v1beta1", + "scheduling/v1alpha1", + "scheduling/v1beta1", + "settings/v1alpha1", + "storage/v1", + "storage/v1alpha1", + "storage/v1beta1", + ] + pruneopts = "" + revision = "27b77cf22008a0bf6e510b2a500b8885805a1b68" + source = "https://github.com/lyft/api" + +[[projects]] + digest = "1:a3bee4b1e4013573fc15631b51a7b7e0d580497e6fec63dc3724b370e624569f" + name = "k8s.io/apimachinery" + packages = [ + "pkg/api/errors", + "pkg/api/meta", + "pkg/api/resource", + "pkg/apis/meta/internalversion", + "pkg/apis/meta/v1", + "pkg/apis/meta/v1/unstructured", + "pkg/apis/meta/v1beta1", + "pkg/conversion", + "pkg/conversion/queryparams", + "pkg/fields", + "pkg/labels", + "pkg/runtime", + "pkg/runtime/schema", + "pkg/runtime/serializer", + "pkg/runtime/serializer/json", + "pkg/runtime/serializer/protobuf", + "pkg/runtime/serializer/recognizer", + "pkg/runtime/serializer/streaming", + "pkg/runtime/serializer/versioning", + "pkg/selection", + "pkg/types", + "pkg/util/cache", + "pkg/util/clock", + "pkg/util/diff", + "pkg/util/errors", + "pkg/util/framer", + "pkg/util/intstr", + "pkg/util/json", + "pkg/util/mergepatch", + "pkg/util/naming", + "pkg/util/net", + "pkg/util/rand", + "pkg/util/runtime", + "pkg/util/sets", + "pkg/util/strategicpatch", + "pkg/util/uuid", + "pkg/util/validation", + "pkg/util/validation/field", + "pkg/util/wait", + "pkg/util/yaml", + "pkg/version", + "pkg/watch", + "third_party/forked/golang/json", + "third_party/forked/golang/reflect", + ] + pruneopts = "" + revision = "695912cabb3a9c08353fbe628115867d47f56b1f" + source = "https://github.com/lyft/apimachinery" + +[[projects]] + digest = "1:8dbb2adde0a196cc682fcff9f26d5d7f407a8639c8ee2936ac5da514582e1f65" + name = "k8s.io/client-go" + packages = [ + "discovery", + "discovery/fake", + "dynamic", + "kubernetes", + "kubernetes/scheme", + "kubernetes/typed/admissionregistration/v1alpha1", + "kubernetes/typed/admissionregistration/v1beta1", + "kubernetes/typed/apps/v1", + "kubernetes/typed/apps/v1beta1", + "kubernetes/typed/apps/v1beta2", + "kubernetes/typed/auditregistration/v1alpha1", + "kubernetes/typed/authentication/v1", + "kubernetes/typed/authentication/v1beta1", + "kubernetes/typed/authorization/v1", + "kubernetes/typed/authorization/v1beta1", + "kubernetes/typed/autoscaling/v1", + "kubernetes/typed/autoscaling/v2beta1", + "kubernetes/typed/autoscaling/v2beta2", + "kubernetes/typed/batch/v1", + "kubernetes/typed/batch/v1beta1", + "kubernetes/typed/batch/v2alpha1", + "kubernetes/typed/certificates/v1beta1", + "kubernetes/typed/coordination/v1beta1", + "kubernetes/typed/core/v1", + "kubernetes/typed/events/v1beta1", + "kubernetes/typed/extensions/v1beta1", + "kubernetes/typed/networking/v1", + "kubernetes/typed/policy/v1beta1", + "kubernetes/typed/rbac/v1", + "kubernetes/typed/rbac/v1alpha1", + "kubernetes/typed/rbac/v1beta1", + "kubernetes/typed/scheduling/v1alpha1", + "kubernetes/typed/scheduling/v1beta1", + "kubernetes/typed/settings/v1alpha1", + "kubernetes/typed/storage/v1", + "kubernetes/typed/storage/v1alpha1", + "kubernetes/typed/storage/v1beta1", + "pkg/apis/clientauthentication", + "pkg/apis/clientauthentication/v1alpha1", + "pkg/apis/clientauthentication/v1beta1", + "pkg/version", + "plugin/pkg/client/auth/exec", + "rest", + "rest/watch", + "restmapper", + "testing", + "tools/auth", + "tools/cache", + "tools/clientcmd", + "tools/clientcmd/api", + "tools/clientcmd/api/latest", + "tools/clientcmd/api/v1", + "tools/leaderelection", + "tools/leaderelection/resourcelock", + "tools/metrics", + "tools/pager", + "tools/record", + "tools/reference", + "transport", + "util/buffer", + "util/cert", + "util/connrotation", + "util/flowcontrol", + "util/homedir", + "util/integer", + "util/retry", + "util/workqueue", + ] + pruneopts = "" + revision = "7621a5ebb88b1e49ce7e7837ae8e99ca030a3c13" + version = "kubernetes-1.13.5" + +[[projects]] + digest = "1:d809e6c8dfa3448ae10f5624eff4ed1ebdc906755e7cea294c44e8b7ac0b077a" + name = "k8s.io/code-generator" + packages = [ + "cmd/client-gen", + "cmd/client-gen/args", + "cmd/client-gen/generators", + "cmd/client-gen/generators/fake", + "cmd/client-gen/generators/scheme", + "cmd/client-gen/generators/util", + "cmd/client-gen/path", + "cmd/client-gen/types", + "cmd/conversion-gen", + "cmd/conversion-gen/args", + "cmd/conversion-gen/generators", + "cmd/deepcopy-gen", + "cmd/deepcopy-gen/args", + "cmd/defaulter-gen", + "cmd/defaulter-gen/args", + "cmd/informer-gen", + "cmd/informer-gen/args", + "cmd/informer-gen/generators", + "cmd/lister-gen", + "cmd/lister-gen/args", + "cmd/lister-gen/generators", + "pkg/util", + ] + pruneopts = "" + revision = "c2090bec4d9b1fb25de3812f868accc2bc9ecbae" + version = "kubernetes-1.13.5" + +[[projects]] + branch = "master" + digest = "1:6a2a63e09a59caff3fd2d36d69b7b92c2fe7cf783390f0b7349fb330820f9a8e" + name = "k8s.io/gengo" + packages = [ + "args", + "examples/deepcopy-gen/generators", + "examples/defaulter-gen/generators", + "examples/set-gen/sets", + "generator", + "namer", + "parser", + "types", + ] + pruneopts = "" + revision = "e17681d19d3ac4837a019ece36c2a0ec31ffe985" + +[[projects]] + digest = "1:3063061b6514ad2666c4fa292451685884cacf77c803e1b10b4a4fa23f7787fb" + name = "k8s.io/klog" + packages = ["."] + pruneopts = "" + revision = "3ca30a56d8a775276f9cdae009ba326fdc05af7f" + version = "v0.4.0" + +[[projects]] + branch = "master" + digest = "1:3176cac3365c8442ab92d465e69e05071b0dbc0d715e66b76059b04611811dff" + name = "k8s.io/kube-openapi" + packages = ["pkg/util/proto"] + pruneopts = "" + revision = "5e22f3d471e6f24ca20becfdffdc6206c7cecac8" + +[[projects]] + digest = "1:77629c3c036454b4623e99e20f5591b9551dd81d92db616384af92435b52e9b6" + name = "sigs.k8s.io/controller-runtime" + packages = [ + "pkg/cache", + "pkg/cache/informertest", + "pkg/cache/internal", + "pkg/client", + "pkg/client/apiutil", + "pkg/client/config", + "pkg/client/fake", + "pkg/controller/controllertest", + "pkg/event", + "pkg/handler", + "pkg/internal/objectutil", + "pkg/internal/recorder", + "pkg/leaderelection", + "pkg/manager", + "pkg/metrics", + "pkg/patch", + "pkg/predicate", + "pkg/reconcile", + "pkg/recorder", + "pkg/runtime/inject", + "pkg/runtime/log", + "pkg/source", + "pkg/source/internal", + "pkg/webhook/admission", + "pkg/webhook/admission/types", + "pkg/webhook/internal/metrics", + "pkg/webhook/types", + ] + pruneopts = "" + revision = "f1eaba5087d69cebb154c6a48193e6667f5b512c" + version = "v0.1.12" + +[[projects]] + digest = "1:321081b4a44256715f2b68411d8eda9a17f17ebfe6f0cc61d2cc52d11c08acfa" + name = "sigs.k8s.io/yaml" + packages = ["."] + pruneopts = "" + revision = "fd68e9863619f6ec2fdd8625fe1f02e7c877e480" + version = "v1.1.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/DiSiqueira/GoTree", + "github.com/fatih/color", + "github.com/ghodss/yaml", + "github.com/golang/protobuf/jsonpb", + "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/ptypes", + "github.com/golang/protobuf/ptypes/struct", + "github.com/golang/protobuf/ptypes/timestamp", + "github.com/grpc-ecosystem/go-grpc-middleware/retry", + "github.com/lyft/flyteidl/clients/go/admin", + "github.com/lyft/flyteidl/clients/go/admin/mocks", + "github.com/lyft/flyteidl/clients/go/coreutils", + "github.com/lyft/flyteidl/clients/go/datacatalog/mocks", + "github.com/lyft/flyteidl/clients/go/events", + "github.com/lyft/flyteidl/clients/go/events/errors", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/datacatalog", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service", + "github.com/lyft/flyteplugins/go/tasks", + "github.com/lyft/flyteplugins/go/tasks/v1", + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s", + "github.com/lyft/flyteplugins/go/tasks/v1/types", + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks", + "github.com/lyft/flytestdlib/config", + "github.com/lyft/flytestdlib/config/viper", + "github.com/lyft/flytestdlib/contextutils", + "github.com/lyft/flytestdlib/logger", + "github.com/lyft/flytestdlib/pbhash", + "github.com/lyft/flytestdlib/profutils", + "github.com/lyft/flytestdlib/promutils", + "github.com/lyft/flytestdlib/promutils/labeled", + "github.com/lyft/flytestdlib/storage", + "github.com/lyft/flytestdlib/utils", + "github.com/lyft/flytestdlib/version", + "github.com/lyft/flytestdlib/yamlutils", + "github.com/magiconair/properties/assert", + "github.com/mitchellh/mapstructure", + "github.com/operator-framework/operator-sdk/pkg/util/k8sutil", + "github.com/pkg/errors", + "github.com/prometheus/client_golang/prometheus", + "github.com/spf13/cobra", + "github.com/spf13/pflag", + "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", + "golang.org/x/time/rate", + "google.golang.org/grpc", + "google.golang.org/grpc/codes", + "google.golang.org/grpc/status", + "k8s.io/api/batch/v1", + "k8s.io/api/core/v1", + "k8s.io/apimachinery/pkg/api/errors", + "k8s.io/apimachinery/pkg/api/resource", + "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/apimachinery/pkg/labels", + "k8s.io/apimachinery/pkg/runtime", + "k8s.io/apimachinery/pkg/runtime/schema", + "k8s.io/apimachinery/pkg/runtime/serializer", + "k8s.io/apimachinery/pkg/types", + "k8s.io/apimachinery/pkg/util/clock", + "k8s.io/apimachinery/pkg/util/rand", + "k8s.io/apimachinery/pkg/util/runtime", + "k8s.io/apimachinery/pkg/util/sets", + "k8s.io/apimachinery/pkg/util/wait", + "k8s.io/apimachinery/pkg/watch", + "k8s.io/client-go/discovery", + "k8s.io/client-go/discovery/fake", + "k8s.io/client-go/kubernetes", + "k8s.io/client-go/kubernetes/scheme", + "k8s.io/client-go/kubernetes/typed/core/v1", + "k8s.io/client-go/rest", + "k8s.io/client-go/testing", + "k8s.io/client-go/tools/cache", + "k8s.io/client-go/tools/clientcmd", + "k8s.io/client-go/tools/leaderelection", + "k8s.io/client-go/tools/leaderelection/resourcelock", + "k8s.io/client-go/tools/record", + "k8s.io/client-go/util/flowcontrol", + "k8s.io/client-go/util/workqueue", + "k8s.io/code-generator/cmd/client-gen", + "k8s.io/code-generator/cmd/conversion-gen", + "k8s.io/code-generator/cmd/deepcopy-gen", + "k8s.io/code-generator/cmd/defaulter-gen", + "k8s.io/code-generator/cmd/informer-gen", + "k8s.io/code-generator/cmd/lister-gen", + "k8s.io/gengo/args", + "sigs.k8s.io/controller-runtime/pkg/cache", + "sigs.k8s.io/controller-runtime/pkg/cache/informertest", + "sigs.k8s.io/controller-runtime/pkg/client", + "sigs.k8s.io/controller-runtime/pkg/client/fake", + "sigs.k8s.io/controller-runtime/pkg/manager", + "sigs.k8s.io/controller-runtime/pkg/runtime/inject", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 000000000..1a0e8974e --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,102 @@ +# Gopkg.toml example +# +# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + +required = [ + "k8s.io/code-generator/cmd/defaulter-gen", + "k8s.io/code-generator/cmd/deepcopy-gen", + "k8s.io/code-generator/cmd/conversion-gen", + "k8s.io/code-generator/cmd/client-gen", + "k8s.io/code-generator/cmd/lister-gen", + "k8s.io/code-generator/cmd/informer-gen", + "k8s.io/gengo/args", +] + +[[constraint]] + name = "github.com/fatih/color" + version = "1.7.0" + +[[override]] + name = "contrib.go.opencensus.io/exporter/ocagent" + version = "0.4.x" + +[[constraint]] + name = "github.com/golang/protobuf" + version = "1.1.0" + +[[constraint]] + name = "github.com/lyft/flyteidl" + source = "https://github.com/lyft/flyteidl" + version = "^0.1.x" + +[[constraint]] + name = "github.com/lyft/flyteplugins" + source = "https://github.com/lyft/flyteplugins" + version = "^0.1.0" + +[[override]] + name = "github.com/lyft/flytestdlib" + source = "https://github.com/lyft/flytestdlib" + version = "^0.2.16" + +# Spark has a dependency on 1.11.2, so we cannot upgrade yet +[[override]] + name = "k8s.io/api" + revision = "27b77cf22008a0bf6e510b2a500b8885805a1b68" + source = "https://github.com/lyft/api" + +[[override]] + name = "k8s.io/apimachinery" + source = "https://github.com/lyft/apimachinery" + revision = "695912cabb3a9c08353fbe628115867d47f56b1f" + +[[override]] + name = "k8s.io/client-go" + version = "kubernetes-1.13.5" + +[[constraint]] + name = "github.com/DiSiqueira/GoTree" + version = "2.0.3" + +[[override]] + name = "k8s.io/code-generator" + # revision = "6702109cc68eb6fe6350b83e14407c8d7309fd1a" + version = "kubernetes-1.13.5" + +[[override]] + name = "github.com/graymeta/stow" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" + +[[override]] + name = "github.com/json-iterator/go" + version = "^1.1.5" + +[[override]] + name = "sigs.k8s.io/controller-runtime" + version = "=v0.1.12" + +[[override]] + branch = "master" + name = "golang.org/x/net" diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..bed437514 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..30db27782 --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ +export REPOSITORY=flytepropeller +include boilerplate/lyft/docker_build/Makefile +include boilerplate/lyft/golang_test_targets/Makefile + +.PHONY: update_boilerplate +update_boilerplate: + @boilerplate/update.sh + +.PHONY: linux_compile +linux_compile: + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/kubectl-flyte ./cmd/kubectl-flyte/main.go + +.PHONY: compile +compile: + mkdir -p ./bin + go build -o bin/flytepropeller ./cmd/controller/main.go + go build -o bin/kubectl-flyte ./cmd/kubectl-flyte/main.go && cp bin/kubectl-flyte ${GOPATH}/bin + +cross_compile: + @glide install + @mkdir -p ./bin/cross + GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 go build -o bin/cross/kubectl-flyte ./cmd/kubectl-flyte/main.go + +op_code_generate: + @RESOURCE_NAME=flyteworkflow OPERATOR_PKG=github.com/lyft/flytepropeller ./hack/update-codegen.sh + +benchmark: + mkdir -p ./bin/benchmark + @go test -run=^$ -bench=. -cpuprofile=cpu.out -memprofile=mem.out ./pkg/controller/nodes/. && mv *.out ./bin/benchmark/ && mv *.test ./bin/benchmark/ + +# server starts the service in development mode +.PHONY: server +server: + @go run ./cmd/controller/main.go -logtostderr --kubeconfig=$(HOME)/.kube/config + +clean: + rm -rf bin + +# Generate golden files. Add test packages that generate golden files here. +golden: + go test ./cmd/kubectl-flyte/cmd -update + go test ./pkg/compiler/test -update diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..dba5e33d0 --- /dev/null +++ b/NOTICE @@ -0,0 +1,5 @@ +flytepropeller +Copyright 2019 Lyft Inc. + +This product includes software developed at Lyft Inc. +This product includes software derived from https://github.com/kubernetes/sample-controller diff --git a/README.rst b/README.rst new file mode 100644 index 000000000..de0f1ac2c --- /dev/null +++ b/README.rst @@ -0,0 +1,129 @@ +Flyte Propeller +=============== + + +.. +.. image:: https://img.shields.io/github/release/lyft/flytepropeller.svg +:target: https://github.com/lyft/flytepropeller/releases/latest + +.. image:: https://godoc.org/github.com/lyft/flytepropeller?status.svg +:target: https://godoc.org/github.com/lyft/flytepropeller) + +.. image:: https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg +:target: http://www.apache.org/licenses/LICENSE-2.0.html) + +.. image:: https://img.shields.io/codecov/c/github/lyft/flytepropeller.svg +:target: https://codecov.io/gh/lyft/flytepropeller + +.. image:: https://goreportcard.com/badge/github.com/lyft/flytepropeller +:target: https://goreportcard.com/report/github.com/lyft/flytepropeller + +.. image:: https://img.shields.io/github/commit-activity/w/lyft/flytepropeller.svg?style=plastic + +.. image:: https://img.shields.io/github/commits-since/lyft/flytepropeller/latest.svg?style=plastic + + +Kubernetes operator to executes Flyte graphs natively on kubernetes + +Getting Started +=============== +kubectl-flyte tool +------------------ +kubectl-flyte is an command line tool that can be used as an extension to kubectl. It is a separate binary that is built from the propeller repo. + +Install +------- +This command will install kubectl-flyte and flytepropeller to `~/go/bin` +.. code-block:: make + + $make compile + +Use +--- +Two ways to execute the command, either standalone *kubectl-flyte* or as a subcommand of *kubectl* + +.. code-block:: command + + $ kubectl-flyte --help + OR + $ kubectl flyte --help + Flyte is a serverless workflow processing platform built for native execution on K8s. + It is extensible and flexible to allow adding new operators and comes with many operators built in + + Usage: + kubectl-flyte [flags] + kubectl-flyte [command] + + Available Commands: + compile Compile a workflow from core proto-buffer files and output a closure. + config Runs various config commands, look at the help of this command to get a list of available commands.. + create Creates a new workflow from proto-buffer files. + delete delete a workflow + get Gets a single workflow or lists all workflows currently in execution + help Help about any command + visualize Get GraphViz dot-formatted output. + + +Observing running workflows +--------------------------- + +To retrieve all workflows in a namespace use the --namespace option, --namespace = "" implies all namespaces. + +.. code-block:: command + + $ kubectl-flyte get --namespace flytekit-development + workflows + ├── flytekit-development/flytekit-development-f01c74085110840b8827 [ExecId: ... ] (2m34s Succeeded) - Time SinceCreation(30h1m39.683602s) + ... + Found 19 workflows + Success: 19, Failed: 0, Running: 0, Waiting: 0 + + +To retrieve a specific workflow, namespace can either be provided in the format namespace/name or using the --namespace argument + +.. code-block:: command + + $ kubectl-flyte get flytekit-development/flytekit-development-ff806e973581f4508bf1 + Workflow + └── flytekit-development/flytekit-development-ff806e973581f4508bf1 [ExecId: project:"flytekit" domain:"development" name:"ff806e973581f4508bf1" ] (2m32s Succeeded ) + ├── start-node start 0s Succeeded + ├── c task 0s Succeeded + ├── b task 0s Succeeded + ├── a task 0s Succeeded + └── end-node end 0s Succeeded + +Deleting workflows +------------------ +To delete a specific workflow + +.. code-block:: command + + $ kubectl-flyte delete --namespace flytekit-development flytekit-development-ff806e973581f4508bf1 + +To delete all completed workflows - they have to be either success/failed with a special isCompleted label set on them. The Label is set `here ` + +.. code-block:: command + + $ kubectl-flyte delete --namespace flytekit-development --all-completed + +Running propeller locally +------------------------- +use the config.yaml in root found `here `. Cd into this folder and then run + +.. code-block:: command + + $ flytepropeller --logtostderr + +Following dependencies need to be met +1. Blob store (you can forward minio port to localhost) +2. Admin Service endpoint (can be forwarded) OR *Disable* events to admin and launchplans +3. access to kubeconfig and kubeapi + +Making changes to CRD +===================== +*Remember* changes to CRD should be carefully done, they should be backwards compatible or else you should use proper +operator versioning system. Once you do the changes, remember to execute + +.. code-block:: make + + $make op_code_generate diff --git a/boilerplate/lyft/docker_build/Makefile b/boilerplate/lyft/docker_build/Makefile new file mode 100644 index 000000000..4019dab83 --- /dev/null +++ b/boilerplate/lyft/docker_build/Makefile @@ -0,0 +1,12 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +.PHONY: docker_build +docker_build: + IMAGE_NAME=$$REPOSITORY ./boilerplate/lyft/docker_build/docker_build.sh + +.PHONY: dockerhub_push +dockerhub_push: + IMAGE_NAME=lyft/$$REPOSITORY REGISTRY=docker.io ./boilerplate/lyft/docker_build/docker_build.sh diff --git a/boilerplate/lyft/docker_build/Readme.rst b/boilerplate/lyft/docker_build/Readme.rst new file mode 100644 index 000000000..bb6af9b49 --- /dev/null +++ b/boilerplate/lyft/docker_build/Readme.rst @@ -0,0 +1,23 @@ +Docker Build and Push +~~~~~~~~~~~~~~~~~~~~~ + +Provides a ``make docker_build`` target that builds your image locally. + +Provides a ``make dockerhub_push`` target that pushes your final image to Dockerhub. + +The Dockerhub image will tagged ``:`` + +If git head has a git tag, the Dockerhub image will also be tagged ``:``. + +**To Enable:** + +Add ``lyft/docker_build`` to your ``boilerplate/update.cfg`` file. + +Add ``include boilerplate/lyft/docker_build/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/docker_build/Makefile + +(this ensures the extra Make targets get included in your main Makefile) diff --git a/boilerplate/lyft/docker_build/docker_build.sh b/boilerplate/lyft/docker_build/docker_build.sh new file mode 100755 index 000000000..f504c100c --- /dev/null +++ b/boilerplate/lyft/docker_build/docker_build.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +echo "" +echo "------------------------------------" +echo " DOCKER BUILD" +echo "------------------------------------" +echo "" + +if [ -n "$REGISTRY" ]; then + # Do not push if there are unstaged git changes + CHANGED=$(git status --porcelain) + if [ -n "$CHANGED" ]; then + echo "Please commit git changes before pushing to a registry" + exit 1 + fi +fi + + +GIT_SHA=$(git rev-parse HEAD) + +IMAGE_TAG_SUFFIX="" +# for intermediate build phases, append -$BUILD_PHASE to all image tags +if [ -n "$BUILD_PHASE" ]; then + IMAGE_TAG_SUFFIX="-${BUILD_PHASE}" +fi + +IMAGE_TAG_WITH_SHA="${IMAGE_NAME}:${GIT_SHA}${IMAGE_TAG_SUFFIX}" + +RELEASE_SEMVER=$(git describe --tags --exact-match "$GIT_SHA" 2>/dev/null) || true +if [ -n "$RELEASE_SEMVER" ]; then + IMAGE_TAG_WITH_SEMVER="${IMAGE_NAME}:${RELEASE_SEMVER}${IMAGE_TAG_SUFFIX}" +fi + +# build the image +# passing no build phase will build the final image +docker build -t "$IMAGE_TAG_WITH_SHA" --target=${BUILD_PHASE} . +echo "${IMAGE_TAG_WITH_SHA} built locally." + +# if REGISTRY specified, push the images to the remote registy +if [ -n "$REGISTRY" ]; then + + if [ -n "${DOCKER_REGISTRY_PASSWORD}" ]; then + docker login --username="$DOCKER_REGISTRY_USERNAME" --password="$DOCKER_REGISTRY_PASSWORD" + fi + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SHA} pushed to remote." + + # If the current commit has a semver tag, also push the images with the semver tag + if [ -n "$RELEASE_SEMVER" ]; then + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER} pushed to remote." + + fi +fi diff --git a/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate b/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate new file mode 100644 index 000000000..5e7b984a1 --- /dev/null +++ b/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/{{REPOSITORY}}/ +WORKDIR /go/src/github.com/lyft/{{REPOSITORY}} + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/{{REPOSITORY}}/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/{{REPOSITORY}} +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["{{REPOSITORY}}"] diff --git a/boilerplate/lyft/golang_dockerfile/Readme.rst b/boilerplate/lyft/golang_dockerfile/Readme.rst new file mode 100644 index 000000000..f801ef98d --- /dev/null +++ b/boilerplate/lyft/golang_dockerfile/Readme.rst @@ -0,0 +1,16 @@ +Golang Dockerfile +~~~~~~~~~~~~~~~~~ + +Provides a Dockerfile that produces a small image. + +**To Enable:** + +Add ``lyft/golang_dockerfile`` to your ``boilerplate/update.cfg`` file. + +Create and configure a ``make linux_compile`` target that compiles your go binaries to the ``/artifacts`` directory :: + + .PHONY: linux_compile + linux_compile: + RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts {{ packages }} + +All binaries compiled to ``/artifacts`` will be available at ``/bin`` in your final image. diff --git a/boilerplate/lyft/golang_dockerfile/update.sh b/boilerplate/lyft/golang_dockerfile/update.sh new file mode 100755 index 000000000..7d8466326 --- /dev/null +++ b/boilerplate/lyft/golang_dockerfile/update.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +echo " - generating Dockerfile in root directory." +sed -e "s/{{REPOSITORY}}/${REPOSITORY}/g" ${DIR}/Dockerfile.GoTemplate > ${DIR}/../../../Dockerfile diff --git a/boilerplate/lyft/golang_test_targets/Makefile b/boilerplate/lyft/golang_test_targets/Makefile new file mode 100644 index 000000000..6c1e527fd --- /dev/null +++ b/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,38 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +DEP_SHA=1f7c19e5f52f49ffb9f956f64c010be14683468b + +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.16.0 + golangci-lint run --exclude deprecated + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl "https://raw.githubusercontent.com/golang/dep/${DEP_SHA}/install.sh" | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/boilerplate/lyft/golang_test_targets/Readme.rst b/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100644 index 000000000..acc5744f5 --- /dev/null +++ b/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/boilerplate/lyft/golang_test_targets/goimports b/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 000000000..160525a8c --- /dev/null +++ b/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/boilerplate/lyft/golangci_file/.golangci.yml b/boilerplate/lyft/golangci_file/.golangci.yml new file mode 100644 index 000000000..a414f33f7 --- /dev/null +++ b/boilerplate/lyft/golangci_file/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/boilerplate/lyft/golangci_file/Readme.rst b/boilerplate/lyft/golangci_file/Readme.rst new file mode 100644 index 000000000..ba5d2b61c --- /dev/null +++ b/boilerplate/lyft/golangci_file/Readme.rst @@ -0,0 +1,8 @@ +GolangCI File +~~~~~~~~~~~~~ + +Provides a ``.golangci`` file with the linters we've agreed upon. + +**To Enable:** + +Add ``lyft/golangci_file`` to your ``boilerplate/update.cfg`` file. diff --git a/boilerplate/lyft/golangci_file/update.sh b/boilerplate/lyft/golangci_file/update.sh new file mode 100755 index 000000000..9e9e6c1f4 --- /dev/null +++ b/boilerplate/lyft/golangci_file/update.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +# Clone the .golangci file +echo " - copying ${DIR}/.golangci to the root directory." +cp ${DIR}/.golangci.yml ${DIR}/../../../.golangci.yml diff --git a/boilerplate/update.cfg b/boilerplate/update.cfg new file mode 100644 index 000000000..5417c8046 --- /dev/null +++ b/boilerplate/update.cfg @@ -0,0 +1,4 @@ +lyft/docker_build +lyft/golang_test_targets +lyft/golangci_file +lyft/golang_dockerfile diff --git a/boilerplate/update.sh b/boilerplate/update.sh new file mode 100755 index 000000000..bea661d9a --- /dev/null +++ b/boilerplate/update.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +OUT="$(mktemp -d)" +git clone git@github.com:lyft/boilerplate.git "${OUT}" + +echo "Updating the update.sh script." +cp "${OUT}/boilerplate/update.sh" "${DIR}/update.sh" +echo "" + + +CONFIG_FILE="${DIR}/update.cfg" +README="https://github.com/lyft/boilerplate/blob/master/Readme.rst" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "$CONFIG_FILE not found." + echo "This file is required in order to select which features to include." + echo "See $README for more details." + exit 1 +fi + +if [ -z "$REPOSITORY" ]; then + echo '$REPOSITORY is required to run this script' + echo "See $README for more details." + exit 1 +fi + +while read directory; do + echo "***********************************************************************************" + echo "$directory is configured in update.cfg." + echo "-----------------------------------------------------------------------------------" + echo "syncing files from source." + dir_path="${OUT}/boilerplate/${directory}" + rm -rf "${DIR}/${directory}" + mkdir -p $(dirname "${DIR}/${directory}") + cp -r "$dir_path" "${DIR}/${directory}" + if [ -f "${DIR}/${directory}/update.sh" ]; then + echo "executing ${DIR}/${directory}/update.sh" + "${DIR}/${directory}/update.sh" + fi + echo "***********************************************************************************" + echo "" +done < "$CONFIG_FILE" + +rm -rf "${OUT}" diff --git a/cmd/controller/cmd/root.go b/cmd/controller/cmd/root.go new file mode 100644 index 000000000..55923eb79 --- /dev/null +++ b/cmd/controller/cmd/root.go @@ -0,0 +1,235 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + "sigs.k8s.io/controller-runtime/pkg/cache" + + "sigs.k8s.io/controller-runtime/pkg/client" + + "sigs.k8s.io/controller-runtime/pkg/manager" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flyteplugins/go/tasks" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytestdlib/config/viper" + "github.com/lyft/flytestdlib/version" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/profutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/operator-framework/operator-sdk/pkg/util/k8sutil" + "github.com/pkg/errors" + "github.com/spf13/pflag" + + "github.com/spf13/cobra" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/clientcmd" + + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + informers "github.com/lyft/flytepropeller/pkg/client/informers/externalversions" + "github.com/lyft/flytepropeller/pkg/controller" + "github.com/lyft/flytepropeller/pkg/signals" + restclient "k8s.io/client-go/rest" +) + +const ( + defaultNamespace = "all" + appName = "flytepropeller" +) + +var ( + cfgFile string + configAccessor = viper.NewAccessor(config.Options{StrictMode: true}) +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: "flyte-propeller", + Short: "Operator for running Flyte Workflows", + Long: `Flyte Propeller runs a workflow to completion by recursing through the nodes, + handling their tasks to completion and propagating their status upstream.`, + PreRunE: initConfig, + Run: func(cmd *cobra.Command, args []string) { + executeRootCmd(config2.GetConfig()) + }, +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + version.LogBuildInformation(appName) + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func init() { + // allows `$ flytepropeller --logtostderr` to work + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + err := flag.CommandLine.Parse([]string{}) + if err != nil { + logAndExit(err) + } + + // Here you will define your flags and configuration settings. Cobra supports persistent flags, which, if defined + // here, will be global for your application. + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", + "config file (default is $HOME/config.yaml)") + + configAccessor.InitializePflags(rootCmd.PersistentFlags()) +} + +func initConfig(_ *cobra.Command, _ []string) error { + configAccessor = viper.NewAccessor(config.Options{ + StrictMode: true, + SearchPaths: []string{cfgFile}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + if err != nil { + return err + } + + // Operator-SDK expects kube config to be in KUBERNETES_CONFIG env var. + controllerCfg := config2.GetConfig() + if controllerCfg.KubeConfigPath != "" { + fmt.Printf("Setting env variable for operator-sdk, %v\n", controllerCfg.KubeConfigPath) + return os.Setenv(k8sutil.KubeConfigEnvVar, os.ExpandEnv(controllerCfg.KubeConfigPath)) + } + + fmt.Printf("Started in-cluster mode\n") + return nil +} + +func logAndExit(err error) { + logger.Error(context.Background(), err) + os.Exit(-1) +} + +func getKubeConfig(_ context.Context, cfg *config2.Config) (*kubernetes.Clientset, *restclient.Config, error) { + var kubecfg *restclient.Config + var err error + if cfg.KubeConfigPath != "" { + kubeConfigPath := os.ExpandEnv(cfg.KubeConfigPath) + kubecfg, err = clientcmd.BuildConfigFromFlags(cfg.MasterURL, kubeConfigPath) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubeconfig") + } + } else { + kubecfg, err = restclient.InClusterConfig() + if err != nil { + return nil, nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") + } + } + + kubeClient, err := kubernetes.NewForConfig(kubecfg) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") + } + return kubeClient, kubecfg, err +} + +func sharedInformerOptions(cfg *config2.Config) []informers.SharedInformerOption { + opts := []informers.SharedInformerOption{ + informers.WithTweakListOptions(func(options *v1.ListOptions) { + options.LabelSelector = v1.FormatLabelSelector(controller.IgnoreCompletedWorkflowsLabelSelector()) + }), + } + if cfg.LimitNamespace != defaultNamespace { + opts = append(opts, informers.WithNamespace(cfg.LimitNamespace)) + } + return opts +} + +func executeRootCmd(cfg *config2.Config) { + baseCtx := context.TODO() + + // set up signals so we handle the first shutdown signal gracefully + ctx := signals.SetupSignalHandler(baseCtx) + + kubeClient, kubecfg, err := getKubeConfig(ctx, cfg) + if err != nil { + logger.Fatalf(ctx, "Error building kubernetes clientset: %s", err.Error()) + } + + flyteworkflowClient, err := clientset.NewForConfig(kubecfg) + if err != nil { + logger.Fatalf(ctx, "Error building example clientset: %s", err.Error()) + } + + opts := sharedInformerOptions(cfg) + flyteworkflowInformerFactory := informers.NewSharedInformerFactoryWithOptions(flyteworkflowClient, cfg.WorkflowReEval.Duration, opts...) + + // Add the propeller subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. + propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(cfg.LimitNamespace) + + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers(ctx, cfg.ProfilerPort.Port, nil) + if err != nil { + logger.Panicf(ctx, "Failed to Start profiling and metrics server. Error: %v", err) + } + }() + + limitNamespace := "" + if cfg.LimitNamespace != defaultNamespace { + limitNamespace = cfg.LimitNamespace + } + + err = flytek8s.Initialize(ctx, limitNamespace, cfg.DownstreamEval.Duration) + if err != nil { + logger.Panicf(ctx, "Failed to initialize k8s plugins. Error: %v", err) + } + + if err := tasks.Load(ctx); err != nil { + logger.Fatalf(ctx, "Failed to load task plugins. [%v]", err) + } + + mgr, err := manager.New(kubecfg, manager.Options{ + Namespace: limitNamespace, + SyncPeriod: &cfg.DownstreamEval.Duration, + NewClient: func(cache cache.Cache, config *restclient.Config, options client.Options) (i client.Client, e error) { + rawClient, err := client.New(kubecfg, client.Options{}) + if err != nil { + return nil, err + } + + return executors.NewFallbackClient(&client.DelegatingClient{ + Reader: &client.DelegatingReader{ + CacheReader: cache, + ClientReader: rawClient, + }, + Writer: rawClient, + StatusClient: rawClient, + }, rawClient), nil + }, + }) + if err != nil { + logger.Fatalf(ctx, "Failed to initialize controller run-time manager. Error: %v", err) + } + + c, err := controller.New(ctx, cfg, kubeClient, flyteworkflowClient, flyteworkflowInformerFactory, mgr, propellerScope) + + if err != nil { + logger.Fatalf(ctx, "Failed to start Controller - [%v]", err.Error()) + } else if c == nil { + logger.Fatalf(ctx, "Failed to start Controller, nil controller received.") + } + + go flyteworkflowInformerFactory.Start(ctx.Done()) + + if err = c.Run(ctx); err != nil { + logger.Fatalf(ctx, "Error running controller: %s", err.Error()) + } +} diff --git a/cmd/controller/main.go b/cmd/controller/main.go new file mode 100644 index 000000000..6b551ff73 --- /dev/null +++ b/cmd/controller/main.go @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/lyft/flytepropeller/cmd/controller/cmd" +) + +func main() { + cmd.Execute() +} diff --git a/cmd/kubectl-flyte/cmd/compile.go b/cmd/kubectl-flyte/cmd/compile.go new file mode 100644 index 000000000..f43f164cf --- /dev/null +++ b/cmd/kubectl-flyte/cmd/compile.go @@ -0,0 +1,106 @@ +package cmd + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + compilerErrors "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +type CompileOpts struct { + *RootOptions + inputFormat format + outputFormat format + protoFile string + outputPath string + dumpClosureYaml bool +} + +func NewCompileCommand(opts *RootOptions) *cobra.Command { + + compileOpts := &CompileOpts{ + RootOptions: opts, + } + compileCmd := &cobra.Command{ + Use: "compile", + Aliases: []string{"new", "compile"}, + Short: "Compile a workflow from core proto-buffer files and output a closure.", + Long: ``, + RunE: func(cmd *cobra.Command, args []string) error { + if err := requiredFlags(cmd, protofileKey, formatKey); err != nil { + return err + } + + fmt.Println("Line numbers in errors enabled") + compilerErrors.SetIncludeSource() + + return compileOpts.compileWorkflowCmd() + }, + } + + compileCmd.Flags().StringVarP(&compileOpts.protoFile, "input-file", "i", "", "Path of the workflow package proto-buffer file to be uploaded") + compileCmd.Flags().StringVarP(&compileOpts.inputFormat, "input-format", "f", formatProto, "Format of the provided file. Supported formats: proto (default), json, yaml") + compileCmd.Flags().StringVarP(&compileOpts.outputPath, "output-file", "o", "", "Path of the generated output file.") + compileCmd.Flags().StringVarP(&compileOpts.outputFormat, "output-format", "m", formatProto, "Format of the generated file. Supported formats: proto (default), json, yaml") + compileCmd.Flags().BoolVarP(&compileOpts.dumpClosureYaml, "dump-closure-yaml", "d", false, "Compiles and transforms, but does not create a workflow. OutputsRef ts to STDOUT.") + + return compileCmd +} + +func (c *CompileOpts) compileWorkflowCmd() error { + if c.protoFile == "" { + return errors.Errorf("Input file not specified") + } + fmt.Printf("Received protofiles : [%v].\n", c.protoFile) + + rawWf, err := ioutil.ReadFile(c.protoFile) + if err != nil { + return err + } + + wfClosure := core.WorkflowClosure{} + err = unmarshal(rawWf, c.inputFormat, &wfClosure) + if err != nil { + return errors.Wrapf(err, "Failed to unmarshal input Workflow") + } + + if c.dumpClosureYaml { + b, err := marshal(&wfClosure, formatYaml) + if err != nil { + return err + } + err = ioutil.WriteFile(c.protoFile+".yaml", b, os.ModePerm) + if err != nil { + return err + } + } + + compiledTasks, err := compileTasks(wfClosure.Tasks) + if err != nil { + return err + } + + compileWfClosure, err := compiler.CompileWorkflow(wfClosure.Workflow, []*core.WorkflowTemplate{}, compiledTasks, []common.InterfaceProvider{}) + if err != nil { + return err + } + + fmt.Printf("Workflow compiled successfully, creating output location: [%v] format [%v]\n", c.outputPath, c.outputFormat) + + o, err := marshal(compileWfClosure, c.outputFormat) + if err != nil { + return errors.Wrapf(err, "Failed to marshal final workflow.") + } + + if c.outputPath != "" { + return ioutil.WriteFile(c.outputPath, o, os.ModePerm) + } + fmt.Printf("%v", string(o)) + return nil +} diff --git a/cmd/kubectl-flyte/cmd/create.go b/cmd/kubectl-flyte/cmd/create.go new file mode 100644 index 000000000..d74219382 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/create.go @@ -0,0 +1,228 @@ +package cmd + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + + "github.com/ghodss/yaml" + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + compilerErrors "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +const ( + protofileKey = "proto-path" + formatKey = "format" + executionIDKey = "execution-id" + inputsKey = "input-path" + annotationsKey = "annotations" +) + +type format = string + +const ( + formatProto format = "proto" + formatJSON format = "json" + formatYaml format = "yaml" +) + +const createCmdName = "create" + +type CreateOpts struct { + *RootOptions + format format + execID string + inputsPath string + protoFile string + annotations *stringMapValue + dryRun bool +} + +func NewCreateCommand(opts *RootOptions) *cobra.Command { + + createOpts := &CreateOpts{ + RootOptions: opts, + } + + createCmd := &cobra.Command{ + Use: createCmdName, + Aliases: []string{"new", "compile"}, + Short: "Creates a new workflow from proto-buffer files.", + Long: ``, + RunE: func(cmd *cobra.Command, args []string) error { + if err := requiredFlags(cmd, protofileKey, formatKey); err != nil { + return err + } + + fmt.Println("Line numbers in errors enabled") + compilerErrors.SetIncludeSource() + + return createOpts.createWorkflowFromProto() + }, + } + + createCmd.Flags().StringVarP(&createOpts.protoFile, protofileKey, "p", "", "Path of the workflow package proto-buffer file to be uploaded") + createCmd.Flags().StringVarP(&createOpts.format, formatKey, "f", formatProto, "Format of the provided file. Supported formats: proto (default), json, yaml") + createCmd.Flags().StringVarP(&createOpts.execID, executionIDKey, "", "", "Execution Id of the Workflow to create.") + createCmd.Flags().StringVarP(&createOpts.inputsPath, inputsKey, "i", "", "Path to inputs file.") + createOpts.annotations = newStringMapValue() + createCmd.Flags().VarP(createOpts.annotations, annotationsKey, "a", "Defines extra annotations to declare on the created object.") + createCmd.Flags().BoolVarP(&createOpts.dryRun, "dry-run", "d", false, "Compiles and transforms, but does not create a workflow. OutputsRef ts to STDOUT.") + + return createCmd +} + +func unmarshal(in []byte, format format, message proto.Message) (err error) { + switch format { + case formatProto: + err = proto.Unmarshal(in, message) + case formatJSON: + err = jsonpb.Unmarshal(bytes.NewReader(in), message) + if err != nil { + err = errors.Wrapf(err, "Failed to unmarshal converted Json. [%v]", string(in)) + } + case formatYaml: + jsonRaw, err := yaml.YAMLToJSON(in) + if err != nil { + return errors.Wrapf(err, "Failed to convert yaml to JSON. [%v]", string(in)) + } + + return unmarshal(jsonRaw, formatJSON, message) + } + + return +} + +var jsonPbMarshaler = jsonpb.Marshaler{} + +func marshal(message proto.Message, format format) (raw []byte, err error) { + switch format { + case formatProto: + return proto.Marshal(message) + case formatJSON: + b := &bytes.Buffer{} + err := jsonPbMarshaler.Marshal(b, message) + if err != nil { + return nil, errors.Wrapf(err, "Failed to marshal Json.") + } + return b.Bytes(), nil + case formatYaml: + b, err := marshal(message, formatJSON) + if err != nil { + return nil, errors.Wrapf(err, "Failed to marshal JSON") + } + return yaml.JSONToYAML(b) + } + return nil, errors.Errorf("Unknown format type") +} + +func loadInputs(path string, format format) (c *core.LiteralMap, err error) { + // Support reading from s3, etc.? + var raw []byte + raw, err = ioutil.ReadFile(path) + if err != nil { + return + } + + c = &core.LiteralMap{} + err = unmarshal(raw, format, c) + return +} + +func compileTasks(tasks []*core.TaskTemplate) ([]*core.CompiledTask, error) { + res := make([]*core.CompiledTask, 0, len(tasks)) + for _, task := range tasks { + compiledTask, err := compiler.CompileTask(task) + if err != nil { + return nil, err + } + + res = append(res, compiledTask) + } + + return res, nil +} + +func (c *CreateOpts) createWorkflowFromProto() error { + fmt.Printf("Received protofiles : [%v] [%v].\n", c.protoFile, c.inputsPath) + rawWf, err := ioutil.ReadFile(c.protoFile) + if err != nil { + return err + } + + wfClosure := core.WorkflowClosure{} + err = unmarshal(rawWf, c.format, &wfClosure) + if err != nil { + return err + } + + compiledTasks, err := compileTasks(wfClosure.Tasks) + if err != nil { + return err + } + + wf, err := compiler.CompileWorkflow(wfClosure.Workflow, []*core.WorkflowTemplate{}, compiledTasks, []common.InterfaceProvider{}) + if err != nil { + return err + } + + var inputs *core.LiteralMap + if c.inputsPath != "" { + inputs, err = loadInputs(c.inputsPath, c.format) + if err != nil { + return errors.Wrapf(err, "Failed to load inputs.") + } + } + + var executionID *core.WorkflowExecutionIdentifier + if len(c.execID) > 0 { + executionID = &core.WorkflowExecutionIdentifier{ + Name: c.execID, + Domain: wfClosure.Workflow.Id.Domain, + Project: wfClosure.Workflow.Id.Project, + } + } + + flyteWf, err := k8s.BuildFlyteWorkflow(wf, inputs, executionID, c.ConfigOverrides.Context.Namespace) + if err != nil { + return err + } + if flyteWf.Annotations == nil { + flyteWf.Annotations = *c.annotations.value + } else { + for key, val := range *c.annotations.value { + flyteWf.Annotations[key] = val + } + } + + if c.dryRun { + fmt.Printf("Dry Run mode enabled. Printing the compiled workflow.") + j, err := json.Marshal(flyteWf) + if err != nil { + return errors.Wrapf(err, "Failed to marshal final workflow to Propeller format.") + } + y, err := yaml.JSONToYAML(j) + if err != nil { + return errors.Wrapf(err, "Failed to marshal final workflow from json to yaml.") + } + fmt.Println(string(y)) + } else { + wf, err := c.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(c.ConfigOverrides.Context.Namespace).Create(flyteWf) + if err != nil { + return err + } + + fmt.Printf("Successfully created Flyte Workflow %v.\n", wf.Name) + } + return nil +} diff --git a/cmd/kubectl-flyte/cmd/create_test.go b/cmd/kubectl-flyte/cmd/create_test.go new file mode 100644 index 000000000..d23485791 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/create_test.go @@ -0,0 +1,292 @@ +package cmd + +import ( + "encoding/json" + "flag" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/ghodss/yaml" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +var update = flag.Bool("update", false, "Update .golden files") + +func init() { +} + +func createEmptyVariableMap() *core.VariableMap { + res := &core.VariableMap{ + Variables: map[string]*core.Variable{}, + } + return res +} + +func createVariableMap(variableMap map[string]*core.Variable) *core.VariableMap { + res := &core.VariableMap{ + Variables: variableMap, + } + return res +} + +func TestCreate(t *testing.T) { + t.Run("Generate simple workflow", generateSimpleWorkflow) + t.Run("Generate workflow with inputs", generateWorkflowWithInputs) + t.Run("Compile", testCompile) +} + +func generateSimpleWorkflow(t *testing.T) { + if !*update { + t.SkipNow() + } + + t.Log("Generating golden files.") + closure := core.WorkflowClosure{ + Workflow: &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "workflow-id-123"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Nodes: []*core.Node{ + { + Id: "node-1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-1"}, + }, + }, + }, + }, + { + Id: "node-2", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-2"}, + }, + }, + }, + }, + }, + }, + Tasks: []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task-1"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + }, + }, + }, + { + Id: &core.Identifier{Name: "task-2"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + }, + }, + }, + }, + } + + marshaller := &jsonpb.Marshaler{} + s, err := marshaller.MarshalToString(&closure) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "workflow.json.golden"), []byte(s), os.ModePerm)) + + m := map[string]interface{}{} + err = json.Unmarshal([]byte(s), &m) + assert.NoError(t, err) + + b, err := yaml.Marshal(m) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "workflow.yaml.golden"), b, os.ModePerm)) + + raw, err := proto.Marshal(&closure) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "workflow.pb.golden"), raw, os.ModePerm)) +} + +func generateWorkflowWithInputs(t *testing.T) { + if !*update { + t.SkipNow() + } + + t.Log("Generating golden files.") + closure := core.WorkflowClosure{ + Workflow: &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "workflow-with-inputs"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + "y": { + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + }, + }, + }}, + }), + }, + Nodes: []*core.Node{ + { + Id: "node-1", + Inputs: []*core.Binding{ + {Var: "x", Binding: &core.BindingData{Value: &core.BindingData_Promise{Promise: &core.OutputReference{Var: "x"}}}}, + {Var: "y", Binding: &core.BindingData{Value: &core.BindingData_Promise{Promise: &core.OutputReference{Var: "y"}}}}, + }, + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-1"}, + }, + }, + }, + }, + { + Id: "node-2", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-2"}, + }, + }, + }, + }, + }, + }, + Tasks: []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task-1"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + "y": { + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + }, + }, + }}, + }), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2"}, + {Name: core.Resources_MEMORY, Value: "2048Mi"}, + }, + }, + }, + }, + }, + { + Id: &core.Identifier{Name: "task-2"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + }, + }, + }, + }, + } + + marshalGolden(t, &closure, "workflow_w_inputs") + sampleInputs := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakeLiteral(2), + "y": utils.MustMakeLiteral([]interface{}{"val1", "val2", "val3"}), + }, + } + + marshalGolden(t, &sampleInputs, "inputs") +} + +func marshalGolden(t *testing.T, message proto.Message, filename string) { + marshaller := &jsonpb.Marshaler{} + s, err := marshaller.MarshalToString(message) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", filename+".json.golden"), []byte(s), os.ModePerm)) + + m := map[string]interface{}{} + err = json.Unmarshal([]byte(s), &m) + assert.NoError(t, err) + + b, err := yaml.Marshal(m) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", filename+".yaml.golden"), b, os.ModePerm)) + + raw, err := proto.Marshal(message) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", filename+".pb.golden"), raw, os.ModePerm)) +} + +func testCompile(t *testing.T) { + f := func(t *testing.T, filePath, format string) { + raw, err := ioutil.ReadFile(filepath.Join("testdata", filePath)) + assert.NoError(t, err) + wf := &core.WorkflowClosure{} + err = unmarshal(raw, format, wf) + assert.NoError(t, err) + assert.NotNil(t, wf) + assert.Equal(t, 2, len(wf.Tasks)) + if len(wf.Tasks) == 2 { + c := wf.Tasks[0].GetContainer() + assert.NotNil(t, c) + compiledTasks, err := compileTasks(wf.Tasks) + assert.NoError(t, err) + compiledWf, err := compiler.CompileWorkflow(wf.Workflow, []*core.WorkflowTemplate{}, compiledTasks, []common.InterfaceProvider{}) + assert.NoError(t, err) + _, err = k8s.BuildFlyteWorkflow(compiledWf, nil, nil, "") + assert.NoError(t, err) + } + } + + t.Run("yaml", func(t *testing.T) { + f(t, "workflow.yaml.golden", formatYaml) + }) + + t.Run("json", func(t *testing.T) { + f(t, "workflow.json.golden", formatJSON) + }) + + t.Run("proto", func(t *testing.T) { + f(t, "workflow.pb.golden", formatProto) + }) +} diff --git a/cmd/kubectl-flyte/cmd/delete.go b/cmd/kubectl-flyte/cmd/delete.go new file mode 100644 index 000000000..7ddede1e0 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/delete.go @@ -0,0 +1,87 @@ +package cmd + +import ( + "fmt" + + "github.com/lyft/flytepropeller/pkg/controller" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type DeleteOpts struct { + *RootOptions + force bool + allCompleted bool + chunkSize int64 + limit int64 +} + +func NewDeleteCommand(opts *RootOptions) *cobra.Command { + + deleteOpts := &DeleteOpts{ + RootOptions: opts, + } + + // deleteCmd represents the delete command + deleteCmd := &cobra.Command{ + Use: "delete [workflow-name]", + Short: "delete a workflow", + Long: ``, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + name := args[0] + return deleteOpts.deleteWorkflow(name) + } + + return deleteOpts.deleteCompletedWorkflows() + }, + } + + deleteCmd.Flags().BoolVarP(&deleteOpts.force, "force", "f", false, "Enable force deletion to remove finalizers from a workflow.") + deleteCmd.Flags().BoolVarP(&deleteOpts.allCompleted, "all-completed", "a", false, "Delete all the workflows that have completed. Cannot be used with --force.") + deleteCmd.Flags().Int64VarP(&deleteOpts.chunkSize, "chunk-size", "c", 100, "When using all-completed, provide a chunk size to retrieve at once from the server.") + deleteCmd.Flags().Int64VarP(&deleteOpts.limit, "limit", "l", -1, "Only iterate over max limit records.") + + return deleteCmd +} + +func (d *DeleteOpts) deleteCompletedWorkflows() error { + if d.force && d.allCompleted { + return fmt.Errorf("cannot delete multiple workflows with --force") + } + if !d.allCompleted { + return fmt.Errorf("all completed | workflow name is required") + } + + t, err := d.GetTimeoutSeconds() + if err != nil { + return err + } + + p := v1.DeletePropagationBackground + return d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).DeleteCollection( + &v1.DeleteOptions{PropagationPolicy: &p}, v1.ListOptions{ + TimeoutSeconds: &t, + LabelSelector: v1.FormatLabelSelector(controller.CompletedWorkflowsLabelSelector()), + }, + ) + +} + +func (d *DeleteOpts) deleteWorkflow(name string) error { + p := v1.DeletePropagationBackground + if err := d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).Delete(name, &v1.DeleteOptions{PropagationPolicy: &p}); err != nil { + return err + } + if d.force { + w, err := d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).Get(name, v1.GetOptions{}) + if err != nil { + return err + } + w.SetFinalizers([]string{}) + if _, err := d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).Update(w); err != nil { + return err + } + } + return nil +} diff --git a/cmd/kubectl-flyte/cmd/get.go b/cmd/kubectl-flyte/cmd/get.go new file mode 100644 index 000000000..b1460c340 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/get.go @@ -0,0 +1,141 @@ +package cmd + +import ( + "fmt" + "strings" + + gotree "github.com/DiSiqueira/GoTree" + "github.com/lyft/flytepropeller/cmd/kubectl-flyte/cmd/printers" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type GetOpts struct { + *RootOptions + detailsEnabledFlag bool + limit int64 + chunkSize int64 +} + +func NewGetCommand(opts *RootOptions) *cobra.Command { + + getOpts := &GetOpts{ + RootOptions: opts, + } + + getCmd := &cobra.Command{ + Use: "get [opts] []", + Short: "Gets a single workflow or lists all workflows currently in execution", + Long: `use labels to filter`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + name := args[0] + return getOpts.getWorkflow(name) + } + return getOpts.listWorkflows() + }, + } + + getCmd.Flags().BoolVarP(&getOpts.detailsEnabledFlag, "details", "d", false, "If details of node execs are desired.") + getCmd.Flags().Int64VarP(&getOpts.chunkSize, "chunk-size", "c", 100, "Use this much batch size.") + getCmd.Flags().Int64VarP(&getOpts.limit, "limit", "l", -1, "Only get limit records. -1 => all records.") + + return getCmd +} + +func (g *GetOpts) getWorkflow(name string) error { + parts := strings.Split(name, "/") + if len(parts) > 1 { + g.ConfigOverrides.Context.Namespace = parts[0] + name = parts[1] + } + w, err := g.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(g.ConfigOverrides.Context.Namespace).Get(name, v1.GetOptions{}) + if err != nil { + return err + } + wp := printers.WorkflowPrinter{} + tree := gotree.New("Workflow") + if err := wp.Print(tree, w); err != nil { + return err + } + fmt.Print(tree.Print()) + return nil +} + +func (g *GetOpts) iterateOverWorkflows(f func(*v1alpha1.FlyteWorkflow) error, batchSize int64, limit int64) error { + if limit > 0 && limit < batchSize { + batchSize = limit + } + t, err := g.GetTimeoutSeconds() + if err != nil { + return err + } + opts := &v1.ListOptions{ + Limit: batchSize, + TimeoutSeconds: &t, + } + var counter int64 + for { + wList, err := g.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(g.ConfigOverrides.Context.Namespace).List(*opts) + if err != nil { + return err + } + for _, w := range wList.Items { + if err := f(&w); err != nil { + return err + } + counter++ + if counter == limit { + return nil + } + } + if wList.Continue == "" { + return nil + } + opts.Continue = wList.Continue + } +} + +func (g *GetOpts) listWorkflows() error { + fmt.Printf("Listing workflows in [%s]\n", g.ConfigOverrides.Context.Namespace) + wp := printers.WorkflowPrinter{} + workflows := gotree.New("workflows") + var counter int64 + var succeeded = 0 + var failed = 0 + var running = 0 + var waiting = 0 + err := g.iterateOverWorkflows( + func(w *v1alpha1.FlyteWorkflow) error { + counter++ + if err := wp.PrintShort(workflows, w); err != nil { + return err + } + switch w.GetExecutionStatus().GetPhase() { + case v1alpha1.WorkflowPhaseReady: + waiting++ + case v1alpha1.WorkflowPhaseSuccess: + succeeded++ + case v1alpha1.WorkflowPhaseFailed: + failed++ + default: + running++ + } + if counter%g.chunkSize == 0 { + fmt.Println("") + fmt.Print(workflows.Print()) + workflows = gotree.New("\nworkflows") + } else { + fmt.Print(".") + } + return nil + }, g.chunkSize, g.limit) + if err != nil { + return err + } + fmt.Print(workflows.Print()) + fmt.Printf("Found %d workflows\n", counter) + fmt.Printf("Sucess: %d, Failed: %d, Running: %d, Waiting: %d\n", succeeded, failed, running, waiting) + return nil +} diff --git a/cmd/kubectl-flyte/cmd/printers/node.go b/cmd/kubectl-flyte/cmd/printers/node.go new file mode 100644 index 000000000..d1c935049 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/printers/node.go @@ -0,0 +1,157 @@ +package printers + +import ( + "fmt" + "strconv" + "strings" + "time" + + "k8s.io/apimachinery/pkg/util/sets" + + gotree "github.com/DiSiqueira/GoTree" + "github.com/fatih/color" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytepropeller/pkg/utils" +) + +var boldString = color.New(color.Bold) + +func ColorizeNodePhase(p v1alpha1.NodePhase) string { + switch p { + case v1alpha1.NodePhaseNotYetStarted: + return p.String() + case v1alpha1.NodePhaseRunning: + return color.YellowString("%s", p.String()) + case v1alpha1.NodePhaseSucceeded: + return color.HiGreenString("%s", p.String()) + case v1alpha1.NodePhaseFailed: + return color.HiRedString("%s", p.String()) + } + return color.CyanString("%s", p.String()) +} + +func CalculateRuntime(s v1alpha1.ExecutableNodeStatus) string { + if s.GetStartedAt() != nil { + if s.GetStoppedAt() != nil { + return s.GetStoppedAt().Sub(s.GetStartedAt().Time).String() + } + return time.Since(s.GetStartedAt().Time).String() + } + return "na" +} + +type NodePrinter struct { + NodeStatusPrinter +} + +func (p NodeStatusPrinter) BaseNodeInfo(node v1alpha1.BaseNode, nodeStatus v1alpha1.ExecutableNodeStatus) []string { + return []string{ + fmt.Sprintf("%s (%s)", boldString.Sprint(node.GetID()), node.GetKind().String()), + CalculateRuntime(nodeStatus), + ColorizeNodePhase(nodeStatus.GetPhase()), + nodeStatus.GetMessage(), + } +} + +func (p NodeStatusPrinter) NodeInfo(wName string, node v1alpha1.BaseNode, nodeStatus v1alpha1.ExecutableNodeStatus) []string { + resourceName, err := utils.FixedLengthUniqueIDForParts(task.IDMaxLength, wName, node.GetID(), strconv.Itoa(int(nodeStatus.GetAttempts()))) + if err != nil { + resourceName = "na" + } + return append( + p.BaseNodeInfo(node, nodeStatus), + fmt.Sprintf("resource=%s", resourceName), + ) +} + +func (p NodePrinter) BranchNodeInfo(node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) []string { + info := p.BaseNodeInfo(node, nodeStatus) + branchStatus := nodeStatus.GetOrCreateBranchStatus() + info = append(info, branchStatus.GetPhase().String()) + if branchStatus.GetFinalizedNode() != nil { + info = append(info, *branchStatus.GetFinalizedNode()) + } + return info + +} + +func (p NodePrinter) traverseNode(tree gotree.Tree, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) error { + switch node.GetKind() { + case v1alpha1.NodeKindBranch: + subTree := tree.Add(strings.Join(p.BranchNodeInfo(node, nodeStatus), " | ")) + f := func(nodeID *v1alpha1.NodeID) error { + if nodeID != nil { + ifNode, ok := w.GetNode(*nodeID) + if !ok { + return fmt.Errorf("failed to find branch node %s", *nodeID) + } + if err := p.traverseNode(subTree, w, ifNode, nodeStatus.GetNodeExecutionStatus(*nodeID)); err != nil { + return err + } + } + return nil + } + if err := f(node.GetBranchNode().GetIf().GetThenNode()); err != nil { + return err + } + if len(node.GetBranchNode().GetElseIf()) > 0 { + for _, n := range node.GetBranchNode().GetElseIf() { + if err := f(n.GetThenNode()); err != nil { + return err + } + } + } + if err := f(node.GetBranchNode().GetElse()); err != nil { + return err + } + case v1alpha1.NodeKindWorkflow: + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + s := w.FindSubWorkflow(*node.GetWorkflowNode().GetSubWorkflowRef()) + wp := WorkflowPrinter{} + cw := executors.NewSubContextualWorkflow(w, s, nodeStatus) + return wp.Print(tree, cw) + } + case v1alpha1.NodeKindTask: + sub := tree.Add(strings.Join(p.NodeInfo(w.GetName(), node, nodeStatus), " | ")) + if err := p.PrintRecursive(sub, w.GetName(), nodeStatus); err != nil { + return err + } + default: + _ = tree.Add(strings.Join(p.NodeInfo(w.GetName(), node, nodeStatus), " | ")) + } + return nil +} + +func (p NodePrinter) PrintList(tree gotree.Tree, w v1alpha1.ExecutableWorkflow, nodes []v1alpha1.ExecutableNode) error { + for _, n := range nodes { + s := w.GetNodeExecutionStatus(n.GetID()) + if err := p.traverseNode(tree, w, n, s); err != nil { + return err + } + } + return nil +} + +type NodeStatusPrinter struct { +} + +func (p NodeStatusPrinter) PrintRecursive(tree gotree.Tree, wfName string, s v1alpha1.ExecutableNodeStatus) error { + orderedKeys := sets.String{} + allStatuses := map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus{} + s.VisitNodeStatuses(func(node v1alpha1.NodeID, status v1alpha1.ExecutableNodeStatus) { + orderedKeys.Insert(node) + allStatuses[node] = status + }) + + for _, id := range orderedKeys.List() { + ns := allStatuses[id] + sub := tree.Add(strings.Join(p.NodeInfo(wfName, &v1alpha1.NodeSpec{ID: id}, ns), " | ")) + if err := p.PrintRecursive(sub, wfName, ns); err != nil { + return err + } + } + + return nil +} diff --git a/cmd/kubectl-flyte/cmd/printers/workflow.go b/cmd/kubectl-flyte/cmd/printers/workflow.go new file mode 100644 index 000000000..f15da36bd --- /dev/null +++ b/cmd/kubectl-flyte/cmd/printers/workflow.go @@ -0,0 +1,63 @@ +package printers + +import ( + "fmt" + "time" + + gotree "github.com/DiSiqueira/GoTree" + "github.com/fatih/color" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/visualize" +) + +func ColorizeWorkflowPhase(p v1alpha1.WorkflowPhase) string { + switch p { + case v1alpha1.WorkflowPhaseReady: + return p.String() + case v1alpha1.WorkflowPhaseRunning: + return color.YellowString("%s", p.String()) + case v1alpha1.WorkflowPhaseSuccess: + return color.HiGreenString("%s", p.String()) + case v1alpha1.WorkflowPhaseFailed: + return color.HiRedString("%s", p.String()) + } + return color.CyanString("%s", p.String()) +} + +func CalculateWorkflowRuntime(s v1alpha1.ExecutableWorkflowStatus) string { + if s.GetStartedAt() != nil { + if s.GetStoppedAt() != nil { + return s.GetStoppedAt().Sub(s.GetStartedAt().Time).String() + } + return time.Since(s.GetStartedAt().Time).String() + } + return "na" +} + +type WorkflowPrinter struct { +} + +func (p WorkflowPrinter) Print(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { + sortedNodes, err := visualize.TopologicalSort(w) + if err != nil { + return err + } + newTree := gotree.New(fmt.Sprintf("%s/%s [ExecId: %s] (%s %s %s)", + w.GetNamespace(), boldString.Sprint(w.GetName()), w.GetExecutionID(), CalculateWorkflowRuntime(w.GetExecutionStatus()), + ColorizeWorkflowPhase(w.GetExecutionStatus().GetPhase()), w.GetExecutionStatus().GetMessage())) + if tree != nil { + tree.AddTree(newTree) + } + np := NodePrinter{} + return np.PrintList(newTree, w, sortedNodes) +} + +func (p WorkflowPrinter) PrintShort(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { + if tree == nil { + return fmt.Errorf("bad state in printer") + } + tree.Add(fmt.Sprintf("%s/%s [ExecId: %s] (%s %s) - Time SinceCreation(%s)", + w.GetNamespace(), boldString.Sprint(w.GetName()), w.GetExecutionID(), CalculateWorkflowRuntime(w.GetExecutionStatus()), + ColorizeWorkflowPhase(w.GetExecutionStatus().GetPhase()), time.Since(w.GetCreationTimestamp().Time))) + return nil +} diff --git a/cmd/kubectl-flyte/cmd/root.go b/cmd/kubectl-flyte/cmd/root.go new file mode 100644 index 000000000..f56009f43 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/root.go @@ -0,0 +1,116 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "os" + "runtime" + "time" + + "github.com/lyft/flytestdlib/config/viper" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/version" + "github.com/spf13/pflag" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + + flyteclient "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + "github.com/spf13/cobra" +) + +func init() { + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + err := flag.CommandLine.Parse([]string{}) + if err != nil { + logger.Error(context.TODO(), "Error in initializing: %v", err) + os.Exit(-1) + } +} + +type RootOptions struct { + *clientcmd.ConfigOverrides + allNamespaces bool + showSource bool + clientConfig clientcmd.ClientConfig + restConfig *rest.Config + kubeClient kubernetes.Interface + flyteClient flyteclient.Interface +} + +func (r *RootOptions) GetTimeoutSeconds() (int64, error) { + if r.Timeout != "" { + d, err := time.ParseDuration(r.Timeout) + if err != nil { + return 10, err + } + return int64(d.Seconds()), nil + } + return 10, nil + +} + +func (r *RootOptions) executeRootCmd() error { + ctx := context.TODO() + logger.Infof(ctx, "Go Version: %s", runtime.Version()) + logger.Infof(ctx, "Go OS/Arch: %s/%s", runtime.GOOS, runtime.GOARCH) + version.LogBuildInformation("kubectl-flyte") + return fmt.Errorf("use one of the sub-commands") +} + +func (r *RootOptions) ConfigureClient() error { + restConfig, err := r.clientConfig.ClientConfig() + if err != nil { + return err + } + k, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return err + } + fc, err := flyteclient.NewForConfig(restConfig) + if err != nil { + return err + } + r.restConfig = restConfig + r.kubeClient = k + r.flyteClient = fc + return nil +} + +// NewCommand returns a new instance of an argo command +func NewFlyteCommand() *cobra.Command { + rootOpts := &RootOptions{} + command := &cobra.Command{ + Use: "kubectl-flyte", + Short: "kubectl-flyte allows launching and managing K8s native workflows", + Long: `Flyte is a serverless workflow processing platform built for native execution on K8s. + It is extensible and flexible to allow adding new operators and comes with many operators built in`, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return rootOpts.ConfigureClient() + }, + RunE: func(cmd *cobra.Command, args []string) error { + return rootOpts.executeRootCmd() + }, + } + + command.AddCommand(NewDeleteCommand(rootOpts)) + command.AddCommand(NewGetCommand(rootOpts)) + command.AddCommand(NewVisualizeCommand(rootOpts)) + command.AddCommand(NewCreateCommand(rootOpts)) + command.AddCommand(NewCompileCommand(rootOpts)) + + loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() + loadingRules.DefaultClientConfig = &clientcmd.DefaultClientConfig + rootOpts.ConfigOverrides = &clientcmd.ConfigOverrides{} + kflags := clientcmd.RecommendedConfigOverrideFlags("") + command.PersistentFlags().StringVar(&loadingRules.ExplicitPath, "kubeconfig", "", "Path to a kube config. Only required if out-of-cluster") + clientcmd.BindOverrideFlags(rootOpts.ConfigOverrides, command.PersistentFlags(), kflags) + rootOpts.clientConfig = clientcmd.NewInteractiveDeferredLoadingClientConfig(loadingRules, rootOpts.ConfigOverrides, os.Stdin) + + command.PersistentFlags().BoolVar(&rootOpts.allNamespaces, "all-namespaces", false, "Enable this flag to execute for all namespaces") + command.PersistentFlags().BoolVarP(&rootOpts.showSource, "show-source", "s", false, "Show line number for errors") + command.AddCommand(viper.GetConfigCommand()) + + return command +} diff --git a/cmd/kubectl-flyte/cmd/string_map_value.go b/cmd/kubectl-flyte/cmd/string_map_value.go new file mode 100644 index 000000000..6f067ab20 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/string_map_value.go @@ -0,0 +1,83 @@ +package cmd + +import ( + "bytes" + "fmt" + "regexp" + "strings" +) + +// Represents a pflag value that parses a string into a map +type stringMapValue struct { + value *map[string]string + changed bool +} + +func newStringMapValue() *stringMapValue { + return &stringMapValue{ + value: &map[string]string{}, + changed: false, + } +} + +var entryRegex = regexp.MustCompile("(?P[^,]+)=(?P[^,]+)") + +// Parses val into a map. Accepted format: a=1,b=2 +func (s *stringMapValue) Set(val string) error { + matches := entryRegex.FindAllStringSubmatch(val, -1) + out := make(map[string]string, len(matches)) + for _, entry := range matches { + if len(entry) != 3 { + return fmt.Errorf("invalid value for entry. Entries must be formatted as key=value. Found %v", + entry) + } + + out[strings.TrimSpace(entry[1])] = strings.TrimSpace(entry[2]) + } + + if !s.changed { + *s.value = out + } else { + for k, v := range out { + (*s.value)[k] = v + } + } + s.changed = true + return nil +} + +func (s *stringMapValue) Type() string { + return "stringToString" +} + +func (s *stringMapValue) String() string { + var buf bytes.Buffer + i := 0 + for k, v := range *s.value { + if i > 0 { + _, err := buf.WriteRune(',') + if err != nil { + return "" + } + } + + _, err := buf.WriteString(k) + if err != nil { + return "" + } + + _, err = buf.WriteRune('=') + if err != nil { + return "" + } + + _, err = buf.WriteString(v) + if err != nil { + return "" + } + + i++ + } + + return "[" + buf.String() + "]" +} diff --git a/cmd/kubectl-flyte/cmd/string_map_value_test.go b/cmd/kubectl-flyte/cmd/string_map_value_test.go new file mode 100644 index 000000000..ff3ae8d17 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/string_map_value_test.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func formatArg(values map[string]string) string { + res := "" + for key, value := range values { + res += fmt.Sprintf(",%v%v%v=%v%v%v", randSpaces(), key, randSpaces(), randSpaces(), value, randSpaces()) + } + + if len(values) > 0 { + return res[1:] + } + + return res +} + +func randSpaces() string { + res := "" + for cnt := rand.Int() % 10; cnt > 0; cnt-- { // nolint: gas + res += " " + } + + return res +} + +func runPositiveTest(t *testing.T, expected map[string]string) { + v := newStringMapValue() + assert.NoError(t, v.Set(formatArg(expected))) + + assert.Equal(t, len(expected), len(*v.value)) + assert.Equal(t, expected, *v.value) +} + +func TestSet(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + expected := map[string]string{ + "a": "1", + "b": "2", + "c": "3", + "d.sub": "x.y", + "e": "4", + } + runPositiveTest(t, expected) + }) + + t.Run("Empty arg", func(t *testing.T) { + expected := map[string]string{ + "": "", + "a": "1", + "b": "2", + "c": "3", + "d.sub": "x.y", + "e": "4", + } + + runPositiveTest(t, expected) + }) +} diff --git a/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden b/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden new file mode 100755 index 000000000..0af476b90 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden @@ -0,0 +1 @@ +{"literals":{"x":{"scalar":{"primitive":{"integer":"2"}}},"y":{"collection":{"literals":[{"scalar":{"primitive":{"stringValue":"val1"}}},{"scalar":{"primitive":{"stringValue":"val2"}}},{"scalar":{"primitive":{"stringValue":"val3"}}}]}}}} \ No newline at end of file diff --git a/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden b/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden new file mode 100755 index 000000000..5ff2c7009 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden @@ -0,0 +1,19 @@ + + +x + + ++ +y&$ + + + +val1 + + + +val2 + + + +val3 \ No newline at end of file diff --git a/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden new file mode 100755 index 000000000..efeefc56a --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden @@ -0,0 +1,17 @@ +literals: + x: + scalar: + primitive: + integer: "2" + "y": + collection: + literals: + - scalar: + primitive: + stringValue: val1 + - scalar: + primitive: + stringValue: val2 + - scalar: + primitive: + stringValue: val3 diff --git a/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden b/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden new file mode 100755 index 000000000..83c9dc265 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden @@ -0,0 +1 @@ +{"workflow":{"id":{"name":"workflow-id-123"},"interface":{"inputs":{"variables":{}}},"nodes":[{"id":"node-1","taskNode":{"referenceId":{"name":"task-1"}}},{"id":"node-2","taskNode":{"referenceId":{"name":"task-2"}}}]},"tasks":[{"id":{"name":"task-1"},"interface":{"inputs":{"variables":{}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"]}},{"id":{"name":"task-2"},"interface":{"inputs":{"variables":{}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"]}}]} \ No newline at end of file diff --git a/cmd/kubectl-flyte/cmd/testdata/workflow.pb.golden b/cmd/kubectl-flyte/cmd/testdata/workflow.pb.golden new file mode 100755 index 000000000..4bfb48a1d Binary files /dev/null and b/cmd/kubectl-flyte/cmd/testdata/workflow.pb.golden differ diff --git a/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden new file mode 100755 index 000000000..261175022 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden @@ -0,0 +1,38 @@ +tasks: +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + id: + name: task-1 + interface: + inputs: + variables: {} +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + id: + name: task-2 + interface: + inputs: + variables: {} +workflow: + id: + name: workflow-id-123 + interface: + inputs: + variables: {} + nodes: + - id: node-1 + taskNode: + referenceId: + name: task-1 + - id: node-2 + taskNode: + referenceId: + name: task-2 diff --git a/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden b/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden new file mode 100755 index 000000000..08fd7df32 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden @@ -0,0 +1 @@ +{"workflow":{"id":{"name":"workflow-with-inputs"},"interface":{"inputs":{"variables":{"x":{"type":{"simple":"INTEGER"}},"y":{"type":{"collectionType":{"simple":"STRING"}}}}}},"nodes":[{"id":"node-1","inputs":[{"var":"x","binding":{"promise":{"var":"x"}}},{"var":"y","binding":{"promise":{"var":"y"}}}],"taskNode":{"referenceId":{"name":"task-1"}}},{"id":"node-2","taskNode":{"referenceId":{"name":"task-2"}}}]},"tasks":[{"id":{"name":"task-1"},"interface":{"inputs":{"variables":{"x":{"type":{"simple":"INTEGER"}},"y":{"type":{"collectionType":{"simple":"STRING"}}}}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"],"resources":{"requests":[{"name":"CPU","value":"2"},{"name":"MEMORY","value":"2048Mi"}]}}},{"id":{"name":"task-2"},"interface":{"inputs":{"variables":{}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"]}}]} \ No newline at end of file diff --git a/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.pb.golden b/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.pb.golden new file mode 100755 index 000000000..0db9aa296 Binary files /dev/null and b/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.pb.golden differ diff --git a/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden new file mode 100755 index 000000000..ce937c2c9 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden @@ -0,0 +1,67 @@ +tasks: +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + resources: + requests: + - name: CPU + value: "2" + - name: MEMORY + value: 2048Mi + id: + name: task-1 + interface: + inputs: + variables: + x: + type: + simple: INTEGER + "y": + type: + collectionType: + simple: STRING +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + id: + name: task-2 + interface: + inputs: + variables: {} +workflow: + id: + name: workflow-with-inputs + interface: + inputs: + variables: + x: + type: + simple: INTEGER + "y": + type: + collectionType: + simple: STRING + nodes: + - id: node-1 + inputs: + - binding: + promise: + var: x + var: x + - binding: + promise: + var: "y" + var: "y" + taskNode: + referenceId: + name: task-1 + - id: node-2 + taskNode: + referenceId: + name: task-2 diff --git a/cmd/kubectl-flyte/cmd/util.go b/cmd/kubectl-flyte/cmd/util.go new file mode 100644 index 000000000..a2346ae9a --- /dev/null +++ b/cmd/kubectl-flyte/cmd/util.go @@ -0,0 +1,18 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func requiredFlags(cmd *cobra.Command, flags ...string) error { + for _, flag := range flags { + f := cmd.Flag(flag) + if f == nil { + return fmt.Errorf("unable to find Key [%v]", flag) + } + } + + return nil +} diff --git a/cmd/kubectl-flyte/cmd/visualize.go b/cmd/kubectl-flyte/cmd/visualize.go new file mode 100644 index 000000000..561363c3e --- /dev/null +++ b/cmd/kubectl-flyte/cmd/visualize.go @@ -0,0 +1,39 @@ +package cmd + +import ( + "fmt" + + "github.com/lyft/flytepropeller/pkg/visualize" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type VisualizeOpts struct { + *RootOptions +} + +func NewVisualizeCommand(opts *RootOptions) *cobra.Command { + + vizOpts := &VisualizeOpts{ + RootOptions: opts, + } + + visualizeCmd := &cobra.Command{ + Use: "visualize ", + Short: "Get GraphViz dot-formatted output.", + Long: `Generates GraphViz dot-formatted output for the workflow.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + w, err := vizOpts.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(vizOpts.ConfigOverrides.Context.Namespace).Get(name, v1.GetOptions{}) + if err != nil { + return err + } + + fmt.Printf("Dot-formatted: %v\n", visualize.WorkflowToGraphViz(w)) + return nil + }, + } + + return visualizeCmd +} diff --git a/cmd/kubectl-flyte/main.go b/cmd/kubectl-flyte/main.go new file mode 100644 index 000000000..f0e7fa82b --- /dev/null +++ b/cmd/kubectl-flyte/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "os" + + "github.com/lyft/flytepropeller/cmd/kubectl-flyte/cmd" +) + +func main() { + + rootCmd := cmd.NewFlyteCommand() + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..c50dbc4b3 --- /dev/null +++ b/config.yaml @@ -0,0 +1,95 @@ +# This is a sample configuration file. +# Real configuration when running inside K8s (local or otherwise) lives in a ConfigMap +propeller: + workers: 4 + workflow-reeval-duration: 10s + downstream-eval-duration: 5s + limit-namespace: "all" + prof-port: 11254 + metrics-prefix: flyte + enable-admin-launcher: true + max-ttl-hours: 1 + gc-interval: 500m + queue: + type: batch + queue: + type: bucket + rate: 20 + capacity: 100 + sub-queue: + type: bucket + rate: 100 + capacity: 1000 + kube-config: "$HOME/.kube/config" + publish-k8s-events: true +# Sample plugins config +plugins: + # Set of enabled plugins at root level + enabled-plugins: + - container + - waitable + - K8S-ARRAY + # All k8s plugins default configuration + k8s: + inject-finalizer: true + default-annotations: + - annotationKey1: annotationValue1 + resource-tolerations: + nvidia.com/gpu: + key: flyte/gpu + value: dedicated + operator: Equal + effect: NoSchedule + default-env-vars: + - AWS_METADATA_SERVICE_TIMEOUT: 5 + - AWS_METADATA_SERVICE_NUM_ATTEMPTS: 20 + - FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000" + - FLYTE_AWS_ACCESS_KEY_ID: minio + - FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage + # Spark Plugin configuration + spark: + spark-config-default: + - spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version: "2" + - spark.kubernetes.allocation.batch.size: "50" + - spark.hadoop.fs.s3a.acl.default: "BucketOwnerFullControl" + - spark.hadoop.fs.s3n.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3n.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3a.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3a.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3a.multipart.threshold: "536870912" + - spark.blacklist.enabled: "true" + - spark.blacklist.timeout: "5m" + # Waitable plugin configuration + waitable: + console-uri: http://localhost:30081/console + # Logging configuration + logs: + kubernetes-enabled: true + kubernetes-url: "http://localhost:30082" +storage: + connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://localhost:9000 + region: us-east-1 + secret-key: miniostorage + cache: + max_size_mbs: 10 + target_gc_percent: 100 + container: myflytecontainer + type: minio +event: + type: admin + rate: 500 + capacity: 1000 +admin: + endpoint: localhost:8089 + insecure: true +errors: + show-source: true +logger: + level: 4 + show-source: true diff --git a/hack/boilerplate.go.txt b/hack/boilerplate.go.txt new file mode 100644 index 000000000..e69de29bb diff --git a/hack/custom-boilerplate.go.txt b/hack/custom-boilerplate.go.txt new file mode 100644 index 000000000..e69de29bb diff --git a/hack/update-codegen.sh b/hack/update-codegen.sh new file mode 100755 index 000000000..e116e7cc8 --- /dev/null +++ b/hack/update-codegen.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Copyright 2017 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file was derived from https://github.com/kubernetes/sample-controller/blob/4d46ec53ca337e118754c5fc50f02634b6a83380/hack/update-codegen.sh + +set -o errexit +set -o nounset +set -o pipefail + +: "${RESOURCE_NAME:?should be set for CRD}" +: "${OPERATOR_PKG:?should be set for operator}" + +echo "Generating CRD: ${RESOURCE_NAME}, in package ${OPERATOR_PKG}..." + +SCRIPT_ROOT=$(dirname ${BASH_SOURCE})/.. +CODEGEN_PKG=${CODEGEN_PKG:-$(cd ${SCRIPT_ROOT}; ls -d -1 ./vendor/k8s.io/code-generator 2>/dev/null || echo ../code-generator)} + +# generate the code with: +# --output-base because this script should also be able to run inside the vendor dir of +# k8s.io/kubernetes. The output-base is needed for the generators to output into the vendor dir +# instead of the $GOPATH directly. For normal projects this can be dropped. +${CODEGEN_PKG}/generate-groups.sh "deepcopy,client,informer,lister" \ + ${OPERATOR_PKG}/pkg/client \ + ${OPERATOR_PKG}/pkg/apis \ + ${RESOURCE_NAME}:v1alpha1 \ + --output-base "$(dirname ${BASH_SOURCE})/../../../.." \ + --go-header-file ${SCRIPT_ROOT}/hack/boilerplate.go.txt + +# To use your own boilerplate text use: +# --go-header-file ${SCRIPT_ROOT}/hack/custom-boilerplate.go.txt diff --git a/hack/verify-codegen.sh b/hack/verify-codegen.sh new file mode 100755 index 000000000..fb944feda --- /dev/null +++ b/hack/verify-codegen.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Copyright 2017 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/7ec2c1043bd0e5b511bfdf79eb215bc429effa24/hack/verify-codegen.sh + +set -o errexit +set -o nounset +set -o pipefail + +SCRIPT_ROOT=$(dirname "${BASH_SOURCE}")/.. + +DIFFROOT="${SCRIPT_ROOT}/pkg" +TMP_DIFFROOT="${SCRIPT_ROOT}/_tmp/pkg" +_tmp="${SCRIPT_ROOT}/_tmp" + +cleanup() { + rm -rf "${_tmp}" +} +trap "cleanup" EXIT SIGINT + +cleanup + +mkdir -p "${TMP_DIFFROOT}" +cp -a "${DIFFROOT}"/* "${TMP_DIFFROOT}" + +"${SCRIPT_ROOT}/hack/update-codegen.sh" +echo "diffing ${DIFFROOT} against freshly generated codegen" +ret=0 +diff -Naupr "${DIFFROOT}" "${TMP_DIFFROOT}" || ret=$? +cp -a "${TMP_DIFFROOT}"/* "${DIFFROOT}" +if [[ $ret -eq 0 ]] +then + echo "${DIFFROOT} up to date." +else + echo "${DIFFROOT} is out of date. Please run hack/update-codegen.sh" + exit 1 +fi diff --git a/pkg/apis/flyteworkflow/register.go b/pkg/apis/flyteworkflow/register.go new file mode 100644 index 000000000..6ea43f456 --- /dev/null +++ b/pkg/apis/flyteworkflow/register.go @@ -0,0 +1,5 @@ +package flyteworkflow + +const ( + GroupName = "flyte.lyft.com" +) diff --git a/pkg/apis/flyteworkflow/v1alpha1/branch.go b/pkg/apis/flyteworkflow/v1alpha1/branch.go new file mode 100644 index 000000000..692e73e0e --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/branch.go @@ -0,0 +1,100 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type Error struct { + *core.Error +} + +func (in Error) UnmarshalJSON(b []byte) error { + in.Error = &core.Error{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.Error) +} + +func (in Error) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.Error); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Error) DeepCopyInto(out *Error) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this + +} + +type BooleanExpression struct { + *core.BooleanExpression +} + +func (in BooleanExpression) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.BooleanExpression); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *BooleanExpression) UnmarshalJSON(b []byte) error { + in.BooleanExpression = &core.BooleanExpression{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.BooleanExpression) +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BooleanExpression) DeepCopyInto(out *BooleanExpression) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type IfBlock struct { + Condition BooleanExpression `json:"condition"` + ThenNode *NodeID `json:"then"` +} + +func (in IfBlock) GetCondition() *core.BooleanExpression { + return in.Condition.BooleanExpression +} + +func (in *IfBlock) GetThenNode() *NodeID { + return in.ThenNode +} + +type BranchNodeSpec struct { + If IfBlock `json:"if"` + ElseIf []*IfBlock `json:"elseIf,omitempty"` + Else *NodeID `json:"else,omitempty"` + ElseFail *Error `json:"elseFail,omitempty"` +} + +func (in *BranchNodeSpec) GetIf() ExecutableIfBlock { + return &in.If +} + +func (in *BranchNodeSpec) GetElse() *NodeID { + return in.Else +} + +func (in *BranchNodeSpec) GetElseIf() []ExecutableIfBlock { + elifs := make([]ExecutableIfBlock, 0, len(in.ElseIf)) + for _, b := range in.ElseIf { + elifs = append(elifs, b) + } + return elifs +} + +func (in *BranchNodeSpec) GetElseFail() *core.Error { + if in.ElseFail != nil { + return in.ElseFail.Error + } + return nil +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/branch_test.go b/pkg/apis/flyteworkflow/v1alpha1/branch_test.go new file mode 100644 index 000000000..3f23f5553 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/branch_test.go @@ -0,0 +1,22 @@ +package v1alpha1_test + +import ( + "encoding/json" + "io/ioutil" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestMarshalUnMarshal_BranchTask(t *testing.T) { + r, err := ioutil.ReadFile("testdata/branch.json") + assert.NoError(t, err) + o := v1alpha1.NodeSpec{} + err = json.Unmarshal(r, &o) + assert.NoError(t, err) + assert.NotNil(t, o.BranchNode.If) + assert.Equal(t, core.ComparisonExpression_GT, o.BranchNode.If.Condition.BooleanExpression.GetComparison().Operator) + assert.Equal(t, 1, len(o.InputBindings)) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/doc.go b/pkg/apis/flyteworkflow/v1alpha1/doc.go new file mode 100644 index 000000000..37762e696 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/doc.go @@ -0,0 +1,5 @@ +// +k8s:deepcopy-gen=package + +// Package v1alpha1 is the v1alpha1 version of the API. +// +groupName=flyteworkflow.flyte.net +package v1alpha1 diff --git a/pkg/apis/flyteworkflow/v1alpha1/identifier.go b/pkg/apis/flyteworkflow/v1alpha1/identifier.go new file mode 100644 index 000000000..ff102680c --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/identifier.go @@ -0,0 +1,45 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type Identifier struct { + *core.Identifier +} + +func (in *Identifier) UnmarshalJSON(b []byte) error { + in.Identifier = &core.Identifier{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.Identifier) +} + +func (in *Identifier) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.Identifier); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *Identifier) DeepCopyInto(out *Identifier) { + *out = *in +} + +type WorkflowExecutionIdentifier struct { + *core.WorkflowExecutionIdentifier +} + +func (in *WorkflowExecutionIdentifier) DeepCopyInto(out *WorkflowExecutionIdentifier) { + *out = *in +} + +type TaskExecutionIdentifier struct { + *core.TaskExecutionIdentifier +} + +func (in *TaskExecutionIdentifier) DeepCopyInto(out *TaskExecutionIdentifier) { + *out = *in +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go new file mode 100644 index 000000000..deea5db50 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -0,0 +1,391 @@ +package v1alpha1 + +import ( + "context" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + types2 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/storage" +) + +// The intention of these interfaces is to decouple the algorithm and usage from the actual CRD definition. +// this would help in ease of changes underneath without affecting the code. + +//go:generate mockery -all + +type CustomState map[string]interface{} +type WorkflowID = string +type TaskID = string +type NodeID = string +type LaunchPlanRefID = Identifier +type ExecutionID = WorkflowExecutionIdentifier + +// NodeKind refers to the type of Node. +type NodeKind string + +func (n NodeKind) String() string { + return string(n) +} + +type DataReference = storage.DataReference + +const ( + // TODO Should we default a NodeKindTask to empty? thus we can assume all unspecified nodetypes as task + NodeKindTask NodeKind = "task" + NodeKindBranch NodeKind = "branch" // A Branch node with conditions + NodeKindWorkflow NodeKind = "workflow" // Either an inline workflow or a remote workflow definition + NodeKindStart NodeKind = "start" // Start node is a special node + NodeKindEnd NodeKind = "end" +) + +// NodePhase indicates the current state of the Node (phase). A node progresses through these states +type NodePhase int + +const ( + NodePhaseNotYetStarted NodePhase = iota + NodePhaseQueued + NodePhaseRunning + NodePhaseFailing + NodePhaseSucceeding + NodePhaseSucceeded + NodePhaseFailed + NodePhaseSkipped + NodePhaseRetryableFailure +) + +func (p NodePhase) String() string { + switch p { + case NodePhaseNotYetStarted: + return "NotYetStarted" + case NodePhaseQueued: + return "Queued" + case NodePhaseRunning: + return "Running" + case NodePhaseSucceeding: + return "Succeeding" + case NodePhaseSucceeded: + return "Succeeded" + case NodePhaseFailed: + return "Failed" + case NodePhaseSkipped: + return "Skipped" + case NodePhaseRetryableFailure: + return "RetryableFailure" + } + return "Unknown" +} + +// WorkflowPhase indicates current state of the Workflow. +type WorkflowPhase int + +const ( + WorkflowPhaseReady WorkflowPhase = iota + WorkflowPhaseRunning + WorkflowPhaseSucceeding + WorkflowPhaseSuccess + WorkflowPhaseFailing + WorkflowPhaseFailed + WorkflowPhaseAborted +) + +func (p WorkflowPhase) String() string { + switch p { + case WorkflowPhaseReady: + return "Ready" + case WorkflowPhaseRunning: + return "Running" + case WorkflowPhaseSuccess: + return "Succeeded" + case WorkflowPhaseFailed: + return "Failed" + case WorkflowPhaseFailing: + return "Failing" + case WorkflowPhaseSucceeding: + return "Succeeding" + case WorkflowPhaseAborted: + return "Aborted" + } + return "Unknown" +} + +// A branchNode has its own Phases. These are used by the child nodes to ensure that the branch node is in the right state +type BranchNodePhase int + +const ( + BranchNodeNotYetEvaluated BranchNodePhase = iota + BranchNodeSuccess + BranchNodeError +) + +func (b BranchNodePhase) String() string { + switch b { + case BranchNodeNotYetEvaluated: + return "NotYetEvaluated" + case BranchNodeSuccess: + return "BranchEvalSuccess" + case BranchNodeError: + return "BranchEvalFailed" + } + return "Undefined" +} + +// TaskType is a dynamic enumeration, that is defined by configuration +type TaskType = string + +// Interface for a Task that can be executed +type ExecutableTask interface { + TaskType() TaskType + CoreTask() *core.TaskTemplate +} + +// Interface for the executable If block +type ExecutableIfBlock interface { + GetCondition() *core.BooleanExpression + GetThenNode() *NodeID +} + +// Interface for branch node status. This is the mutable API for a branch node +type ExecutableBranchNodeStatus interface { + GetPhase() BranchNodePhase + GetFinalizedNode() *NodeID +} + +type MutableBranchNodeStatus interface { + ExecutableBranchNodeStatus + + SetBranchNodeError() + SetBranchNodeSuccess(id NodeID) +} + +// Interface for dynamic node status. +type ExecutableDynamicNodeStatus interface { + GetDynamicNodePhase() DynamicNodePhase +} + +type MutableDynamicNodeStatus interface { + ExecutableDynamicNodeStatus + + SetDynamicNodePhase(phase DynamicNodePhase) +} + +// Interface for Branch node. All the methods are purely read only except for the GetExecutionStatus. +// Phase returns ExecutableBranchNodeStatus, which permits some mutations +type ExecutableBranchNode interface { + GetIf() ExecutableIfBlock + GetElse() *NodeID + GetElseIf() []ExecutableIfBlock + GetElseFail() *core.Error +} + +type ExecutableWorkflowNodeStatus interface { + // Name of the child execution. We only store name since the project and domain will be + // the same as the parent workflow execution. + GetWorkflowExecutionName() string +} + +type MutableWorkflowNodeStatus interface { + ExecutableWorkflowNodeStatus + + // Sets the name of the child execution. We only store name since the project and domain + // will be the same as the parent workflow execution. + SetWorkflowExecutionName(name string) +} + +type MutableNodeStatus interface { + // Mutation API's + SetDataDir(DataReference) + SetParentNodeID(n *NodeID) + SetParentTaskID(t *core.TaskExecutionIdentifier) + UpdatePhase(phase NodePhase, occurredAt metav1.Time, reason string) + IncrementAttempts() uint32 + SetCached() + ResetDirty() + + GetOrCreateBranchStatus() MutableBranchNodeStatus + GetOrCreateWorkflowStatus() MutableWorkflowNodeStatus + ClearWorkflowStatus() + GetOrCreateTaskStatus() MutableTaskNodeStatus + ClearTaskStatus() + GetOrCreateSubWorkflowStatus() MutableSubWorkflowNodeStatus + ClearSubWorkflowStatus() + GetOrCreateDynamicNodeStatus() MutableDynamicNodeStatus + ClearDynamicNodeStatus() +} + +// Interface for a Node Phase. This provides a mutable API. +type ExecutableNodeStatus interface { + NodeStatusGetter + MutableNodeStatus + NodeStatusVisitor + GetPhase() NodePhase + GetQueuedAt() *metav1.Time + GetStoppedAt() *metav1.Time + GetStartedAt() *metav1.Time + GetLastUpdatedAt() *metav1.Time + GetParentNodeID() *NodeID + GetParentTaskID() *core.TaskExecutionIdentifier + GetDataDir() DataReference + GetMessage() string + GetAttempts() uint32 + GetWorkflowNodeStatus() ExecutableWorkflowNodeStatus + GetTaskNodeStatus() ExecutableTaskNodeStatus + GetSubWorkflowNodeStatus() ExecutableSubWorkflowNodeStatus + + IsCached() bool + IsDirty() bool +} + +type ExecutableSubWorkflowNodeStatus interface { + GetPhase() WorkflowPhase +} + +type MutableSubWorkflowNodeStatus interface { + ExecutableSubWorkflowNodeStatus + SetPhase(phase WorkflowPhase) +} + +type ExecutableTaskNodeStatus interface { + GetPhase() types2.TaskPhase + GetPhaseVersion() uint32 + GetCustomState() types2.CustomState +} + +type MutableTaskNodeStatus interface { + ExecutableTaskNodeStatus + SetPhase(phase types2.TaskPhase) + SetPhaseVersion(version uint32) + SetCustomState(state types2.CustomState) +} + +// Interface for a Child Workflow Node +type ExecutableWorkflowNode interface { + GetLaunchPlanRefID() *LaunchPlanRefID + GetSubWorkflowRef() *WorkflowID +} + +type BaseNode interface { + GetID() NodeID + GetKind() NodeKind +} + +// Interface for the Executable Node +type ExecutableNode interface { + BaseNode + IsStartNode() bool + IsEndNode() bool + GetTaskID() *TaskID + GetBranchNode() ExecutableBranchNode + GetWorkflowNode() ExecutableWorkflowNode + GetOutputAlias() []Alias + GetInputBindings() []*Binding + GetResources() *v1.ResourceRequirements + GetConfig() *v1.ConfigMap + GetRetryStrategy() *RetryStrategy +} + +// Interface for the Workflow Phase. This is the mutable portion for a Workflow +type ExecutableWorkflowStatus interface { + NodeStatusGetter + UpdatePhase(p WorkflowPhase, msg string) + GetPhase() WorkflowPhase + GetStoppedAt() *metav1.Time + GetStartedAt() *metav1.Time + GetLastUpdatedAt() *metav1.Time + IsTerminated() bool + GetMessage() string + SetDataDir(DataReference) + GetDataDir() DataReference + GetOutputReference() DataReference + SetOutputReference(reference DataReference) + IncFailedAttempts() + SetMessage(msg string) + ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name NodeID) (storage.DataReference, error) +} + +type BaseWorkflow interface { + StartNode() ExecutableNode + GetID() WorkflowID + // From returns all nodes that can be reached directly + // from the node with the given unique name. + FromNode(name NodeID) ([]NodeID, error) + GetNode(nodeID NodeID) (ExecutableNode, bool) +} + +type BaseWorkflowWithStatus interface { + BaseWorkflow + NodeStatusGetter +} + +// This interface captures the methods available on any workflow (top level or child). The Meta section is available +// only for the top level workflow +type ExecutableSubWorkflow interface { + BaseWorkflow + GetOutputBindings() []*Binding + GetOnFailureNode() ExecutableNode + GetNodes() []NodeID + GetConnections() *Connections + GetOutputs() *OutputVarMap +} + +// WorkflowMeta provides an interface to retrieve labels, annotations and other concepts that are declared only once +// for the top level workflow +type WorkflowMeta interface { + GetExecutionID() ExecutionID + GetK8sWorkflowID() types.NamespacedName + NewControllerRef() metav1.OwnerReference + GetNamespace() string + GetCreationTimestamp() metav1.Time + GetAnnotations() map[string]string + GetLabels() map[string]string + GetName() string + GetServiceAccountName() string +} + +type WorkflowMetaExtended interface { + WorkflowMeta + GetTask(id TaskID) (ExecutableTask, error) + FindSubWorkflow(subID WorkflowID) ExecutableSubWorkflow + GetExecutionStatus() ExecutableWorkflowStatus +} + +// A Top level Workflow is a combination of WorkflowMeta and an ExecutableSubWorkflow +type ExecutableWorkflow interface { + ExecutableSubWorkflow + WorkflowMetaExtended + NodeStatusGetter +} + +type NodeStatusGetter interface { + GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus +} + +type NodeStatusMap = map[NodeID]ExecutableNodeStatus + +type NodeStatusVisitFn = func(node NodeID, status ExecutableNodeStatus) + +type NodeStatusVisitor interface { + VisitNodeStatuses(visitor NodeStatusVisitFn) +} + +// Simple callback that can be used to indicate that the workflow with WorkflowID should be re-enqueued for examination. +type EnqueueWorkflow func(workflowID WorkflowID) + +func GetOutputsFile(outputDir DataReference) DataReference { + return outputDir + "/outputs.pb" +} + +func GetInputsFile(inputDir DataReference) DataReference { + return inputDir + "/inputs.pb" +} + +func GetOutputErrorFile(inputDir DataReference) DataReference { + return inputDir + "/error.pb" +} + +func GetFutureFile() string { + return "futures.pb" +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go new file mode 100644 index 000000000..a6ed88364 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go @@ -0,0 +1,39 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// BaseNode is an autogenerated mock type for the BaseNode type +type BaseNode struct { + mock.Mock +} + +// GetID provides a mock function with given fields: +func (_m *BaseNode) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetKind provides a mock function with given fields: +func (_m *BaseNode) GetKind() v1alpha1.NodeKind { + ret := _m.Called() + + var r0 v1alpha1.NodeKind + if rf, ok := ret.Get(0).(func() v1alpha1.NodeKind); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.NodeKind) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go new file mode 100644 index 000000000..5f4dbb259 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go @@ -0,0 +1,87 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// BaseWorkflow is an autogenerated mock type for the BaseWorkflow type +type BaseWorkflow struct { + mock.Mock +} + +// FromNode provides a mock function with given fields: name +func (_m *BaseWorkflow) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetID provides a mock function with given fields: +func (_m *BaseWorkflow) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *BaseWorkflow) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// StartNode provides a mock function with given fields: +func (_m *BaseWorkflow) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go new file mode 100644 index 000000000..fd3375116 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go @@ -0,0 +1,103 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// BaseWorkflowWithStatus is an autogenerated mock type for the BaseWorkflowWithStatus type +type BaseWorkflowWithStatus struct { + mock.Mock +} + +// FromNode provides a mock function with given fields: name +func (_m *BaseWorkflowWithStatus) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetID provides a mock function with given fields: +func (_m *BaseWorkflowWithStatus) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *BaseWorkflowWithStatus) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *BaseWorkflowWithStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// StartNode provides a mock function with given fields: +func (_m *BaseWorkflowWithStatus) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go new file mode 100644 index 000000000..6c2c3948f --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go @@ -0,0 +1,76 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableBranchNode is an autogenerated mock type for the ExecutableBranchNode type +type ExecutableBranchNode struct { + mock.Mock +} + +// GetElse provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetElse() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetElseFail provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetElseFail() *core.Error { + ret := _m.Called() + + var r0 *core.Error + if rf, ok := ret.Get(0).(func() *core.Error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Error) + } + } + + return r0 +} + +// GetElseIf provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetElseIf() []v1alpha1.ExecutableIfBlock { + ret := _m.Called() + + var r0 []v1alpha1.ExecutableIfBlock + if rf, ok := ret.Get(0).(func() []v1alpha1.ExecutableIfBlock); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]v1alpha1.ExecutableIfBlock) + } + } + + return r0 +} + +// GetIf provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetIf() v1alpha1.ExecutableIfBlock { + ret := _m.Called() + + var r0 v1alpha1.ExecutableIfBlock + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableIfBlock); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableIfBlock) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go new file mode 100644 index 000000000..a24a93704 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go @@ -0,0 +1,41 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableBranchNodeStatus is an autogenerated mock type for the ExecutableBranchNodeStatus type +type ExecutableBranchNodeStatus struct { + mock.Mock +} + +// GetFinalizedNode provides a mock function with given fields: +func (_m *ExecutableBranchNodeStatus) GetFinalizedNode() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableBranchNodeStatus) GetPhase() v1alpha1.BranchNodePhase { + ret := _m.Called() + + var r0 v1alpha1.BranchNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.BranchNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.BranchNodePhase) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go new file mode 100644 index 000000000..fc8819ba3 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go @@ -0,0 +1,25 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableDynamicNodeStatus is an autogenerated mock type for the ExecutableDynamicNodeStatus type +type ExecutableDynamicNodeStatus struct { + mock.Mock +} + +// GetDynamicNodePhase provides a mock function with given fields: +func (_m *ExecutableDynamicNodeStatus) GetDynamicNodePhase() v1alpha1.DynamicNodePhase { + ret := _m.Called() + + var r0 v1alpha1.DynamicNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.DynamicNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.DynamicNodePhase) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go new file mode 100644 index 000000000..7e29c8b37 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go @@ -0,0 +1,43 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// ExecutableIfBlock is an autogenerated mock type for the ExecutableIfBlock type +type ExecutableIfBlock struct { + mock.Mock +} + +// GetCondition provides a mock function with given fields: +func (_m *ExecutableIfBlock) GetCondition() *core.BooleanExpression { + ret := _m.Called() + + var r0 *core.BooleanExpression + if rf, ok := ret.Get(0).(func() *core.BooleanExpression); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.BooleanExpression) + } + } + + return r0 +} + +// GetThenNode provides a mock function with given fields: +func (_m *ExecutableIfBlock) GetThenNode() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go new file mode 100644 index 000000000..0a7fa9e59 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -0,0 +1,196 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1 "k8s.io/api/core/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableNode is an autogenerated mock type for the ExecutableNode type +type ExecutableNode struct { + mock.Mock +} + +// GetBranchNode provides a mock function with given fields: +func (_m *ExecutableNode) GetBranchNode() v1alpha1.ExecutableBranchNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableBranchNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableBranchNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableBranchNode) + } + } + + return r0 +} + +// GetConfig provides a mock function with given fields: +func (_m *ExecutableNode) GetConfig() *v1.ConfigMap { + ret := _m.Called() + + var r0 *v1.ConfigMap + if rf, ok := ret.Get(0).(func() *v1.ConfigMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ConfigMap) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *ExecutableNode) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetInputBindings provides a mock function with given fields: +func (_m *ExecutableNode) GetInputBindings() []*v1alpha1.Binding { + ret := _m.Called() + + var r0 []*v1alpha1.Binding + if rf, ok := ret.Get(0).(func() []*v1alpha1.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*v1alpha1.Binding) + } + } + + return r0 +} + +// GetKind provides a mock function with given fields: +func (_m *ExecutableNode) GetKind() v1alpha1.NodeKind { + ret := _m.Called() + + var r0 v1alpha1.NodeKind + if rf, ok := ret.Get(0).(func() v1alpha1.NodeKind); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.NodeKind) + } + + return r0 +} + +// GetOutputAlias provides a mock function with given fields: +func (_m *ExecutableNode) GetOutputAlias() []v1alpha1.Alias { + ret := _m.Called() + + var r0 []v1alpha1.Alias + if rf, ok := ret.Get(0).(func() []v1alpha1.Alias); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]v1alpha1.Alias) + } + } + + return r0 +} + +// GetResources provides a mock function with given fields: +func (_m *ExecutableNode) GetResources() *v1.ResourceRequirements { + ret := _m.Called() + + var r0 *v1.ResourceRequirements + if rf, ok := ret.Get(0).(func() *v1.ResourceRequirements); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ResourceRequirements) + } + } + + return r0 +} + +// GetRetryStrategy provides a mock function with given fields: +func (_m *ExecutableNode) GetRetryStrategy() *v1alpha1.RetryStrategy { + ret := _m.Called() + + var r0 *v1alpha1.RetryStrategy + if rf, ok := ret.Get(0).(func() *v1alpha1.RetryStrategy); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.RetryStrategy) + } + } + + return r0 +} + +// GetTaskID provides a mock function with given fields: +func (_m *ExecutableNode) GetTaskID() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetWorkflowNode provides a mock function with given fields: +func (_m *ExecutableNode) GetWorkflowNode() v1alpha1.ExecutableWorkflowNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowNode) + } + } + + return r0 +} + +// IsEndNode provides a mock function with given fields: +func (_m *ExecutableNode) IsEndNode() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsStartNode provides a mock function with given fields: +func (_m *ExecutableNode) IsStartNode() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go new file mode 100644 index 000000000..ee91f4ad9 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -0,0 +1,407 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import storage "github.com/lyft/flytestdlib/storage" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableNodeStatus is an autogenerated mock type for the ExecutableNodeStatus type +type ExecutableNodeStatus struct { + mock.Mock +} + +// ClearDynamicNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearDynamicNodeStatus() { + _m.Called() +} + +// ClearSubWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearSubWorkflowStatus() { + _m.Called() +} + +// ClearTaskStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearTaskStatus() { + _m.Called() +} + +// ClearWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearWorkflowStatus() { + _m.Called() +} + +// GetAttempts provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetAttempts() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// GetDataDir provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetDataDir() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetLastUpdatedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetLastUpdatedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetMessage provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetMessage() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *ExecutableNodeStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// GetOrCreateBranchStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateBranchStatus() v1alpha1.MutableBranchNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableBranchNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableBranchNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableBranchNodeStatus) + } + } + + return r0 +} + +// GetOrCreateDynamicNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateDynamicNodeStatus() v1alpha1.MutableDynamicNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableDynamicNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableDynamicNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableDynamicNodeStatus) + } + } + + return r0 +} + +// GetOrCreateSubWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateSubWorkflowStatus() v1alpha1.MutableSubWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableSubWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableSubWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableSubWorkflowNodeStatus) + } + } + + return r0 +} + +// GetOrCreateTaskStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateTaskStatus() v1alpha1.MutableTaskNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableTaskNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableTaskNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableTaskNodeStatus) + } + } + + return r0 +} + +// GetOrCreateWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateWorkflowStatus() v1alpha1.MutableWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableWorkflowNodeStatus) + } + } + + return r0 +} + +// GetParentNodeID provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetParentNodeID() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetParentTaskID provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetParentTaskID() *core.TaskExecutionIdentifier { + ret := _m.Called() + + var r0 *core.TaskExecutionIdentifier + if rf, ok := ret.Get(0).(func() *core.TaskExecutionIdentifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskExecutionIdentifier) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetPhase() v1alpha1.NodePhase { + ret := _m.Called() + + var r0 v1alpha1.NodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.NodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.NodePhase) + } + + return r0 +} + +// GetQueuedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetQueuedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetStartedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetStartedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetStoppedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetStoppedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetSubWorkflowNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetSubWorkflowNodeStatus() v1alpha1.ExecutableSubWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableSubWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableSubWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflowNodeStatus) + } + } + + return r0 +} + +// GetTaskNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetTaskNodeStatus() v1alpha1.ExecutableTaskNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableTaskNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableTaskNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTaskNodeStatus) + } + } + + return r0 +} + +// GetWorkflowNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetWorkflowNodeStatus() v1alpha1.ExecutableWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowNodeStatus) + } + } + + return r0 +} + +// IncrementAttempts provides a mock function with given fields: +func (_m *ExecutableNodeStatus) IncrementAttempts() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// IsCached provides a mock function with given fields: +func (_m *ExecutableNodeStatus) IsCached() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsDirty provides a mock function with given fields: +func (_m *ExecutableNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// ResetDirty provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ResetDirty() { + _m.Called() +} + +// SetCached provides a mock function with given fields: +func (_m *ExecutableNodeStatus) SetCached() { + _m.Called() +} + +// SetDataDir provides a mock function with given fields: _a0 +func (_m *ExecutableNodeStatus) SetDataDir(_a0 storage.DataReference) { + _m.Called(_a0) +} + +// SetParentNodeID provides a mock function with given fields: n +func (_m *ExecutableNodeStatus) SetParentNodeID(n *string) { + _m.Called(n) +} + +// SetParentTaskID provides a mock function with given fields: t +func (_m *ExecutableNodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { + _m.Called(t) +} + +// UpdatePhase provides a mock function with given fields: phase, occurredAt, reason +func (_m *ExecutableNodeStatus) UpdatePhase(phase v1alpha1.NodePhase, occurredAt v1.Time, reason string) { + _m.Called(phase, occurredAt, reason) +} + +// VisitNodeStatuses provides a mock function with given fields: visitor +func (_m *ExecutableNodeStatus) VisitNodeStatuses(visitor func(string, v1alpha1.ExecutableNodeStatus)) { + _m.Called(visitor) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go new file mode 100644 index 000000000..132991edb --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go @@ -0,0 +1,167 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableSubWorkflow is an autogenerated mock type for the ExecutableSubWorkflow type +type ExecutableSubWorkflow struct { + mock.Mock +} + +// FromNode provides a mock function with given fields: name +func (_m *ExecutableSubWorkflow) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetConnections provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetConnections() *v1alpha1.Connections { + ret := _m.Called() + + var r0 *v1alpha1.Connections + if rf, ok := ret.Get(0).(func() *v1alpha1.Connections); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.Connections) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *ExecutableSubWorkflow) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodes provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetNodes() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetOnFailureNode provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetOnFailureNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} + +// GetOutputBindings provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetOutputBindings() []*v1alpha1.Binding { + ret := _m.Called() + + var r0 []*v1alpha1.Binding + if rf, ok := ret.Get(0).(func() []*v1alpha1.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*v1alpha1.Binding) + } + } + + return r0 +} + +// GetOutputs provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetOutputs() *v1alpha1.OutputVarMap { + ret := _m.Called() + + var r0 *v1alpha1.OutputVarMap + if rf, ok := ret.Get(0).(func() *v1alpha1.OutputVarMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.OutputVarMap) + } + } + + return r0 +} + +// StartNode provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go new file mode 100644 index 000000000..c90dc4598 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go @@ -0,0 +1,25 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableSubWorkflowNodeStatus is an autogenerated mock type for the ExecutableSubWorkflowNodeStatus type +type ExecutableSubWorkflowNodeStatus struct { + mock.Mock +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableSubWorkflowNodeStatus) GetPhase() v1alpha1.WorkflowPhase { + ret := _m.Called() + + var r0 v1alpha1.WorkflowPhase + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowPhase) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go new file mode 100644 index 000000000..61700e7ec --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go @@ -0,0 +1,41 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// ExecutableTask is an autogenerated mock type for the ExecutableTask type +type ExecutableTask struct { + mock.Mock +} + +// CoreTask provides a mock function with given fields: +func (_m *ExecutableTask) CoreTask() *core.TaskTemplate { + ret := _m.Called() + + var r0 *core.TaskTemplate + if rf, ok := ret.Get(0).(func() *core.TaskTemplate); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskTemplate) + } + } + + return r0 +} + +// TaskType provides a mock function with given fields: +func (_m *ExecutableTask) TaskType() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go new file mode 100644 index 000000000..c1c2c7e25 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go @@ -0,0 +1,55 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" + +// ExecutableTaskNodeStatus is an autogenerated mock type for the ExecutableTaskNodeStatus type +type ExecutableTaskNodeStatus struct { + mock.Mock +} + +// GetCustomState provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetCustomState() map[string]interface{} { + ret := _m.Called() + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func() map[string]interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetPhase() types.TaskPhase { + ret := _m.Called() + + var r0 types.TaskPhase + if rf, ok := ret.Get(0).(func() types.TaskPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.TaskPhase) + } + + return r0 +} + +// GetPhaseVersion provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go new file mode 100644 index 000000000..d14b3d0f9 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go @@ -0,0 +1,370 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableWorkflow is an autogenerated mock type for the ExecutableWorkflow type +type ExecutableWorkflow struct { + mock.Mock +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *ExecutableWorkflow) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} + +// FromNode provides a mock function with given fields: name +func (_m *ExecutableWorkflow) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAnnotations provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetConnections provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetConnections() *v1alpha1.Connections { + ret := _m.Called() + + var r0 *v1alpha1.Connections + if rf, ok := ret.Get(0).(func() *v1alpha1.Connections); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.Connections) + } + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetExecutionID provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +// GetExecutionStatus provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowStatus) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *ExecutableWorkflow) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *ExecutableWorkflow) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// GetNodes provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetNodes() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetOnFailureNode provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetOnFailureNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} + +// GetOutputBindings provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetOutputBindings() []*v1alpha1.Binding { + ret := _m.Called() + + var r0 []*v1alpha1.Binding + if rf, ok := ret.Get(0).(func() []*v1alpha1.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*v1alpha1.Binding) + } + } + + return r0 +} + +// GetOutputs provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetOutputs() *v1alpha1.OutputVarMap { + ret := _m.Called() + + var r0 *v1alpha1.OutputVarMap + if rf, ok := ret.Get(0).(func() *v1alpha1.OutputVarMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.OutputVarMap) + } + } + + return r0 +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetTask provides a mock function with given fields: id +func (_m *ExecutableWorkflow) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewControllerRef provides a mock function with given fields: +func (_m *ExecutableWorkflow) NewControllerRef() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +// StartNode provides a mock function with given fields: +func (_m *ExecutableWorkflow) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go new file mode 100644 index 000000000..3cf799a20 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go @@ -0,0 +1,43 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableWorkflowNode is an autogenerated mock type for the ExecutableWorkflowNode type +type ExecutableWorkflowNode struct { + mock.Mock +} + +// GetLaunchPlanRefID provides a mock function with given fields: +func (_m *ExecutableWorkflowNode) GetLaunchPlanRefID() *v1alpha1.Identifier { + ret := _m.Called() + + var r0 *v1alpha1.Identifier + if rf, ok := ret.Get(0).(func() *v1alpha1.Identifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.Identifier) + } + } + + return r0 +} + +// GetSubWorkflowRef provides a mock function with given fields: +func (_m *ExecutableWorkflowNode) GetSubWorkflowRef() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go new file mode 100644 index 000000000..fbf954421 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go @@ -0,0 +1,24 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// ExecutableWorkflowNodeStatus is an autogenerated mock type for the ExecutableWorkflowNodeStatus type +type ExecutableWorkflowNodeStatus struct { + mock.Mock +} + +// GetWorkflowExecutionName provides a mock function with given fields: +func (_m *ExecutableWorkflowNodeStatus) GetWorkflowExecutionName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go new file mode 100644 index 000000000..8de18b32c --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go @@ -0,0 +1,194 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import storage "github.com/lyft/flytestdlib/storage" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableWorkflowStatus is an autogenerated mock type for the ExecutableWorkflowStatus type +type ExecutableWorkflowStatus struct { + mock.Mock +} + +// ConstructNodeDataDir provides a mock function with given fields: ctx, constructor, name +func (_m *ExecutableWorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name string) (storage.DataReference, error) { + ret := _m.Called(ctx, constructor, name) + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func(context.Context, storage.ReferenceConstructor, string) storage.DataReference); ok { + r0 = rf(ctx, constructor, name) + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, storage.ReferenceConstructor, string) error); ok { + r1 = rf(ctx, constructor, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetDataDir provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetDataDir() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetLastUpdatedAt provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetLastUpdatedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetMessage provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetMessage() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *ExecutableWorkflowStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// GetOutputReference provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetOutputReference() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetPhase() v1alpha1.WorkflowPhase { + ret := _m.Called() + + var r0 v1alpha1.WorkflowPhase + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowPhase) + } + + return r0 +} + +// GetStartedAt provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetStartedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetStoppedAt provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetStoppedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// IncFailedAttempts provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) IncFailedAttempts() { + _m.Called() +} + +// IsTerminated provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) IsTerminated() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// SetDataDir provides a mock function with given fields: _a0 +func (_m *ExecutableWorkflowStatus) SetDataDir(_a0 storage.DataReference) { + _m.Called(_a0) +} + +// SetMessage provides a mock function with given fields: msg +func (_m *ExecutableWorkflowStatus) SetMessage(msg string) { + _m.Called(msg) +} + +// SetOutputReference provides a mock function with given fields: reference +func (_m *ExecutableWorkflowStatus) SetOutputReference(reference storage.DataReference) { + _m.Called(reference) +} + +// UpdatePhase provides a mock function with given fields: p, msg +func (_m *ExecutableWorkflowStatus) UpdatePhase(p v1alpha1.WorkflowPhase, msg string) { + _m.Called(p, msg) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go new file mode 100644 index 000000000..fcf090d22 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go @@ -0,0 +1,51 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableBranchNodeStatus is an autogenerated mock type for the MutableBranchNodeStatus type +type MutableBranchNodeStatus struct { + mock.Mock +} + +// GetFinalizedNode provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) GetFinalizedNode() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) GetPhase() v1alpha1.BranchNodePhase { + ret := _m.Called() + + var r0 v1alpha1.BranchNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.BranchNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.BranchNodePhase) + } + + return r0 +} + +// SetBranchNodeError provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) SetBranchNodeError() { + _m.Called() +} + +// SetBranchNodeSuccess provides a mock function with given fields: id +func (_m *MutableBranchNodeStatus) SetBranchNodeSuccess(id string) { + _m.Called(id) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go new file mode 100644 index 000000000..0208256ad --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go @@ -0,0 +1,30 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableDynamicNodeStatus is an autogenerated mock type for the MutableDynamicNodeStatus type +type MutableDynamicNodeStatus struct { + mock.Mock +} + +// GetDynamicNodePhase provides a mock function with given fields: +func (_m *MutableDynamicNodeStatus) GetDynamicNodePhase() v1alpha1.DynamicNodePhase { + ret := _m.Called() + + var r0 v1alpha1.DynamicNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.DynamicNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.DynamicNodePhase) + } + + return r0 +} + +// SetDynamicNodePhase provides a mock function with given fields: phase +func (_m *MutableDynamicNodeStatus) SetDynamicNodePhase(phase v1alpha1.DynamicNodePhase) { + _m.Called(phase) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go new file mode 100644 index 000000000..0ef3f0b33 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -0,0 +1,158 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import storage "github.com/lyft/flytestdlib/storage" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableNodeStatus is an autogenerated mock type for the MutableNodeStatus type +type MutableNodeStatus struct { + mock.Mock +} + +// ClearDynamicNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearDynamicNodeStatus() { + _m.Called() +} + +// ClearSubWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearSubWorkflowStatus() { + _m.Called() +} + +// ClearTaskStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearTaskStatus() { + _m.Called() +} + +// ClearWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearWorkflowStatus() { + _m.Called() +} + +// GetOrCreateBranchStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateBranchStatus() v1alpha1.MutableBranchNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableBranchNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableBranchNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableBranchNodeStatus) + } + } + + return r0 +} + +// GetOrCreateDynamicNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateDynamicNodeStatus() v1alpha1.MutableDynamicNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableDynamicNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableDynamicNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableDynamicNodeStatus) + } + } + + return r0 +} + +// GetOrCreateSubWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateSubWorkflowStatus() v1alpha1.MutableSubWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableSubWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableSubWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableSubWorkflowNodeStatus) + } + } + + return r0 +} + +// GetOrCreateTaskStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateTaskStatus() v1alpha1.MutableTaskNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableTaskNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableTaskNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableTaskNodeStatus) + } + } + + return r0 +} + +// GetOrCreateWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateWorkflowStatus() v1alpha1.MutableWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableWorkflowNodeStatus) + } + } + + return r0 +} + +// IncrementAttempts provides a mock function with given fields: +func (_m *MutableNodeStatus) IncrementAttempts() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// ResetDirty provides a mock function with given fields: +func (_m *MutableNodeStatus) ResetDirty() { + _m.Called() +} + +// SetCached provides a mock function with given fields: +func (_m *MutableNodeStatus) SetCached() { + _m.Called() +} + +// SetDataDir provides a mock function with given fields: _a0 +func (_m *MutableNodeStatus) SetDataDir(_a0 storage.DataReference) { + _m.Called(_a0) +} + +// SetParentNodeID provides a mock function with given fields: n +func (_m *MutableNodeStatus) SetParentNodeID(n *string) { + _m.Called(n) +} + +// SetParentTaskID provides a mock function with given fields: t +func (_m *MutableNodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { + _m.Called(t) +} + +// UpdatePhase provides a mock function with given fields: phase, occurredAt, reason +func (_m *MutableNodeStatus) UpdatePhase(phase v1alpha1.NodePhase, occurredAt v1.Time, reason string) { + _m.Called(phase, occurredAt, reason) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go new file mode 100644 index 000000000..b194a6373 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go @@ -0,0 +1,30 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableSubWorkflowNodeStatus is an autogenerated mock type for the MutableSubWorkflowNodeStatus type +type MutableSubWorkflowNodeStatus struct { + mock.Mock +} + +// GetPhase provides a mock function with given fields: +func (_m *MutableSubWorkflowNodeStatus) GetPhase() v1alpha1.WorkflowPhase { + ret := _m.Called() + + var r0 v1alpha1.WorkflowPhase + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowPhase) + } + + return r0 +} + +// SetPhase provides a mock function with given fields: phase +func (_m *MutableSubWorkflowNodeStatus) SetPhase(phase v1alpha1.WorkflowPhase) { + _m.Called(phase) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go new file mode 100644 index 000000000..adade901a --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go @@ -0,0 +1,70 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" + +// MutableTaskNodeStatus is an autogenerated mock type for the MutableTaskNodeStatus type +type MutableTaskNodeStatus struct { + mock.Mock +} + +// GetCustomState provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetCustomState() map[string]interface{} { + ret := _m.Called() + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func() map[string]interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetPhase() types.TaskPhase { + ret := _m.Called() + + var r0 types.TaskPhase + if rf, ok := ret.Get(0).(func() types.TaskPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.TaskPhase) + } + + return r0 +} + +// GetPhaseVersion provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// SetCustomState provides a mock function with given fields: state +func (_m *MutableTaskNodeStatus) SetCustomState(state map[string]interface{}) { + _m.Called(state) +} + +// SetPhase provides a mock function with given fields: phase +func (_m *MutableTaskNodeStatus) SetPhase(phase types.TaskPhase) { + _m.Called(phase) +} + +// SetPhaseVersion provides a mock function with given fields: version +func (_m *MutableTaskNodeStatus) SetPhaseVersion(version uint32) { + _m.Called(version) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go new file mode 100644 index 000000000..a6dc4c8c9 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go @@ -0,0 +1,29 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// MutableWorkflowNodeStatus is an autogenerated mock type for the MutableWorkflowNodeStatus type +type MutableWorkflowNodeStatus struct { + mock.Mock +} + +// GetWorkflowExecutionName provides a mock function with given fields: +func (_m *MutableWorkflowNodeStatus) GetWorkflowExecutionName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// SetWorkflowExecutionName provides a mock function with given fields: name +func (_m *MutableWorkflowNodeStatus) SetWorkflowExecutionName(name string) { + _m.Called(name) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go new file mode 100644 index 000000000..fff448c48 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go @@ -0,0 +1,27 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// NodeStatusGetter is an autogenerated mock type for the NodeStatusGetter type +type NodeStatusGetter struct { + mock.Mock +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *NodeStatusGetter) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go new file mode 100644 index 000000000..fc30a240b --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go @@ -0,0 +1,16 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// NodeStatusVisitor is an autogenerated mock type for the NodeStatusVisitor type +type NodeStatusVisitor struct { + mock.Mock +} + +// VisitNodeStatuses provides a mock function with given fields: visitor +func (_m *NodeStatusVisitor) VisitNodeStatuses(visitor func(string, v1alpha1.ExecutableNodeStatus)) { + _m.Called(visitor) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go new file mode 100644 index 000000000..6149479c4 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go @@ -0,0 +1,143 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// WorkflowMeta is an autogenerated mock type for the WorkflowMeta type +type WorkflowMeta struct { + mock.Mock +} + +// GetAnnotations provides a mock function with given fields: +func (_m *WorkflowMeta) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *WorkflowMeta) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetExecutionID provides a mock function with given fields: +func (_m *WorkflowMeta) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *WorkflowMeta) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *WorkflowMeta) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *WorkflowMeta) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *WorkflowMeta) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *WorkflowMeta) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// NewControllerRef provides a mock function with given fields: +func (_m *WorkflowMeta) NewControllerRef() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go new file mode 100644 index 000000000..a9d69c6b7 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go @@ -0,0 +1,198 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// WorkflowMetaExtended is an autogenerated mock type for the WorkflowMetaExtended type +type WorkflowMetaExtended struct { + mock.Mock +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *WorkflowMetaExtended) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} + +// GetAnnotations provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetExecutionID provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +// GetExecutionStatus provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowStatus) + } + } + + return r0 +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetTask provides a mock function with given fields: id +func (_m *WorkflowMetaExtended) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewControllerRef provides a mock function with given fields: +func (_m *WorkflowMetaExtended) NewControllerRef() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go new file mode 100644 index 000000000..a02c99c94 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -0,0 +1,512 @@ +package v1alpha1 + +import ( + "encoding/json" + "reflect" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type BranchNodeStatus struct { + Phase BranchNodePhase `json:"phase"` + FinalizedNodeID *NodeID `json:"finalNodeId"` +} + +func (in *BranchNodeStatus) GetPhase() BranchNodePhase { + return in.Phase +} + +func (in *BranchNodeStatus) SetBranchNodeError() { + in.Phase = BranchNodeError +} + +func (in *BranchNodeStatus) SetBranchNodeSuccess(id NodeID) { + in.Phase = BranchNodeSuccess + in.FinalizedNodeID = &id +} + +func (in *BranchNodeStatus) GetFinalizedNode() *NodeID { + return in.FinalizedNodeID +} + +func (in *BranchNodeStatus) Equals(other *BranchNodeStatus) bool { + if in == nil && other == nil { + return true + } + if in != nil && other != nil { + phaseEqual := in.Phase == other.Phase + if phaseEqual { + if in.FinalizedNodeID == nil && other.FinalizedNodeID == nil { + return true + } + if in.FinalizedNodeID != nil && other.FinalizedNodeID != nil { + return *in.FinalizedNodeID == *other.FinalizedNodeID + } + return false + } + return false + } + return false +} + +type DynamicNodePhase int + +const ( + DynamicNodePhaseNone DynamicNodePhase = iota + DynamicNodePhaseExecuting +) + +type DynamicNodeStatus struct { + Phase DynamicNodePhase `json:"phase"` +} + +func (s *DynamicNodeStatus) GetDynamicNodePhase() DynamicNodePhase { + return s.Phase +} + +func (s *DynamicNodeStatus) SetDynamicNodePhase(phase DynamicNodePhase) { + s.Phase = phase +} + +func (s *DynamicNodeStatus) Equals(o *DynamicNodeStatus) bool { + if s == nil && o == nil { + return true + } + if s != nil && o != nil { + return s.Phase == o.Phase + } + return false +} + +type SubWorkflowNodeStatus struct { + Phase WorkflowPhase `json:"phase"` +} + +func (s SubWorkflowNodeStatus) GetPhase() WorkflowPhase { + return s.Phase +} + +func (s *SubWorkflowNodeStatus) SetPhase(phase WorkflowPhase) { + s.Phase = phase +} + +type WorkflowNodeStatus struct { + WorkflowName string `json:"name"` +} + +func (in *WorkflowNodeStatus) SetWorkflowExecutionName(name string) { + in.WorkflowName = name +} + +func (in *WorkflowNodeStatus) GetWorkflowExecutionName() string { + return in.WorkflowName +} + +type NodeStatus struct { + Phase NodePhase `json:"phase"` + QueuedAt *metav1.Time `json:"queuedAt,omitempty"` + StartedAt *metav1.Time `json:"startedAt,omitempty"` + StoppedAt *metav1.Time `json:"stoppedAt,omitempty"` + LastUpdatedAt *metav1.Time `json:"lastUpdatedAt,omitempty"` + Message string `json:"message,omitempty"` + DataDir DataReference `json:"dataDir,omitempty"` + Attempts uint32 `json:"attempts"` + Cached bool `json:"cached"` + dirty bool + // This is useful only for branch nodes. If this is set, then it can be used to determine if execution can proceed + ParentNode *NodeID `json:"parentNode,omitempty"` + ParentTask *TaskExecutionIdentifier `json:"parentTask,omitempty"` + BranchStatus *BranchNodeStatus `json:"branchStatus,omitempty"` + SubNodeStatus map[NodeID]*NodeStatus `json:"subNodeStatus,omitempty"` + // We can store the outputs at this layer + + WorkflowNodeStatus *WorkflowNodeStatus `json:"workflowNodeStatus,omitempty"` + TaskNodeStatus *TaskNodeStatus `json:",omitempty"` + SubWorkflowNodeStatus *SubWorkflowNodeStatus `json:"subWorkflowStatus,omitempty"` + DynamicNodeStatus *DynamicNodeStatus `json:"dynamicNodeStatus,omitempty"` +} + +func (in NodeStatus) VisitNodeStatuses(visitor NodeStatusVisitFn) { + for n, s := range in.SubNodeStatus { + visitor(n, s) + } +} + +func (in *NodeStatus) ClearWorkflowStatus() { + in.WorkflowNodeStatus = nil +} + +func (in *NodeStatus) ClearTaskStatus() { + in.TaskNodeStatus = nil +} + +func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { + return in.LastUpdatedAt +} + +func (in *NodeStatus) GetAttempts() uint32 { + return in.Attempts +} + +func (in *NodeStatus) SetCached() { + in.Cached = true + in.setDirty() +} + +func (in *NodeStatus) setDirty() { + in.dirty = true +} +func (in *NodeStatus) IsCached() bool { + return in.Cached +} + +func (in *NodeStatus) IsDirty() bool { + return in.dirty +} + +// ResetDirty is for unit tests, shouldn't be used in actual logic. +func (in *NodeStatus) ResetDirty() { + in.dirty = false +} + +func (in *NodeStatus) IncrementAttempts() uint32 { + in.Attempts++ + in.setDirty() + return in.Attempts +} + +func (in *NodeStatus) GetOrCreateDynamicNodeStatus() MutableDynamicNodeStatus { + if in.DynamicNodeStatus == nil { + in.setDirty() + in.DynamicNodeStatus = &DynamicNodeStatus{} + } + + return in.DynamicNodeStatus +} + +func (in *NodeStatus) ClearDynamicNodeStatus() { + in.DynamicNodeStatus = nil +} + +func (in *NodeStatus) GetOrCreateBranchStatus() MutableBranchNodeStatus { + if in.BranchStatus == nil { + in.BranchStatus = &BranchNodeStatus{} + } + + in.setDirty() + return in.BranchStatus +} + +func (in *NodeStatus) GetWorkflowNodeStatus() ExecutableWorkflowNodeStatus { + if in.WorkflowNodeStatus == nil { + return nil + } + + in.setDirty() + return in.WorkflowNodeStatus +} + +func (in *NodeStatus) GetPhase() NodePhase { + return in.Phase +} + +func (in *NodeStatus) GetMessage() string { + return in.Message +} + +func IsPhaseTerminal(phase NodePhase) bool { + return phase == NodePhaseSucceeded || phase == NodePhaseFailed || phase == NodePhaseSkipped +} + +func (in *NodeStatus) GetOrCreateTaskStatus() MutableTaskNodeStatus { + if in.TaskNodeStatus == nil { + in.TaskNodeStatus = &TaskNodeStatus{} + } + + in.setDirty() + return in.TaskNodeStatus +} + +func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string) { + if in.Phase == p { + // We will not update the phase multiple times. This prevents the comparison from returning false positive + return + } + + in.Phase = p + in.Message = reason + if len(reason) > maxMessageSize { + in.Message = reason[:maxMessageSize] + } + + n := occurredAt + if occurredAt.IsZero() { + n = metav1.Now() + } + + if p == NodePhaseQueued && in.QueuedAt == nil { + in.QueuedAt = &n + } else if p == NodePhaseRunning && in.StartedAt == nil { + in.StartedAt = &n + } else if IsPhaseTerminal(p) && in.StoppedAt == nil { + if in.StartedAt == nil { + in.StartedAt = &n + } + + in.StoppedAt = &n + } + + if in.Phase != p { + in.LastUpdatedAt = &n + } + + in.setDirty() +} + +func (in *NodeStatus) GetStartedAt() *metav1.Time { + return in.StartedAt +} + +func (in *NodeStatus) GetStoppedAt() *metav1.Time { + return in.StoppedAt +} + +func (in *NodeStatus) GetQueuedAt() *metav1.Time { + return in.QueuedAt +} + +func (in *NodeStatus) GetParentNodeID() *NodeID { + return in.ParentNode +} + +func (in *NodeStatus) GetParentTaskID() *core.TaskExecutionIdentifier { + if in.ParentTask != nil { + return in.ParentTask.TaskExecutionIdentifier + } + return nil +} + +func (in *NodeStatus) SetParentNodeID(n *NodeID) { + in.ParentNode = n + in.setDirty() +} + +func (in *NodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { + in.ParentTask = &TaskExecutionIdentifier{ + TaskExecutionIdentifier: t, + } + in.setDirty() +} + +func (in *NodeStatus) GetOrCreateWorkflowStatus() MutableWorkflowNodeStatus { + if in.WorkflowNodeStatus == nil { + in.WorkflowNodeStatus = &WorkflowNodeStatus{} + } + + in.setDirty() + return in.WorkflowNodeStatus +} + +func (in NodeStatus) GetTaskNodeStatus() ExecutableTaskNodeStatus { + // Explicitly return nil here to avoid a misleading non-nil interface. + if in.TaskNodeStatus == nil { + return nil + } + + return in.TaskNodeStatus +} + +func (in NodeStatus) GetSubWorkflowNodeStatus() ExecutableSubWorkflowNodeStatus { + if in.SubWorkflowNodeStatus == nil { + return nil + } + + return in.SubWorkflowNodeStatus +} + +func (in NodeStatus) GetOrCreateSubWorkflowStatus() MutableSubWorkflowNodeStatus { + if in.SubWorkflowNodeStatus == nil { + in.SubWorkflowNodeStatus = &SubWorkflowNodeStatus{} + } + + return in.SubWorkflowNodeStatus +} + +func (in *NodeStatus) ClearSubWorkflowStatus() { + in.SubWorkflowNodeStatus = nil +} + +func (in *NodeStatus) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { + n, ok := in.SubNodeStatus[id] + if ok { + return n + } + if in.SubNodeStatus == nil { + in.SubNodeStatus = make(map[NodeID]*NodeStatus) + } + newNodeStatus := &NodeStatus{} + newNodeStatus.SetParentTaskID(in.GetParentTaskID()) + newNodeStatus.SetParentNodeID(in.GetParentNodeID()) + + in.SubNodeStatus[id] = newNodeStatus + return newNodeStatus +} + +func (in *NodeStatus) IsTerminated() bool { + return in.GetPhase() == NodePhaseFailed || in.GetPhase() == NodePhaseSkipped || in.GetPhase() == NodePhaseSucceeded +} + +func (in *NodeStatus) GetDataDir() DataReference { + return in.DataDir +} + +func (in *NodeStatus) SetDataDir(d DataReference) { + in.DataDir = d + in.setDirty() +} + +func (in *NodeStatus) Equals(other *NodeStatus) bool { + // Assuming in is never nil + if other == nil { + return false + } + + if in.Attempts != other.Attempts { + return false + } + + if in.Phase != other.Phase { + return false + } + + if !reflect.DeepEqual(in.TaskNodeStatus, other.TaskNodeStatus) { + return false + } + + if in.DataDir != other.DataDir { + return false + } + + if in.ParentNode != nil && other.ParentNode != nil { + if *in.ParentNode != *other.ParentNode { + return false + } + } else if !(in.ParentNode == other.ParentNode) { + // Both are not nil + return false + } + + if !reflect.DeepEqual(in.ParentTask, other.ParentTask) { + return false + } + + if len(in.SubNodeStatus) != len(other.SubNodeStatus) { + return false + } + + for k, v := range in.SubNodeStatus { + otherV, ok := other.SubNodeStatus[k] + if !ok { + return false + } + if !v.Equals(otherV) { + return false + } + } + + return in.BranchStatus.Equals(other.BranchStatus) // && in.DynamicNodeStatus.Equals(other.DynamicNodeStatus) +} + +// THIS IS NOT AUTO GENERATED +func (in *CustomState) DeepCopyInto(out *CustomState) { + if in == nil || *in == nil { + return + } + + raw, err := json.Marshal(in) + if err != nil { + return + } + + err = json.Unmarshal(raw, out) + if err != nil { + return + } +} + +func (in *CustomState) DeepCopy() *CustomState { + if in == nil || *in == nil { + return nil + } + + out := &CustomState{} + in.DeepCopyInto(out) + return out +} + +type TaskNodeStatus struct { + Phase types.TaskPhase `json:"phase,omitempty"` + PhaseVersion uint32 `json:"phaseVersion,omitempty"` + CustomState types.CustomState `json:"custom,omitempty"` +} + +func (in *TaskNodeStatus) SetPhase(phase types.TaskPhase) { + in.Phase = phase +} + +func (in *TaskNodeStatus) SetPhaseVersion(version uint32) { + in.PhaseVersion = version +} + +func (in *TaskNodeStatus) SetCustomState(state types.CustomState) { + in.CustomState = state +} + +func (in TaskNodeStatus) GetPhase() types.TaskPhase { + return in.Phase +} + +func (in TaskNodeStatus) GetPhaseVersion() uint32 { + return in.PhaseVersion +} + +func (in TaskNodeStatus) GetCustomState() types.CustomState { + return in.CustomState +} + +func (in *TaskNodeStatus) UpdatePhase(phase types.TaskPhase, phaseVersion uint32) { + in.Phase = phase + in.PhaseVersion = phaseVersion +} + +func (in *TaskNodeStatus) UpdateCustomState(state types.CustomState) { + in.CustomState = state +} + +func (in *TaskNodeStatus) DeepCopyInto(out *TaskNodeStatus) { + if in == nil { + return + } + + raw, err := json.Marshal(in) + if err != nil { + return + } + + err = json.Unmarshal(raw, out) + if err != nil { + return + } +} + +func (in *TaskNodeStatus) DeepCopy() *TaskNodeStatus { + if in == nil { + return nil + } + + out := &TaskNodeStatus{} + in.DeepCopyInto(out) + return out +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go b/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go new file mode 100644 index 000000000..f458c1c38 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go @@ -0,0 +1,156 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsPhaseTerminal(t *testing.T) { + assert.True(t, IsPhaseTerminal(NodePhaseFailed)) + assert.True(t, IsPhaseTerminal(NodePhaseSkipped)) + assert.True(t, IsPhaseTerminal(NodePhaseSucceeded)) + + assert.False(t, IsPhaseTerminal(NodePhaseFailing)) + assert.False(t, IsPhaseTerminal(NodePhaseRunning)) + assert.False(t, IsPhaseTerminal(NodePhaseNotYetStarted)) +} + +func TestNodeStatus_Equals(t *testing.T) { + one := &NodeStatus{} + var other *NodeStatus + assert.False(t, one.Equals(other)) + + other = &NodeStatus{} + assert.True(t, one.Equals(other)) + + one.Phase = NodePhaseRunning + assert.False(t, one.Equals(other)) + + other.Phase = one.Phase + assert.True(t, one.Equals(other)) + + one.DataDir = "data-dir" + assert.False(t, one.Equals(other)) + + other.DataDir = one.DataDir + assert.True(t, one.Equals(other)) + + parentNode := "x" + one.ParentNode = &parentNode + assert.False(t, one.Equals(other)) + + parentNode2 := "y" + other.ParentNode = &parentNode2 + assert.False(t, one.Equals(other)) + + other.ParentNode = &parentNode + assert.True(t, one.Equals(other)) + + one.BranchStatus = &BranchNodeStatus{} + assert.False(t, one.Equals(other)) + other.BranchStatus = &BranchNodeStatus{} + assert.True(t, one.Equals(other)) + + node := "x" + one.SubNodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.False(t, one.Equals(other)) + other.SubNodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.True(t, one.Equals(other)) + + one.SubNodeStatus[node].Phase = NodePhaseRunning + assert.False(t, one.Equals(other)) + other.SubNodeStatus[node].Phase = NodePhaseRunning + assert.True(t, one.Equals(other)) +} + +func TestBranchNodeStatus_Equals(t *testing.T) { + var one *BranchNodeStatus + var other *BranchNodeStatus + assert.True(t, one.Equals(other)) + one = &BranchNodeStatus{} + + assert.False(t, one.Equals(other)) + other = &BranchNodeStatus{} + + assert.True(t, one.Equals(other)) + + one.Phase = BranchNodeError + assert.False(t, one.Equals(other)) + other.Phase = one.Phase + + assert.True(t, one.Equals(other)) + + node := "x" + one.FinalizedNodeID = &node + assert.False(t, one.Equals(other)) + + node2 := "y" + other.FinalizedNodeID = &node2 + assert.False(t, one.Equals(other)) + + node2 = node + other.FinalizedNodeID = &node2 + assert.True(t, one.Equals(other)) +} + +func TestDynamicNodeStatus_Equals(t *testing.T) { + var one *DynamicNodeStatus + var other *DynamicNodeStatus + assert.True(t, one.Equals(other)) + one = &DynamicNodeStatus{} + + assert.False(t, one.Equals(other)) + other = &DynamicNodeStatus{} + + assert.True(t, one.Equals(other)) + + one.Phase = DynamicNodePhaseExecuting + assert.False(t, one.Equals(other)) + other.Phase = one.Phase + + assert.True(t, one.Equals(other)) +} + +func TestCustomState_DeepCopyInto(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var in CustomState + var out CustomState + in.DeepCopyInto(&out) + assert.Nil(t, in) + assert.Nil(t, out) + }) + + t.Run("Not nil in", func(t *testing.T) { + in := CustomState(map[string]interface{}{ + "key1": "hello", + }) + + var out CustomState + in.DeepCopyInto(&out) + assert.NotNil(t, out) + assert.Equal(t, 1, len(out)) + }) +} + +func TestCustomState_DeepCopy(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var in CustomState + assert.Nil(t, in) + assert.Nil(t, in.DeepCopy()) + }) + + t.Run("Not nil in", func(t *testing.T) { + in := CustomState(map[string]interface{}{ + "key1": "hello", + }) + + out := in.DeepCopy() + assert.NotNil(t, out) + assert.Equal(t, 1, len(*out)) + }) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/pkg/apis/flyteworkflow/v1alpha1/nodes.go new file mode 100644 index 000000000..3a6fa0bfe --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -0,0 +1,193 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + typesv1 "k8s.io/api/core/v1" +) + +var marshaler = jsonpb.Marshaler{} + +type OutputVarMap struct { + *core.VariableMap +} + +func (in *OutputVarMap) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.VariableMap); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func (in *OutputVarMap) UnmarshalJSON(b []byte) error { + in.VariableMap = &core.VariableMap{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.VariableMap) +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *OutputVarMap) DeepCopyInto(out *OutputVarMap) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type Binding struct { + *core.Binding +} + +func (in *Binding) UnmarshalJSON(b []byte) error { + in.Binding = &core.Binding{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.Binding) +} + +func (in *Binding) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.Binding); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Binding) DeepCopyInto(out *Binding) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +// Strategy to be used to Retry a node that is in RetryableFailure state +type RetryStrategy struct { + // MinAttempts implies the atleast n attempts to try this node before giving up. The atleast here is because we may + // fail to write the attempt information and end up retrying again. + // Also `0` and `1` both mean atleast one attempt will be done. 0 is a degenerate case. + MinAttempts *int `json:"minAttempts"` + // TODO Add retrydelay? +} + +type Alias struct { + core.Alias +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Alias) DeepCopyInto(out *Alias) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type NodeMetadata struct { + core.NodeMetadata +} + +func (in *NodeMetadata) DeepCopyInto(out *NodeMetadata) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type NodeSpec struct { + ID NodeID `json:"id"` + Resources *typesv1.ResourceRequirements `json:"resources,omitempty"` + Kind NodeKind `json:"kind"` + BranchNode *BranchNodeSpec `json:"branch,omitempty"` + TaskRef *TaskID `json:"task,omitempty"` + WorkflowNode *WorkflowNodeSpec `json:"workflow,omitempty"` + InputBindings []*Binding `json:"inputBindings,omitempty"` + Config *typesv1.ConfigMap `json:"config,omitempty"` + RetryStrategy *RetryStrategy `json:"retry,omitempty"` + OutputAliases []Alias `json:"outputAlias,omitempty"` + + // SecurityContext holds pod-level security attributes and common container settings. + // Optional: Defaults to empty. See type description for default values of each field. + // +optional + SecurityContext *typesv1.PodSecurityContext `json:"securityContext,omitempty" protobuf:"bytes,14,opt,name=securityContext"` + // ImagePullSecrets is an optional list of references to secrets in the same namespace to use for pulling any of the images used by this PodSpec. + // If specified, these secrets will be passed to individual puller implementations for them to use. For example, + // in the case of docker, only DockerConfig type secrets are honored. + // More info: https://kubernetes.io/docs/concepts/containers/images#specifying-imagepullsecrets-on-a-pod + // +optional + // +patchMergeKey=name + // +patchStrategy=merge + ImagePullSecrets []typesv1.LocalObjectReference `json:"imagePullSecrets,omitempty" patchStrategy:"merge" patchMergeKey:"name" protobuf:"bytes,15,rep,name=imagePullSecrets"` + // Specifies the hostname of the Pod + // If not specified, the pod's hostname will be set to a system-defined value. + // +optional + Hostname string `json:"hostname,omitempty" protobuf:"bytes,16,opt,name=hostname"` + // If specified, the fully qualified Pod hostname will be "...svc.". + // If not specified, the pod will not have a domainname at all. + // +optional + Subdomain string `json:"subdomain,omitempty" protobuf:"bytes,17,opt,name=subdomain"` + // If specified, the pod's scheduling constraints + // +optional + Affinity *typesv1.Affinity `json:"affinity,omitempty" protobuf:"bytes,18,opt,name=affinity"` + // If specified, the pod will be dispatched by specified scheduler. + // If not specified, the pod will be dispatched by default scheduler. + // +optional + SchedulerName string `json:"schedulerName,omitempty" protobuf:"bytes,19,opt,name=schedulerName"` + // If specified, the pod's tolerations. + // +optional + Tolerations []typesv1.Toleration `json:"tolerations,omitempty" protobuf:"bytes,22,opt,name=tolerations"` + // StartTime before the system will actively try to mark it failed and kill associated containers. + // Value must be a positive integer. + // +optional + ActiveDeadlineSeconds *int64 `json:"activeDeadlineSeconds,omitempty"` +} + +func (in *NodeSpec) GetRetryStrategy() *RetryStrategy { + return in.RetryStrategy +} + +func (in *NodeSpec) GetConfig() *typesv1.ConfigMap { + return in.Config +} + +func (in *NodeSpec) GetResources() *typesv1.ResourceRequirements { + return in.Resources +} + +func (in *NodeSpec) GetOutputAlias() []Alias { + return in.OutputAliases +} + +func (in *NodeSpec) GetWorkflowNode() ExecutableWorkflowNode { + if in.WorkflowNode == nil { + return nil + } + return in.WorkflowNode +} + +func (in *NodeSpec) GetBranchNode() ExecutableBranchNode { + if in.BranchNode == nil { + return nil + } + return in.BranchNode +} + +func (in *NodeSpec) GetTaskID() *TaskID { + return in.TaskRef +} + +func (in *NodeSpec) GetKind() NodeKind { + return in.Kind +} + +func (in *NodeSpec) GetID() NodeID { + return in.ID +} + +func (in *NodeSpec) IsStartNode() bool { + return in.ID == StartNodeID +} + +func (in *NodeSpec) IsEndNode() bool { + return in.ID == EndNodeID +} + +func (in *NodeSpec) GetInputBindings() []*Binding { + return in.InputBindings +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/register.go b/pkg/apis/flyteworkflow/v1alpha1/register.go new file mode 100644 index 000000000..56772feed --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/register.go @@ -0,0 +1,38 @@ +package v1alpha1 + +import ( + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +const FlyteWorkflowKind = "flyteworkflow" + +// SchemeGroupVersion is group version used to register these objects +var SchemeGroupVersion = schema.GroupVersion{Group: flyteworkflow.GroupName, Version: "v1alpha1"} + +// GetKind takes an unqualified kind and returns back a Group qualified GroupKind +func Kind(kind string) schema.GroupKind { + return SchemeGroupVersion.WithKind(kind).GroupKind() +} + +// Resource takes an unqualified resource and returns a Group qualified GroupResource +func Resource(resource string) schema.GroupResource { + return SchemeGroupVersion.WithResource(resource).GroupResource() +} + +var ( + SchemeBuilder = runtime.NewSchemeBuilder(addKnownTypes) + AddToScheme = SchemeBuilder.AddToScheme +) + +// Adds the list of known types to Scheme. +func addKnownTypes(scheme *runtime.Scheme) error { + scheme.AddKnownTypes(SchemeGroupVersion, + &FlyteWorkflow{}, + &FlyteWorkflowList{}, + ) + metav1.AddToGroupVersion(scheme, SchemeGroupVersion) + return nil +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go b/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go new file mode 100644 index 000000000..a7d7532b9 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go @@ -0,0 +1,23 @@ +package v1alpha1 + +type WorkflowNodeSpec struct { + // Either one of the two + LaunchPlanRefID *LaunchPlanRefID `json:"launchPlanRefId,omitempty"` + // We currently want the SubWorkflow to be completely contained in the node. this is because + // We use the node status to store the information of the execution. + // Important Note: This may cause a bloat in case we use the same SubWorkflow in multiple nodes. The recommended + // technique for that is to use launch plan refs. This is because we will end up executing the launch plan refs as + // disparate executions in Flyte propeller. This is potentially better as it prevents us from hitting the storage limit + // in etcd + //+optional. + // Workflow *WorkflowSpec `json:"workflow,omitempty"` + SubWorkflowReference *WorkflowID `json:"subWorkflowRef,omitempty"` +} + +func (in *WorkflowNodeSpec) GetLaunchPlanRefID() *LaunchPlanRefID { + return in.LaunchPlanRefID +} + +func (in *WorkflowNodeSpec) GetSubWorkflowRef() *WorkflowID { + return in.SubWorkflowReference +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/tasks.go b/pkg/apis/flyteworkflow/v1alpha1/tasks.go new file mode 100644 index 000000000..bcf922730 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/tasks.go @@ -0,0 +1,39 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type TaskSpec struct { + *core.TaskTemplate +} + +func (in *TaskSpec) TaskType() TaskType { + return in.Type +} + +func (in *TaskSpec) CoreTask() *core.TaskTemplate { + return in.TaskTemplate +} + +func (in *TaskSpec) DeepCopyInto(out *TaskSpec) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +func (in *TaskSpec) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.TaskTemplate); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *TaskSpec) UnmarshalJSON(b []byte) error { + in.TaskTemplate = &core.TaskTemplate{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.TaskTemplate) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go b/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go new file mode 100644 index 000000000..302022313 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go @@ -0,0 +1,20 @@ +package v1alpha1_test + +import ( + "encoding/json" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestTaskSpec(t *testing.T) { + j, err := ReadYamlFileAsJSON("testdata/task.yaml") + assert.NoError(t, err) + + task := &v1alpha1.TaskSpec{} + assert.NoError(t, json.Unmarshal(j, task)) + + assert.NotNil(t, task.CoreTask()) + assert.Equal(t, "demo", task.TaskType()) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json b/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json new file mode 100644 index 000000000..73d91d8e8 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json @@ -0,0 +1,34 @@ +{ + "branch": { + "if": { + "condition": { + "comparison": { + "operator": "GT", + "leftValue": { + "primitive": { + "integer": "5" + } + }, + "rightValue": { + "var": "x" + } + } + }, + "then": "foo1" + }, + "else": "foo2" + }, + "id": "foobranch", + "inputBindings": [ + { + "binding": { + "promise": { + "nodeId": "start", + "var": "x" + } + }, + "var": "x" + } + ], + "kind": "branch" +} \ No newline at end of file diff --git a/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json b/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json new file mode 100644 index 000000000..c671cdbdd --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json @@ -0,0 +1,15 @@ +{ + "n1": [ + "n2", + "n3" + ], + "n2": [ + "n4" + ], + "n3": [ + "n4" + ], + "n4": [ + "n5" + ] +} \ No newline at end of file diff --git a/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml b/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml new file mode 100644 index 000000000..d99786db5 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml @@ -0,0 +1,33 @@ +id: + name: add-one-and-print +type: "demo" +interface: + inputs: + variables: + value_to_print: + type: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER +metadata: + runtime: + version: 1.19.0b7 + timeout: 0s +container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=add_one_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflyteimage:abc123 + resources: + requests: + - value: "0.000" + - value: "2.000" + - value: 2048Mi + diff --git a/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml b/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml new file mode 100644 index 000000000..c83f8182b --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml @@ -0,0 +1,220 @@ +workflow.apiVersion: flyte.lyft.com/v1alpha1 +kind: flyteworkflow +metadata: + creationTimestamp: null + generateName: dummy-workflow-1-0- + labels: + execution-id: "" + workflow-id: dummy-workflow-1-0 +inputs: + literals: + triggered_date: + scalar: + primitive: + datetime: 2018-08-01T18:09:18Z +spec: + connections: + add-one-and-print-0: + - sum-non-none-0 + add-one-and-print-1: + - add-one-and-print-2 + - sum-and-print-0 + add-one-and-print-2: + - sum-and-print-0 + add-one-and-print-3: + - sum-non-none-0 + start-node: + - add-one-and-print-0 + - add-one-and-print-3 + - print-every-time-0 + sum-and-print-0: + - end-node + - print-every-time-0 + sum-non-none-0: + - add-one-and-print-1 + - sum-and-print-0 + id: dummy-workflow-1-0 + nodes: + add-one-and-print-0: + activeDeadlineSeconds: 0 + id: add-one-and-print-0 + input_bindings: + - binding: + scalar: + primitive: + integer: "3" + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + add-one-and-print-1: + activeDeadlineSeconds: 0 + id: add-one-and-print-1 + input_bindings: + - binding: + promise: + nodeId: sum-non-none-0 + var: out + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + add-one-and-print-2: + activeDeadlineSeconds: 0 + id: add-one-and-print-2 + input_bindings: + - binding: + promise: + nodeId: add-one-and-print-1 + var: out + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + add-one-and-print-3: + activeDeadlineSeconds: 0 + id: add-one-and-print-3 + input_bindings: + - binding: + scalar: + primitive: + integer: "101" + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + end-node: + id: end-node + input_bindings: + - binding: + promise: + nodeId: sum-and-print-0 + var: out + var: output + kind: end + resources: {} + status: + phase: 0 + print-every-time-0: + activeDeadlineSeconds: 0 + id: print-every-time-0 + input_bindings: + - binding: + promise: + nodeId: start-node + var: triggered_date + var: date_triggered + - binding: + promise: + nodeId: sum-and-print-0 + var: out_blob + var: in_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: multi_blob + var: multi_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: out + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: print-every-time + start-node: + id: start-node + kind: start + resources: {} + status: + phase: 0 + sum-and-print-0: + activeDeadlineSeconds: 0 + id: sum-and-print-0 + input_bindings: + - binding: + collection: + bindings: + - promise: + nodeId: sum-non-none-0 + var: out + - promise: + nodeId: add-one-and-print-1 + var: out + - promise: + nodeId: add-one-and-print-2 + var: out + - scalar: + primitive: + integer: "100" + var: values_to_add + kind: task + resources: {} + status: + phase: 0 + task_ref: sum-and-print + sum-non-none-0: + activeDeadlineSeconds: 0 + id: sum-non-none-0 + input_bindings: + - binding: + collection: + bindings: + - promise: + nodeId: add-one-and-print-0 + var: out + - promise: + nodeId: add-one-and-print-3 + var: out + var: values_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: sum-non-none +status: + phase: 0 +tasks: + add-one-and-print: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=add_one_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflyteimage:abc123 + resources: + requests: + - value: "0.000" + - value: "2.000" + - value: 2048Mi + id: + name: add-one-and-print + interface: + inputs: + variables: + value_to_print: + type: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER + metadata: + runtime: + version: 1.19.0b7 + timeout: 0s + type: "7" diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/pkg/apis/flyteworkflow/v1alpha1/workflow.go new file mode 100644 index 000000000..c53a680a9 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -0,0 +1,232 @@ +package v1alpha1 + +import ( + "bytes" + "encoding/json" + + "k8s.io/apimachinery/pkg/types" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/pkg/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const StartNodeID = "start-node" +const EndNodeID = "end-node" + +// +genclient +// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object + +// FlyteWorkflow: represents one Execution Workflow object +type FlyteWorkflow struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + *WorkflowSpec `json:"spec"` + Inputs *Inputs `json:"inputs,omitempty"` + ExecutionID ExecutionID `json:"executionId"` + Tasks map[TaskID]*TaskSpec `json:"tasks"` + SubWorkflows map[WorkflowID]*WorkflowSpec `json:"subWorkflows,omitempty"` + // StartTime before the system will actively try to mark it failed and kill associated containers. + // Value must be a positive integer. + // +optional + ActiveDeadlineSeconds *int64 `json:"activeDeadlineSeconds,omitempty"` + // Specifies the time when the workflow has been accepted into the system. + AcceptedAt *metav1.Time `json:"acceptedAt,omitEmpty"` + // ServiceAccountName is the name of the ServiceAccount to use to run this pod. + // More info: https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/ + // +optional + ServiceAccountName string `json:"serviceAccountName,omitempty" protobuf:"bytes,8,opt,name=serviceAccountName"` + // Status is the only mutable section in the workflow. It holds all the execution information + Status WorkflowStatus `json:"status,omitempty"` +} + +var FlyteWorkflowGVK = SchemeGroupVersion.WithKind(FlyteWorkflowKind) + +func (in *FlyteWorkflow) NewControllerRef() metav1.OwnerReference { + // TODO Open Issue - https://github.com/kubernetes/client-go/issues/308 + // For some reason the CRD does not have the GVK correctly populated. So we will fake it. + if len(in.GroupVersionKind().Group) == 0 || len(in.GroupVersionKind().Kind) == 0 || len(in.GroupVersionKind().Version) == 0 { + return *metav1.NewControllerRef(in, FlyteWorkflowGVK) + } + return *metav1.NewControllerRef(in, in.GroupVersionKind()) +} + +func (in *FlyteWorkflow) GetTask(id TaskID) (ExecutableTask, error) { + t, ok := in.Tasks[id] + if !ok { + return nil, errors.Errorf("Unable to find task with Id [%v]", id) + } + return t, nil +} + +func (in *FlyteWorkflow) GetExecutionStatus() ExecutableWorkflowStatus { + return &in.Status +} + +func (in *FlyteWorkflow) GetK8sWorkflowID() types.NamespacedName { + return types.NamespacedName{ + Name: in.GetName(), + Namespace: in.GetNamespace(), + } +} + +func (in *FlyteWorkflow) GetExecutionID() ExecutionID { + return in.ExecutionID +} + +func (in *FlyteWorkflow) FindSubWorkflow(subID WorkflowID) ExecutableSubWorkflow { + s, ok := in.SubWorkflows[subID] + if !ok { + return nil + } + return s +} + +func (in *FlyteWorkflow) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { + return in.Status.GetNodeExecutionStatus(id) +} + +func (in *FlyteWorkflow) GetServiceAccountName() string { + return in.ServiceAccountName +} + +type Inputs struct { + *core.LiteralMap +} + +func (in *Inputs) UnmarshalJSON(b []byte) error { + in.LiteralMap = &core.LiteralMap{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.LiteralMap) +} + +func (in *Inputs) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.LiteralMap); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Inputs) DeepCopyInto(out *Inputs) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type Connections struct { + DownstreamEdges map[NodeID][]NodeID + UpstreamEdges map[NodeID][]NodeID +} + +func (in *Connections) UnmarshalJSON(b []byte) error { + in.DownstreamEdges = map[NodeID][]NodeID{} + err := json.Unmarshal(b, &in.DownstreamEdges) + if err != nil { + return err + } + in.UpstreamEdges = map[NodeID][]NodeID{} + for from, nodes := range in.DownstreamEdges { + for _, to := range nodes { + if _, ok := in.UpstreamEdges[to]; !ok { + in.UpstreamEdges[to] = []NodeID{} + } + in.UpstreamEdges[to] = append(in.UpstreamEdges[to], from) + } + } + return nil +} + +func (in *Connections) MarshalJSON() ([]byte, error) { + return json.Marshal(in.DownstreamEdges) +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Connections) DeepCopyInto(out *Connections) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +// WorkflowSpec is the spec for the actual Flyte Workflow (DAG) +type WorkflowSpec struct { + ID WorkflowID `json:"id"` + Nodes map[NodeID]*NodeSpec `json:"nodes"` + + // Defines the set of connections (both data dependencies and execution dependencies) that the graph is + // formed of. The execution engine will respect and follow these connections as it determines which nodes + // can and should be executed. + Connections Connections `json:"connections"` + + // Defines a single node to execute in case the system determined the Workflow has failed. + OnFailure *NodeSpec `json:"onFailure,omitempty"` + + // Defines the declaration of the outputs types and names this workflow is expected to generate. + Outputs *OutputVarMap `json:"outputs,omitempty"` + + // Defines the data links used to construct the final outputs of the workflow. Bindings will typically + // refer to specific outputs of a subset of the nodes executed in the Workflow. When executing the end-node, + // the execution engine will traverse these bindings and assemble the final set of outputs of the workflow. + OutputBindings []*Binding `json:"outputBindings,omitempty"` +} + +func (in *WorkflowSpec) StartNode() ExecutableNode { + n, ok := in.Nodes[StartNodeID] + if !ok { + return nil + } + return n +} + +func (in *WorkflowSpec) GetID() WorkflowID { + return in.ID +} + +func (in *WorkflowSpec) FromNode(name NodeID) ([]NodeID, error) { + if _, ok := in.Nodes[name]; !ok { + return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID) + } + downstreamNodes := in.Connections.DownstreamEdges[name] + return downstreamNodes, nil +} + +func (in *WorkflowSpec) GetOutputs() *OutputVarMap { + return in.Outputs +} + +func (in *WorkflowSpec) GetNode(nodeID NodeID) (ExecutableNode, bool) { + n, ok := in.Nodes[nodeID] + return n, ok +} + +func (in *WorkflowSpec) GetConnections() *Connections { + return &in.Connections +} + +func (in *WorkflowSpec) GetOutputBindings() []*Binding { + return in.OutputBindings +} + +func (in *WorkflowSpec) GetOnFailureNode() ExecutableNode { + if in.OnFailure == nil { + return nil + } + return in.OnFailure +} + +func (in *WorkflowSpec) GetNodes() []NodeID { + nodeIds := make([]NodeID, 0, len(in.Nodes)) + for id := range in.Nodes { + nodeIds = append(nodeIds, id) + } + return nodeIds +} + +// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object +// FlyteWorkflowList is a list of FlyteWorkflow resources +type FlyteWorkflowList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata"` + Items []FlyteWorkflow `json:"items"` +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go b/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go new file mode 100644 index 000000000..4027d4d95 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go @@ -0,0 +1,154 @@ +package v1alpha1 + +import ( + "context" + + "github.com/lyft/flytestdlib/storage" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const maxMessageSize = 1024 + +type WorkflowStatus struct { + Phase WorkflowPhase `json:"phase"` + StartedAt *metav1.Time `json:"startedAt,omitempty"` + StoppedAt *metav1.Time `json:"stoppedAt,omitempty"` + LastUpdatedAt *metav1.Time `json:"lastUpdatedAt,omitempty"` + Message string `json:"message,omitempty"` + DataDir DataReference `json:"dataDir,omitempty"` + OutputReference DataReference `json:"outputRef,omitempty"` + + // We can store the outputs at this layer + // We can also store a cross section of nodes being executed currently here. This could be an optimization + + NodeStatus map[NodeID]*NodeStatus `json:"nodeStatus,omitempty"` + + // Number of Attempts completed with rounds resulting in error. this is used to cap out poison pill workflows + // that spin in an error loop. The value should be set at the global level and will be enforced. At the end of + // the retries the workflow will fail + FailedAttempts uint32 `json:"failedAttempts,omitempty"` +} + +func IsWorkflowPhaseTerminal(p WorkflowPhase) bool { + return p == WorkflowPhaseFailed || p == WorkflowPhaseSuccess || p == WorkflowPhaseAborted +} + +func (in *WorkflowStatus) SetMessage(msg string) { + in.Message = msg +} + +func (in *WorkflowStatus) UpdatePhase(p WorkflowPhase, msg string) { + in.Phase = p + in.Message = msg + if len(msg) > maxMessageSize { + in.Message = msg[:maxMessageSize] + } + + n := metav1.Now() + if in.StartedAt == nil { + in.StartedAt = &n + } + + if IsWorkflowPhaseTerminal(p) && in.StoppedAt == nil { + in.StoppedAt = &n + } + + in.LastUpdatedAt = &n +} + +func (in *WorkflowStatus) IncFailedAttempts() { + in.FailedAttempts++ +} + +func (in *WorkflowStatus) GetPhase() WorkflowPhase { + return in.Phase +} + +func (in *WorkflowStatus) GetStartedAt() *metav1.Time { + return in.StartedAt +} + +func (in *WorkflowStatus) GetStoppedAt() *metav1.Time { + return in.StoppedAt +} + +func (in *WorkflowStatus) GetLastUpdatedAt() *metav1.Time { + return in.LastUpdatedAt +} + +func (in *WorkflowStatus) IsTerminated() bool { + return in.Phase == WorkflowPhaseSuccess || in.Phase == WorkflowPhaseFailed || in.Phase == WorkflowPhaseAborted +} + +func (in *WorkflowStatus) GetMessage() string { + return in.Message +} + +func (in *WorkflowStatus) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { + n, ok := in.NodeStatus[id] + if ok { + return n + } + if in.NodeStatus == nil { + in.NodeStatus = make(map[NodeID]*NodeStatus) + } + newNodeStatus := &NodeStatus{} + in.NodeStatus[id] = newNodeStatus + return newNodeStatus +} + +func (in *WorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name NodeID) (storage.DataReference, error) { + return constructor.ConstructReference(ctx, in.GetDataDir(), name, "data") +} + +func (in *WorkflowStatus) GetDataDir() DataReference { + return in.DataDir +} + +func (in *WorkflowStatus) SetDataDir(d DataReference) { + in.DataDir = d +} + +func (in *WorkflowStatus) GetOutputReference() DataReference { + return in.OutputReference +} + +func (in *WorkflowStatus) SetOutputReference(reference DataReference) { + in.OutputReference = reference +} + +func (in *WorkflowStatus) Equals(other *WorkflowStatus) bool { + // Assuming in is never nil! + if other == nil { + return false + } + if in.FailedAttempts != other.FailedAttempts { + return false + } + if in.Phase != other.Phase { + return false + } + // We will not compare the time and message + if in.DataDir != other.DataDir { + return false + } + + if in.OutputReference != other.OutputReference { + return false + } + + if len(in.NodeStatus) != len(other.NodeStatus) { + return false + } + + for k, v := range in.NodeStatus { + otherV, ok := other.NodeStatus[k] + if !ok { + return false + } + if !v.Equals(otherV) { + return false + } + } + return true +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go b/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go new file mode 100644 index 000000000..9d53caac7 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go @@ -0,0 +1,54 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsWorkflowPhaseTerminal(t *testing.T) { + assert.True(t, IsWorkflowPhaseTerminal(WorkflowPhaseFailed)) + assert.True(t, IsWorkflowPhaseTerminal(WorkflowPhaseSuccess)) + + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseFailing)) + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseSucceeding)) + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseReady)) + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseRunning)) +} + +func TestWorkflowStatus_Equals(t *testing.T) { + one := &WorkflowStatus{} + other := &WorkflowStatus{} + assert.True(t, one.Equals(other)) + + one.Phase = WorkflowPhaseRunning + assert.False(t, one.Equals(other)) + + other.Phase = one.Phase + assert.True(t, one.Equals(other)) + + one.DataDir = "data-dir" + assert.False(t, one.Equals(other)) + other.DataDir = one.DataDir + assert.True(t, one.Equals(other)) + + node := "x" + one.NodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.False(t, one.Equals(other)) + other.NodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.True(t, one.Equals(other)) + + one.NodeStatus[node].Phase = NodePhaseRunning + assert.False(t, one.Equals(other)) + other.NodeStatus[node].Phase = NodePhaseRunning + assert.True(t, one.Equals(other)) + + one.OutputReference = "out" + assert.False(t, one.Equals(other)) + other.OutputReference = "out" + assert.True(t, one.Equals(other)) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go b/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go new file mode 100644 index 000000000..c50d118de --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go @@ -0,0 +1,51 @@ +package v1alpha1_test + +import ( + "encoding/json" + "io/ioutil" + "testing" + + "github.com/ghodss/yaml" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func TestMarshalUnmarshal_Connections(t *testing.T) { + r, err := ioutil.ReadFile("testdata/connections.json") + assert.NoError(t, err) + o := v1alpha1.Connections{} + err = json.Unmarshal(r, &o) + assert.NoError(t, err) + assert.Equal(t, map[v1alpha1.NodeID][]v1alpha1.NodeID{ + "n1": {"n2", "n3"}, + "n2": {"n4"}, + "n3": {"n4"}, + "n4": {"n5"}, + }, o.DownstreamEdges) + assert.Equal(t, []v1alpha1.NodeID{"n1"}, o.UpstreamEdges["n2"]) + assert.Equal(t, []v1alpha1.NodeID{"n1"}, o.UpstreamEdges["n3"]) + assert.Equal(t, []v1alpha1.NodeID{"n4"}, o.UpstreamEdges["n5"]) + assert.True(t, sets.NewString(o.UpstreamEdges["n4"]...).Equal(sets.NewString("n2", "n3"))) +} + +func ReadYamlFileAsJSON(path string) ([]byte, error) { + r, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + return yaml.YAMLToJSON(r) +} + +func TestWorkflowSpec(t *testing.T) { + j, err := ReadYamlFileAsJSON("testdata/workflowspec.yaml") + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + err = json.Unmarshal(j, w) + assert.NoError(t, err) + assert.NotNil(t, w.WorkflowSpec) + assert.Nil(t, w.GetOnFailureNode()) + assert.Equal(t, 7, len(w.Connections.DownstreamEdges)) + assert.Equal(t, 8, len(w.Connections.UpstreamEdges)) + +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go new file mode 100644 index 000000000..a4cd7186a --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go @@ -0,0 +1,676 @@ +// +build !ignore_autogenerated + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1 "k8s.io/api/core/v1" + runtime "k8s.io/apimachinery/pkg/runtime" +) + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Alias. +func (in *Alias) DeepCopy() *Alias { + if in == nil { + return nil + } + out := new(Alias) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Binding. +func (in *Binding) DeepCopy() *Binding { + if in == nil { + return nil + } + out := new(Binding) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BooleanExpression. +func (in *BooleanExpression) DeepCopy() *BooleanExpression { + if in == nil { + return nil + } + out := new(BooleanExpression) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BranchNodeSpec) DeepCopyInto(out *BranchNodeSpec) { + *out = *in + in.If.DeepCopyInto(&out.If) + if in.ElseIf != nil { + in, out := &in.ElseIf, &out.ElseIf + *out = make([]*IfBlock, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(IfBlock) + (*in).DeepCopyInto(*out) + } + } + } + if in.Else != nil { + in, out := &in.Else, &out.Else + *out = new(string) + **out = **in + } + if in.ElseFail != nil { + in, out := &in.ElseFail, &out.ElseFail + *out = (*in).DeepCopy() + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BranchNodeSpec. +func (in *BranchNodeSpec) DeepCopy() *BranchNodeSpec { + if in == nil { + return nil + } + out := new(BranchNodeSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BranchNodeStatus) DeepCopyInto(out *BranchNodeStatus) { + *out = *in + if in.FinalizedNodeID != nil { + in, out := &in.FinalizedNodeID, &out.FinalizedNodeID + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BranchNodeStatus. +func (in *BranchNodeStatus) DeepCopy() *BranchNodeStatus { + if in == nil { + return nil + } + out := new(BranchNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Connections. +func (in *Connections) DeepCopy() *Connections { + if in == nil { + return nil + } + out := new(Connections) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamicNodeStatus) DeepCopyInto(out *DynamicNodeStatus) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamicNodeStatus. +func (in *DynamicNodeStatus) DeepCopy() *DynamicNodeStatus { + if in == nil { + return nil + } + out := new(DynamicNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Error. +func (in *Error) DeepCopy() *Error { + if in == nil { + return nil + } + out := new(Error) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FlyteWorkflow) DeepCopyInto(out *FlyteWorkflow) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + if in.WorkflowSpec != nil { + in, out := &in.WorkflowSpec, &out.WorkflowSpec + *out = new(WorkflowSpec) + (*in).DeepCopyInto(*out) + } + if in.Inputs != nil { + in, out := &in.Inputs, &out.Inputs + *out = (*in).DeepCopy() + } + in.ExecutionID.DeepCopyInto(&out.ExecutionID) + if in.Tasks != nil { + in, out := &in.Tasks, &out.Tasks + *out = make(map[string]*TaskSpec, len(*in)) + for key, val := range *in { + var outVal *TaskSpec + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = (*in).DeepCopy() + } + (*out)[key] = outVal + } + } + if in.SubWorkflows != nil { + in, out := &in.SubWorkflows, &out.SubWorkflows + *out = make(map[string]*WorkflowSpec, len(*in)) + for key, val := range *in { + var outVal *WorkflowSpec + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(WorkflowSpec) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + if in.ActiveDeadlineSeconds != nil { + in, out := &in.ActiveDeadlineSeconds, &out.ActiveDeadlineSeconds + *out = new(int64) + **out = **in + } + if in.AcceptedAt != nil { + in, out := &in.AcceptedAt, &out.AcceptedAt + *out = (*in).DeepCopy() + } + in.Status.DeepCopyInto(&out.Status) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FlyteWorkflow. +func (in *FlyteWorkflow) DeepCopy() *FlyteWorkflow { + if in == nil { + return nil + } + out := new(FlyteWorkflow) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *FlyteWorkflow) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FlyteWorkflowList) DeepCopyInto(out *FlyteWorkflowList) { + *out = *in + out.TypeMeta = in.TypeMeta + out.ListMeta = in.ListMeta + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]FlyteWorkflow, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FlyteWorkflowList. +func (in *FlyteWorkflowList) DeepCopy() *FlyteWorkflowList { + if in == nil { + return nil + } + out := new(FlyteWorkflowList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *FlyteWorkflowList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Identifier. +func (in *Identifier) DeepCopy() *Identifier { + if in == nil { + return nil + } + out := new(Identifier) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *IfBlock) DeepCopyInto(out *IfBlock) { + *out = *in + in.Condition.DeepCopyInto(&out.Condition) + if in.ThenNode != nil { + in, out := &in.ThenNode, &out.ThenNode + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IfBlock. +func (in *IfBlock) DeepCopy() *IfBlock { + if in == nil { + return nil + } + out := new(IfBlock) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Inputs. +func (in *Inputs) DeepCopy() *Inputs { + if in == nil { + return nil + } + out := new(Inputs) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeMetadata. +func (in *NodeMetadata) DeepCopy() *NodeMetadata { + if in == nil { + return nil + } + out := new(NodeMetadata) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NodeSpec) DeepCopyInto(out *NodeSpec) { + *out = *in + if in.Resources != nil { + in, out := &in.Resources, &out.Resources + *out = new(v1.ResourceRequirements) + (*in).DeepCopyInto(*out) + } + if in.BranchNode != nil { + in, out := &in.BranchNode, &out.BranchNode + *out = new(BranchNodeSpec) + (*in).DeepCopyInto(*out) + } + if in.TaskRef != nil { + in, out := &in.TaskRef, &out.TaskRef + *out = new(string) + **out = **in + } + if in.WorkflowNode != nil { + in, out := &in.WorkflowNode, &out.WorkflowNode + *out = new(WorkflowNodeSpec) + (*in).DeepCopyInto(*out) + } + if in.InputBindings != nil { + in, out := &in.InputBindings, &out.InputBindings + *out = make([]*Binding, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = (*in).DeepCopy() + } + } + } + if in.Config != nil { + in, out := &in.Config, &out.Config + *out = new(v1.ConfigMap) + (*in).DeepCopyInto(*out) + } + if in.RetryStrategy != nil { + in, out := &in.RetryStrategy, &out.RetryStrategy + *out = new(RetryStrategy) + (*in).DeepCopyInto(*out) + } + if in.OutputAliases != nil { + in, out := &in.OutputAliases, &out.OutputAliases + *out = make([]Alias, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.SecurityContext != nil { + in, out := &in.SecurityContext, &out.SecurityContext + *out = new(v1.PodSecurityContext) + (*in).DeepCopyInto(*out) + } + if in.ImagePullSecrets != nil { + in, out := &in.ImagePullSecrets, &out.ImagePullSecrets + *out = make([]v1.LocalObjectReference, len(*in)) + copy(*out, *in) + } + if in.Affinity != nil { + in, out := &in.Affinity, &out.Affinity + *out = new(v1.Affinity) + (*in).DeepCopyInto(*out) + } + if in.Tolerations != nil { + in, out := &in.Tolerations, &out.Tolerations + *out = make([]v1.Toleration, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.ActiveDeadlineSeconds != nil { + in, out := &in.ActiveDeadlineSeconds, &out.ActiveDeadlineSeconds + *out = new(int64) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeSpec. +func (in *NodeSpec) DeepCopy() *NodeSpec { + if in == nil { + return nil + } + out := new(NodeSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NodeStatus) DeepCopyInto(out *NodeStatus) { + *out = *in + if in.QueuedAt != nil { + in, out := &in.QueuedAt, &out.QueuedAt + *out = (*in).DeepCopy() + } + if in.StartedAt != nil { + in, out := &in.StartedAt, &out.StartedAt + *out = (*in).DeepCopy() + } + if in.StoppedAt != nil { + in, out := &in.StoppedAt, &out.StoppedAt + *out = (*in).DeepCopy() + } + if in.LastUpdatedAt != nil { + in, out := &in.LastUpdatedAt, &out.LastUpdatedAt + *out = (*in).DeepCopy() + } + if in.ParentNode != nil { + in, out := &in.ParentNode, &out.ParentNode + *out = new(string) + **out = **in + } + if in.ParentTask != nil { + in, out := &in.ParentTask, &out.ParentTask + *out = (*in).DeepCopy() + } + if in.BranchStatus != nil { + in, out := &in.BranchStatus, &out.BranchStatus + *out = new(BranchNodeStatus) + (*in).DeepCopyInto(*out) + } + if in.SubNodeStatus != nil { + in, out := &in.SubNodeStatus, &out.SubNodeStatus + *out = make(map[string]*NodeStatus, len(*in)) + for key, val := range *in { + var outVal *NodeStatus + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(NodeStatus) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + if in.WorkflowNodeStatus != nil { + in, out := &in.WorkflowNodeStatus, &out.WorkflowNodeStatus + *out = new(WorkflowNodeStatus) + **out = **in + } + if in.TaskNodeStatus != nil { + in, out := &in.TaskNodeStatus, &out.TaskNodeStatus + *out = (*in).DeepCopy() + } + if in.SubWorkflowNodeStatus != nil { + in, out := &in.SubWorkflowNodeStatus, &out.SubWorkflowNodeStatus + *out = new(SubWorkflowNodeStatus) + **out = **in + } + if in.DynamicNodeStatus != nil { + in, out := &in.DynamicNodeStatus, &out.DynamicNodeStatus + *out = new(DynamicNodeStatus) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeStatus. +func (in *NodeStatus) DeepCopy() *NodeStatus { + if in == nil { + return nil + } + out := new(NodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OutputVarMap. +func (in *OutputVarMap) DeepCopy() *OutputVarMap { + if in == nil { + return nil + } + out := new(OutputVarMap) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RetryStrategy) DeepCopyInto(out *RetryStrategy) { + *out = *in + if in.MinAttempts != nil { + in, out := &in.MinAttempts, &out.MinAttempts + *out = new(int) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RetryStrategy. +func (in *RetryStrategy) DeepCopy() *RetryStrategy { + if in == nil { + return nil + } + out := new(RetryStrategy) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SubWorkflowNodeStatus) DeepCopyInto(out *SubWorkflowNodeStatus) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SubWorkflowNodeStatus. +func (in *SubWorkflowNodeStatus) DeepCopy() *SubWorkflowNodeStatus { + if in == nil { + return nil + } + out := new(SubWorkflowNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TaskExecutionIdentifier. +func (in *TaskExecutionIdentifier) DeepCopy() *TaskExecutionIdentifier { + if in == nil { + return nil + } + out := new(TaskExecutionIdentifier) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TaskSpec. +func (in *TaskSpec) DeepCopy() *TaskSpec { + if in == nil { + return nil + } + out := new(TaskSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowExecutionIdentifier. +func (in *WorkflowExecutionIdentifier) DeepCopy() *WorkflowExecutionIdentifier { + if in == nil { + return nil + } + out := new(WorkflowExecutionIdentifier) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowNodeSpec) DeepCopyInto(out *WorkflowNodeSpec) { + *out = *in + if in.LaunchPlanRefID != nil { + in, out := &in.LaunchPlanRefID, &out.LaunchPlanRefID + *out = (*in).DeepCopy() + } + if in.SubWorkflowReference != nil { + in, out := &in.SubWorkflowReference, &out.SubWorkflowReference + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowNodeSpec. +func (in *WorkflowNodeSpec) DeepCopy() *WorkflowNodeSpec { + if in == nil { + return nil + } + out := new(WorkflowNodeSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowNodeStatus) DeepCopyInto(out *WorkflowNodeStatus) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowNodeStatus. +func (in *WorkflowNodeStatus) DeepCopy() *WorkflowNodeStatus { + if in == nil { + return nil + } + out := new(WorkflowNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowSpec) DeepCopyInto(out *WorkflowSpec) { + *out = *in + if in.Nodes != nil { + in, out := &in.Nodes, &out.Nodes + *out = make(map[string]*NodeSpec, len(*in)) + for key, val := range *in { + var outVal *NodeSpec + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(NodeSpec) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + in.Connections.DeepCopyInto(&out.Connections) + if in.OnFailure != nil { + in, out := &in.OnFailure, &out.OnFailure + *out = new(NodeSpec) + (*in).DeepCopyInto(*out) + } + if in.Outputs != nil { + in, out := &in.Outputs, &out.Outputs + *out = (*in).DeepCopy() + } + if in.OutputBindings != nil { + in, out := &in.OutputBindings, &out.OutputBindings + *out = make([]*Binding, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = (*in).DeepCopy() + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowSpec. +func (in *WorkflowSpec) DeepCopy() *WorkflowSpec { + if in == nil { + return nil + } + out := new(WorkflowSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowStatus) DeepCopyInto(out *WorkflowStatus) { + *out = *in + if in.StartedAt != nil { + in, out := &in.StartedAt, &out.StartedAt + *out = (*in).DeepCopy() + } + if in.StoppedAt != nil { + in, out := &in.StoppedAt, &out.StoppedAt + *out = (*in).DeepCopy() + } + if in.LastUpdatedAt != nil { + in, out := &in.LastUpdatedAt, &out.LastUpdatedAt + *out = (*in).DeepCopy() + } + if in.NodeStatus != nil { + in, out := &in.NodeStatus, &out.NodeStatus + *out = make(map[string]*NodeStatus, len(*in)) + for key, val := range *in { + var outVal *NodeStatus + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(NodeStatus) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowStatus. +func (in *WorkflowStatus) DeepCopy() *WorkflowStatus { + if in == nil { + return nil + } + out := new(WorkflowStatus) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/client/clientset/versioned/clientset.go b/pkg/client/clientset/versioned/clientset.go new file mode 100644 index 000000000..93756ede0 --- /dev/null +++ b/pkg/client/clientset/versioned/clientset.go @@ -0,0 +1,82 @@ +// Code generated by client-gen. DO NOT EDIT. + +package versioned + +import ( + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + discovery "k8s.io/client-go/discovery" + rest "k8s.io/client-go/rest" + flowcontrol "k8s.io/client-go/util/flowcontrol" +) + +type Interface interface { + Discovery() discovery.DiscoveryInterface + FlyteworkflowV1alpha1() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface + // Deprecated: please explicitly pick a version if possible. + Flyteworkflow() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface +} + +// Clientset contains the clients for groups. Each group has exactly one +// version included in a Clientset. +type Clientset struct { + *discovery.DiscoveryClient + flyteworkflowV1alpha1 *flyteworkflowv1alpha1.FlyteworkflowV1alpha1Client +} + +// FlyteworkflowV1alpha1 retrieves the FlyteworkflowV1alpha1Client +func (c *Clientset) FlyteworkflowV1alpha1() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return c.flyteworkflowV1alpha1 +} + +// Deprecated: Flyteworkflow retrieves the default version of FlyteworkflowClient. +// Please explicitly pick a version. +func (c *Clientset) Flyteworkflow() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return c.flyteworkflowV1alpha1 +} + +// Discovery retrieves the DiscoveryClient +func (c *Clientset) Discovery() discovery.DiscoveryInterface { + if c == nil { + return nil + } + return c.DiscoveryClient +} + +// NewForConfig creates a new Clientset for the given config. +func NewForConfig(c *rest.Config) (*Clientset, error) { + configShallowCopy := *c + if configShallowCopy.RateLimiter == nil && configShallowCopy.QPS > 0 { + configShallowCopy.RateLimiter = flowcontrol.NewTokenBucketRateLimiter(configShallowCopy.QPS, configShallowCopy.Burst) + } + var cs Clientset + var err error + cs.flyteworkflowV1alpha1, err = flyteworkflowv1alpha1.NewForConfig(&configShallowCopy) + if err != nil { + return nil, err + } + + cs.DiscoveryClient, err = discovery.NewDiscoveryClientForConfig(&configShallowCopy) + if err != nil { + return nil, err + } + return &cs, nil +} + +// NewForConfigOrDie creates a new Clientset for the given config and +// panics if there is an error in the config. +func NewForConfigOrDie(c *rest.Config) *Clientset { + var cs Clientset + cs.flyteworkflowV1alpha1 = flyteworkflowv1alpha1.NewForConfigOrDie(c) + + cs.DiscoveryClient = discovery.NewDiscoveryClientForConfigOrDie(c) + return &cs +} + +// New creates a new Clientset for the given RESTClient. +func New(c rest.Interface) *Clientset { + var cs Clientset + cs.flyteworkflowV1alpha1 = flyteworkflowv1alpha1.New(c) + + cs.DiscoveryClient = discovery.NewDiscoveryClient(c) + return &cs +} diff --git a/pkg/client/clientset/versioned/doc.go b/pkg/client/clientset/versioned/doc.go new file mode 100644 index 000000000..0e0c2a890 --- /dev/null +++ b/pkg/client/clientset/versioned/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package has the automatically generated clientset. +package versioned diff --git a/pkg/client/clientset/versioned/fake/clientset_generated.go b/pkg/client/clientset/versioned/fake/clientset_generated.go new file mode 100644 index 000000000..65395e5db --- /dev/null +++ b/pkg/client/clientset/versioned/fake/clientset_generated.go @@ -0,0 +1,66 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + fakeflyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/discovery" + fakediscovery "k8s.io/client-go/discovery/fake" + "k8s.io/client-go/testing" +) + +// NewSimpleClientset returns a clientset that will respond with the provided objects. +// It's backed by a very simple object tracker that processes creates, updates and deletions as-is, +// without applying any validations and/or defaults. It shouldn't be considered a replacement +// for a real clientset and is mostly useful in simple unit tests. +func NewSimpleClientset(objects ...runtime.Object) *Clientset { + o := testing.NewObjectTracker(scheme, codecs.UniversalDecoder()) + for _, obj := range objects { + if err := o.Add(obj); err != nil { + panic(err) + } + } + + cs := &Clientset{} + cs.discovery = &fakediscovery.FakeDiscovery{Fake: &cs.Fake} + cs.AddReactor("*", "*", testing.ObjectReaction(o)) + cs.AddWatchReactor("*", func(action testing.Action) (handled bool, ret watch.Interface, err error) { + gvr := action.GetResource() + ns := action.GetNamespace() + watch, err := o.Watch(gvr, ns) + if err != nil { + return false, nil, err + } + return true, watch, nil + }) + + return cs +} + +// Clientset implements clientset.Interface. Meant to be embedded into a +// struct to get a default implementation. This makes faking out just the method +// you want to test easier. +type Clientset struct { + testing.Fake + discovery *fakediscovery.FakeDiscovery +} + +func (c *Clientset) Discovery() discovery.DiscoveryInterface { + return c.discovery +} + +var _ clientset.Interface = &Clientset{} + +// FlyteworkflowV1alpha1 retrieves the FlyteworkflowV1alpha1Client +func (c *Clientset) FlyteworkflowV1alpha1() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return &fakeflyteworkflowv1alpha1.FakeFlyteworkflowV1alpha1{Fake: &c.Fake} +} + +// Flyteworkflow retrieves the FlyteworkflowV1alpha1Client +func (c *Clientset) Flyteworkflow() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return &fakeflyteworkflowv1alpha1.FakeFlyteworkflowV1alpha1{Fake: &c.Fake} +} diff --git a/pkg/client/clientset/versioned/fake/doc.go b/pkg/client/clientset/versioned/fake/doc.go new file mode 100644 index 000000000..3630ed1cd --- /dev/null +++ b/pkg/client/clientset/versioned/fake/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package has the automatically generated fake clientset. +package fake diff --git a/pkg/client/clientset/versioned/fake/register.go b/pkg/client/clientset/versioned/fake/register.go new file mode 100644 index 000000000..23a9c3a39 --- /dev/null +++ b/pkg/client/clientset/versioned/fake/register.go @@ -0,0 +1,40 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + schema "k8s.io/apimachinery/pkg/runtime/schema" + serializer "k8s.io/apimachinery/pkg/runtime/serializer" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" +) + +var scheme = runtime.NewScheme() +var codecs = serializer.NewCodecFactory(scheme) +var parameterCodec = runtime.NewParameterCodec(scheme) +var localSchemeBuilder = runtime.SchemeBuilder{ + flyteworkflowv1alpha1.AddToScheme, +} + +// AddToScheme adds all types of this clientset into the given scheme. This allows composition +// of clientsets, like in: +// +// import ( +// "k8s.io/client-go/kubernetes" +// clientsetscheme "k8s.io/client-go/kubernetes/scheme" +// aggregatorclientsetscheme "k8s.io/kube-aggregator/pkg/client/clientset_generated/clientset/scheme" +// ) +// +// kclientset, _ := kubernetes.NewForConfig(c) +// _ = aggregatorclientsetscheme.AddToScheme(clientsetscheme.Scheme) +// +// After this, RawExtensions in Kubernetes types will serialize kube-aggregator types +// correctly. +var AddToScheme = localSchemeBuilder.AddToScheme + +func init() { + v1.AddToGroupVersion(scheme, schema.GroupVersion{Version: "v1"}) + utilruntime.Must(AddToScheme(scheme)) +} diff --git a/pkg/client/clientset/versioned/scheme/doc.go b/pkg/client/clientset/versioned/scheme/doc.go new file mode 100644 index 000000000..14db57a58 --- /dev/null +++ b/pkg/client/clientset/versioned/scheme/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package contains the scheme of the automatically generated clientset. +package scheme diff --git a/pkg/client/clientset/versioned/scheme/register.go b/pkg/client/clientset/versioned/scheme/register.go new file mode 100644 index 000000000..6323cb326 --- /dev/null +++ b/pkg/client/clientset/versioned/scheme/register.go @@ -0,0 +1,40 @@ +// Code generated by client-gen. DO NOT EDIT. + +package scheme + +import ( + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + schema "k8s.io/apimachinery/pkg/runtime/schema" + serializer "k8s.io/apimachinery/pkg/runtime/serializer" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" +) + +var Scheme = runtime.NewScheme() +var Codecs = serializer.NewCodecFactory(Scheme) +var ParameterCodec = runtime.NewParameterCodec(Scheme) +var localSchemeBuilder = runtime.SchemeBuilder{ + flyteworkflowv1alpha1.AddToScheme, +} + +// AddToScheme adds all types of this clientset into the given scheme. This allows composition +// of clientsets, like in: +// +// import ( +// "k8s.io/client-go/kubernetes" +// clientsetscheme "k8s.io/client-go/kubernetes/scheme" +// aggregatorclientsetscheme "k8s.io/kube-aggregator/pkg/client/clientset_generated/clientset/scheme" +// ) +// +// kclientset, _ := kubernetes.NewForConfig(c) +// _ = aggregatorclientsetscheme.AddToScheme(clientsetscheme.Scheme) +// +// After this, RawExtensions in Kubernetes types will serialize kube-aggregator types +// correctly. +var AddToScheme = localSchemeBuilder.AddToScheme + +func init() { + v1.AddToGroupVersion(Scheme, schema.GroupVersion{Version: "v1"}) + utilruntime.Must(AddToScheme(Scheme)) +} diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go new file mode 100644 index 000000000..93a7ca4e0 --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package has the automatically generated typed clients. +package v1alpha1 diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go new file mode 100644 index 000000000..2b5ba4c8e --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// Package fake has the automatically generated clients. +package fake diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go new file mode 100644 index 000000000..c9f48bec4 --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go @@ -0,0 +1,124 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + labels "k8s.io/apimachinery/pkg/labels" + schema "k8s.io/apimachinery/pkg/runtime/schema" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + testing "k8s.io/client-go/testing" +) + +// FakeFlyteWorkflows implements FlyteWorkflowInterface +type FakeFlyteWorkflows struct { + Fake *FakeFlyteworkflowV1alpha1 + ns string +} + +var flyteworkflowsResource = schema.GroupVersionResource{Group: "flyteworkflow.flyte.net", Version: "v1alpha1", Resource: "flyteworkflows"} + +var flyteworkflowsKind = schema.GroupVersionKind{Group: "flyteworkflow.flyte.net", Version: "v1alpha1", Kind: "FlyteWorkflow"} + +// Get takes name of the flyteWorkflow, and returns the corresponding flyteWorkflow object, and an error if there is any. +func (c *FakeFlyteWorkflows) Get(name string, options v1.GetOptions) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewGetAction(flyteworkflowsResource, c.ns, name), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// List takes label and field selectors, and returns the list of FlyteWorkflows that match those selectors. +func (c *FakeFlyteWorkflows) List(opts v1.ListOptions) (result *v1alpha1.FlyteWorkflowList, err error) { + obj, err := c.Fake. + Invokes(testing.NewListAction(flyteworkflowsResource, flyteworkflowsKind, c.ns, opts), &v1alpha1.FlyteWorkflowList{}) + + if obj == nil { + return nil, err + } + + label, _, _ := testing.ExtractFromListOptions(opts) + if label == nil { + label = labels.Everything() + } + list := &v1alpha1.FlyteWorkflowList{ListMeta: obj.(*v1alpha1.FlyteWorkflowList).ListMeta} + for _, item := range obj.(*v1alpha1.FlyteWorkflowList).Items { + if label.Matches(labels.Set(item.Labels)) { + list.Items = append(list.Items, item) + } + } + return list, err +} + +// Watch returns a watch.Interface that watches the requested flyteWorkflows. +func (c *FakeFlyteWorkflows) Watch(opts v1.ListOptions) (watch.Interface, error) { + return c.Fake. + InvokesWatch(testing.NewWatchAction(flyteworkflowsResource, c.ns, opts)) + +} + +// Create takes the representation of a flyteWorkflow and creates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *FakeFlyteWorkflows) Create(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewCreateAction(flyteworkflowsResource, c.ns, flyteWorkflow), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// Update takes the representation of a flyteWorkflow and updates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *FakeFlyteWorkflows) Update(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewUpdateAction(flyteworkflowsResource, c.ns, flyteWorkflow), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// UpdateStatus was generated because the type contains a Status member. +// Add a +genclient:noStatus comment above the type to avoid generating UpdateStatus(). +func (c *FakeFlyteWorkflows) UpdateStatus(flyteWorkflow *v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) { + obj, err := c.Fake. + Invokes(testing.NewUpdateSubresourceAction(flyteworkflowsResource, "status", c.ns, flyteWorkflow), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// Delete takes name of the flyteWorkflow and deletes it. Returns an error if one occurs. +func (c *FakeFlyteWorkflows) Delete(name string, options *v1.DeleteOptions) error { + _, err := c.Fake. + Invokes(testing.NewDeleteAction(flyteworkflowsResource, c.ns, name), &v1alpha1.FlyteWorkflow{}) + + return err +} + +// DeleteCollection deletes a collection of objects. +func (c *FakeFlyteWorkflows) DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + action := testing.NewDeleteCollectionAction(flyteworkflowsResource, c.ns, listOptions) + + _, err := c.Fake.Invokes(action, &v1alpha1.FlyteWorkflowList{}) + return err +} + +// Patch applies the patch and returns the patched flyteWorkflow. +func (c *FakeFlyteWorkflows) Patch(name string, pt types.PatchType, data []byte, subresources ...string) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewPatchSubresourceAction(flyteworkflowsResource, c.ns, name, pt, data, subresources...), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go new file mode 100644 index 000000000..11460605c --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go @@ -0,0 +1,24 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + rest "k8s.io/client-go/rest" + testing "k8s.io/client-go/testing" +) + +type FakeFlyteworkflowV1alpha1 struct { + *testing.Fake +} + +func (c *FakeFlyteworkflowV1alpha1) FlyteWorkflows(namespace string) v1alpha1.FlyteWorkflowInterface { + return &FakeFlyteWorkflows{c, namespace} +} + +// RESTClient returns a RESTClient that is used to communicate +// with API server by this client implementation. +func (c *FakeFlyteworkflowV1alpha1) RESTClient() rest.Interface { + var ret *rest.RESTClient + return ret +} diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go new file mode 100644 index 000000000..6b2dc62c1 --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go @@ -0,0 +1,175 @@ +// Code generated by client-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + "time" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + scheme "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/scheme" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + rest "k8s.io/client-go/rest" +) + +// FlyteWorkflowsGetter has a method to return a FlyteWorkflowInterface. +// A group's client should implement this interface. +type FlyteWorkflowsGetter interface { + FlyteWorkflows(namespace string) FlyteWorkflowInterface +} + +// FlyteWorkflowInterface has methods to work with FlyteWorkflow resources. +type FlyteWorkflowInterface interface { + Create(*v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) + Update(*v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) + UpdateStatus(*v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) + Delete(name string, options *v1.DeleteOptions) error + DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error + Get(name string, options v1.GetOptions) (*v1alpha1.FlyteWorkflow, error) + List(opts v1.ListOptions) (*v1alpha1.FlyteWorkflowList, error) + Watch(opts v1.ListOptions) (watch.Interface, error) + Patch(name string, pt types.PatchType, data []byte, subresources ...string) (result *v1alpha1.FlyteWorkflow, err error) + FlyteWorkflowExpansion +} + +// flyteWorkflows implements FlyteWorkflowInterface +type flyteWorkflows struct { + client rest.Interface + ns string +} + +// newFlyteWorkflows returns a FlyteWorkflows +func newFlyteWorkflows(c *FlyteworkflowV1alpha1Client, namespace string) *flyteWorkflows { + return &flyteWorkflows{ + client: c.RESTClient(), + ns: namespace, + } +} + +// Get takes name of the flyteWorkflow, and returns the corresponding flyteWorkflow object, and an error if there is any. +func (c *flyteWorkflows) Get(name string, options v1.GetOptions) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Get(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(name). + VersionedParams(&options, scheme.ParameterCodec). + Do(). + Into(result) + return +} + +// List takes label and field selectors, and returns the list of FlyteWorkflows that match those selectors. +func (c *flyteWorkflows) List(opts v1.ListOptions) (result *v1alpha1.FlyteWorkflowList, err error) { + var timeout time.Duration + if opts.TimeoutSeconds != nil { + timeout = time.Duration(*opts.TimeoutSeconds) * time.Second + } + result = &v1alpha1.FlyteWorkflowList{} + err = c.client.Get(). + Namespace(c.ns). + Resource("flyteworkflows"). + VersionedParams(&opts, scheme.ParameterCodec). + Timeout(timeout). + Do(). + Into(result) + return +} + +// Watch returns a watch.Interface that watches the requested flyteWorkflows. +func (c *flyteWorkflows) Watch(opts v1.ListOptions) (watch.Interface, error) { + var timeout time.Duration + if opts.TimeoutSeconds != nil { + timeout = time.Duration(*opts.TimeoutSeconds) * time.Second + } + opts.Watch = true + return c.client.Get(). + Namespace(c.ns). + Resource("flyteworkflows"). + VersionedParams(&opts, scheme.ParameterCodec). + Timeout(timeout). + Watch() +} + +// Create takes the representation of a flyteWorkflow and creates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *flyteWorkflows) Create(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Post(). + Namespace(c.ns). + Resource("flyteworkflows"). + Body(flyteWorkflow). + Do(). + Into(result) + return +} + +// Update takes the representation of a flyteWorkflow and updates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *flyteWorkflows) Update(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Put(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(flyteWorkflow.Name). + Body(flyteWorkflow). + Do(). + Into(result) + return +} + +// UpdateStatus was generated because the type contains a Status member. +// Add a +genclient:noStatus comment above the type to avoid generating UpdateStatus(). + +func (c *flyteWorkflows) UpdateStatus(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Put(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(flyteWorkflow.Name). + SubResource("status"). + Body(flyteWorkflow). + Do(). + Into(result) + return +} + +// Delete takes name of the flyteWorkflow and deletes it. Returns an error if one occurs. +func (c *flyteWorkflows) Delete(name string, options *v1.DeleteOptions) error { + return c.client.Delete(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(name). + Body(options). + Do(). + Error() +} + +// DeleteCollection deletes a collection of objects. +func (c *flyteWorkflows) DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + var timeout time.Duration + if listOptions.TimeoutSeconds != nil { + timeout = time.Duration(*listOptions.TimeoutSeconds) * time.Second + } + return c.client.Delete(). + Namespace(c.ns). + Resource("flyteworkflows"). + VersionedParams(&listOptions, scheme.ParameterCodec). + Timeout(timeout). + Body(options). + Do(). + Error() +} + +// Patch applies the patch and returns the patched flyteWorkflow. +func (c *flyteWorkflows) Patch(name string, pt types.PatchType, data []byte, subresources ...string) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Patch(pt). + Namespace(c.ns). + Resource("flyteworkflows"). + SubResource(subresources...). + Name(name). + Body(data). + Do(). + Into(result) + return +} diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go new file mode 100644 index 000000000..2d7414f23 --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go @@ -0,0 +1,74 @@ +// Code generated by client-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/scheme" + serializer "k8s.io/apimachinery/pkg/runtime/serializer" + rest "k8s.io/client-go/rest" +) + +type FlyteworkflowV1alpha1Interface interface { + RESTClient() rest.Interface + FlyteWorkflowsGetter +} + +// FlyteworkflowV1alpha1Client is used to interact with features provided by the flyteworkflow.flyte.net group. +type FlyteworkflowV1alpha1Client struct { + restClient rest.Interface +} + +func (c *FlyteworkflowV1alpha1Client) FlyteWorkflows(namespace string) FlyteWorkflowInterface { + return newFlyteWorkflows(c, namespace) +} + +// NewForConfig creates a new FlyteworkflowV1alpha1Client for the given config. +func NewForConfig(c *rest.Config) (*FlyteworkflowV1alpha1Client, error) { + config := *c + if err := setConfigDefaults(&config); err != nil { + return nil, err + } + client, err := rest.RESTClientFor(&config) + if err != nil { + return nil, err + } + return &FlyteworkflowV1alpha1Client{client}, nil +} + +// NewForConfigOrDie creates a new FlyteworkflowV1alpha1Client for the given config and +// panics if there is an error in the config. +func NewForConfigOrDie(c *rest.Config) *FlyteworkflowV1alpha1Client { + client, err := NewForConfig(c) + if err != nil { + panic(err) + } + return client +} + +// New creates a new FlyteworkflowV1alpha1Client for the given RESTClient. +func New(c rest.Interface) *FlyteworkflowV1alpha1Client { + return &FlyteworkflowV1alpha1Client{c} +} + +func setConfigDefaults(config *rest.Config) error { + gv := v1alpha1.SchemeGroupVersion + config.GroupVersion = &gv + config.APIPath = "/apis" + config.NegotiatedSerializer = serializer.DirectCodecFactory{CodecFactory: scheme.Codecs} + + if config.UserAgent == "" { + config.UserAgent = rest.DefaultKubernetesUserAgent() + } + + return nil +} + +// RESTClient returns a RESTClient that is used to communicate +// with API server by this client implementation. +func (c *FlyteworkflowV1alpha1Client) RESTClient() rest.Interface { + if c == nil { + return nil + } + return c.restClient +} diff --git a/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go new file mode 100644 index 000000000..eb8294c16 --- /dev/null +++ b/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go @@ -0,0 +1,5 @@ +// Code generated by client-gen. DO NOT EDIT. + +package v1alpha1 + +type FlyteWorkflowExpansion interface{} diff --git a/pkg/client/informers/externalversions/factory.go b/pkg/client/informers/externalversions/factory.go new file mode 100644 index 000000000..2a094285f --- /dev/null +++ b/pkg/client/informers/externalversions/factory.go @@ -0,0 +1,164 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package externalversions + +import ( + reflect "reflect" + sync "sync" + time "time" + + versioned "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + flyteworkflow "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/flyteworkflow" + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + schema "k8s.io/apimachinery/pkg/runtime/schema" + cache "k8s.io/client-go/tools/cache" +) + +// SharedInformerOption defines the functional option type for SharedInformerFactory. +type SharedInformerOption func(*sharedInformerFactory) *sharedInformerFactory + +type sharedInformerFactory struct { + client versioned.Interface + namespace string + tweakListOptions internalinterfaces.TweakListOptionsFunc + lock sync.Mutex + defaultResync time.Duration + customResync map[reflect.Type]time.Duration + + informers map[reflect.Type]cache.SharedIndexInformer + // startedInformers is used for tracking which informers have been started. + // This allows Start() to be called multiple times safely. + startedInformers map[reflect.Type]bool +} + +// WithCustomResyncConfig sets a custom resync period for the specified informer types. +func WithCustomResyncConfig(resyncConfig map[v1.Object]time.Duration) SharedInformerOption { + return func(factory *sharedInformerFactory) *sharedInformerFactory { + for k, v := range resyncConfig { + factory.customResync[reflect.TypeOf(k)] = v + } + return factory + } +} + +// WithTweakListOptions sets a custom filter on all listers of the configured SharedInformerFactory. +func WithTweakListOptions(tweakListOptions internalinterfaces.TweakListOptionsFunc) SharedInformerOption { + return func(factory *sharedInformerFactory) *sharedInformerFactory { + factory.tweakListOptions = tweakListOptions + return factory + } +} + +// WithNamespace limits the SharedInformerFactory to the specified namespace. +func WithNamespace(namespace string) SharedInformerOption { + return func(factory *sharedInformerFactory) *sharedInformerFactory { + factory.namespace = namespace + return factory + } +} + +// NewSharedInformerFactory constructs a new instance of sharedInformerFactory for all namespaces. +func NewSharedInformerFactory(client versioned.Interface, defaultResync time.Duration) SharedInformerFactory { + return NewSharedInformerFactoryWithOptions(client, defaultResync) +} + +// NewFilteredSharedInformerFactory constructs a new instance of sharedInformerFactory. +// Listers obtained via this SharedInformerFactory will be subject to the same filters +// as specified here. +// Deprecated: Please use NewSharedInformerFactoryWithOptions instead +func NewFilteredSharedInformerFactory(client versioned.Interface, defaultResync time.Duration, namespace string, tweakListOptions internalinterfaces.TweakListOptionsFunc) SharedInformerFactory { + return NewSharedInformerFactoryWithOptions(client, defaultResync, WithNamespace(namespace), WithTweakListOptions(tweakListOptions)) +} + +// NewSharedInformerFactoryWithOptions constructs a new instance of a SharedInformerFactory with additional options. +func NewSharedInformerFactoryWithOptions(client versioned.Interface, defaultResync time.Duration, options ...SharedInformerOption) SharedInformerFactory { + factory := &sharedInformerFactory{ + client: client, + namespace: v1.NamespaceAll, + defaultResync: defaultResync, + informers: make(map[reflect.Type]cache.SharedIndexInformer), + startedInformers: make(map[reflect.Type]bool), + customResync: make(map[reflect.Type]time.Duration), + } + + // Apply all options + for _, opt := range options { + factory = opt(factory) + } + + return factory +} + +// Start initializes all requested informers. +func (f *sharedInformerFactory) Start(stopCh <-chan struct{}) { + f.lock.Lock() + defer f.lock.Unlock() + + for informerType, informer := range f.informers { + if !f.startedInformers[informerType] { + go informer.Run(stopCh) + f.startedInformers[informerType] = true + } + } +} + +// WaitForCacheSync waits for all started informers' cache were synced. +func (f *sharedInformerFactory) WaitForCacheSync(stopCh <-chan struct{}) map[reflect.Type]bool { + informers := func() map[reflect.Type]cache.SharedIndexInformer { + f.lock.Lock() + defer f.lock.Unlock() + + informers := map[reflect.Type]cache.SharedIndexInformer{} + for informerType, informer := range f.informers { + if f.startedInformers[informerType] { + informers[informerType] = informer + } + } + return informers + }() + + res := map[reflect.Type]bool{} + for informType, informer := range informers { + res[informType] = cache.WaitForCacheSync(stopCh, informer.HasSynced) + } + return res +} + +// InternalInformerFor returns the SharedIndexInformer for obj using an internal +// client. +func (f *sharedInformerFactory) InformerFor(obj runtime.Object, newFunc internalinterfaces.NewInformerFunc) cache.SharedIndexInformer { + f.lock.Lock() + defer f.lock.Unlock() + + informerType := reflect.TypeOf(obj) + informer, exists := f.informers[informerType] + if exists { + return informer + } + + resyncPeriod, exists := f.customResync[informerType] + if !exists { + resyncPeriod = f.defaultResync + } + + informer = newFunc(f.client, resyncPeriod) + f.informers[informerType] = informer + + return informer +} + +// SharedInformerFactory provides shared informers for resources in all known +// API group versions. +type SharedInformerFactory interface { + internalinterfaces.SharedInformerFactory + ForResource(resource schema.GroupVersionResource) (GenericInformer, error) + WaitForCacheSync(stopCh <-chan struct{}) map[reflect.Type]bool + + Flyteworkflow() flyteworkflow.Interface +} + +func (f *sharedInformerFactory) Flyteworkflow() flyteworkflow.Interface { + return flyteworkflow.New(f, f.namespace, f.tweakListOptions) +} diff --git a/pkg/client/informers/externalversions/flyteworkflow/interface.go b/pkg/client/informers/externalversions/flyteworkflow/interface.go new file mode 100644 index 000000000..b8410c168 --- /dev/null +++ b/pkg/client/informers/externalversions/flyteworkflow/interface.go @@ -0,0 +1,30 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package flyteworkflow + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1" + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" +) + +// Interface provides access to each of this group's versions. +type Interface interface { + // V1alpha1 provides access to shared informers for resources in V1alpha1. + V1alpha1() v1alpha1.Interface +} + +type group struct { + factory internalinterfaces.SharedInformerFactory + namespace string + tweakListOptions internalinterfaces.TweakListOptionsFunc +} + +// New returns a new Interface. +func New(f internalinterfaces.SharedInformerFactory, namespace string, tweakListOptions internalinterfaces.TweakListOptionsFunc) Interface { + return &group{factory: f, namespace: namespace, tweakListOptions: tweakListOptions} +} + +// V1alpha1 returns a new v1alpha1.Interface. +func (g *group) V1alpha1() v1alpha1.Interface { + return v1alpha1.New(g.factory, g.namespace, g.tweakListOptions) +} diff --git a/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go b/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go new file mode 100644 index 000000000..3ea918b5d --- /dev/null +++ b/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go @@ -0,0 +1,73 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + time "time" + + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + versioned "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" + v1alpha1 "github.com/lyft/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + watch "k8s.io/apimachinery/pkg/watch" + cache "k8s.io/client-go/tools/cache" +) + +// FlyteWorkflowInformer provides access to a shared informer and lister for +// FlyteWorkflows. +type FlyteWorkflowInformer interface { + Informer() cache.SharedIndexInformer + Lister() v1alpha1.FlyteWorkflowLister +} + +type flyteWorkflowInformer struct { + factory internalinterfaces.SharedInformerFactory + tweakListOptions internalinterfaces.TweakListOptionsFunc + namespace string +} + +// NewFlyteWorkflowInformer constructs a new informer for FlyteWorkflow type. +// Always prefer using an informer factory to get a shared informer instead of getting an independent +// one. This reduces memory footprint and number of connections to the server. +func NewFlyteWorkflowInformer(client versioned.Interface, namespace string, resyncPeriod time.Duration, indexers cache.Indexers) cache.SharedIndexInformer { + return NewFilteredFlyteWorkflowInformer(client, namespace, resyncPeriod, indexers, nil) +} + +// NewFilteredFlyteWorkflowInformer constructs a new informer for FlyteWorkflow type. +// Always prefer using an informer factory to get a shared informer instead of getting an independent +// one. This reduces memory footprint and number of connections to the server. +func NewFilteredFlyteWorkflowInformer(client versioned.Interface, namespace string, resyncPeriod time.Duration, indexers cache.Indexers, tweakListOptions internalinterfaces.TweakListOptionsFunc) cache.SharedIndexInformer { + return cache.NewSharedIndexInformer( + &cache.ListWatch{ + ListFunc: func(options v1.ListOptions) (runtime.Object, error) { + if tweakListOptions != nil { + tweakListOptions(&options) + } + return client.FlyteworkflowV1alpha1().FlyteWorkflows(namespace).List(options) + }, + WatchFunc: func(options v1.ListOptions) (watch.Interface, error) { + if tweakListOptions != nil { + tweakListOptions(&options) + } + return client.FlyteworkflowV1alpha1().FlyteWorkflows(namespace).Watch(options) + }, + }, + &flyteworkflowv1alpha1.FlyteWorkflow{}, + resyncPeriod, + indexers, + ) +} + +func (f *flyteWorkflowInformer) defaultInformer(client versioned.Interface, resyncPeriod time.Duration) cache.SharedIndexInformer { + return NewFilteredFlyteWorkflowInformer(client, f.namespace, resyncPeriod, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc}, f.tweakListOptions) +} + +func (f *flyteWorkflowInformer) Informer() cache.SharedIndexInformer { + return f.factory.InformerFor(&flyteworkflowv1alpha1.FlyteWorkflow{}, f.defaultInformer) +} + +func (f *flyteWorkflowInformer) Lister() v1alpha1.FlyteWorkflowLister { + return v1alpha1.NewFlyteWorkflowLister(f.Informer().GetIndexer()) +} diff --git a/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go b/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go new file mode 100644 index 000000000..c4425cb4c --- /dev/null +++ b/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go @@ -0,0 +1,29 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" +) + +// Interface provides access to all the informers in this group version. +type Interface interface { + // FlyteWorkflows returns a FlyteWorkflowInformer. + FlyteWorkflows() FlyteWorkflowInformer +} + +type version struct { + factory internalinterfaces.SharedInformerFactory + namespace string + tweakListOptions internalinterfaces.TweakListOptionsFunc +} + +// New returns a new Interface. +func New(f internalinterfaces.SharedInformerFactory, namespace string, tweakListOptions internalinterfaces.TweakListOptionsFunc) Interface { + return &version{factory: f, namespace: namespace, tweakListOptions: tweakListOptions} +} + +// FlyteWorkflows returns a FlyteWorkflowInformer. +func (v *version) FlyteWorkflows() FlyteWorkflowInformer { + return &flyteWorkflowInformer{factory: v.factory, namespace: v.namespace, tweakListOptions: v.tweakListOptions} +} diff --git a/pkg/client/informers/externalversions/generic.go b/pkg/client/informers/externalversions/generic.go new file mode 100644 index 000000000..3d1564aa5 --- /dev/null +++ b/pkg/client/informers/externalversions/generic.go @@ -0,0 +1,46 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package externalversions + +import ( + "fmt" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + schema "k8s.io/apimachinery/pkg/runtime/schema" + cache "k8s.io/client-go/tools/cache" +) + +// GenericInformer is type of SharedIndexInformer which will locate and delegate to other +// sharedInformers based on type +type GenericInformer interface { + Informer() cache.SharedIndexInformer + Lister() cache.GenericLister +} + +type genericInformer struct { + informer cache.SharedIndexInformer + resource schema.GroupResource +} + +// Informer returns the SharedIndexInformer. +func (f *genericInformer) Informer() cache.SharedIndexInformer { + return f.informer +} + +// Lister returns the GenericLister. +func (f *genericInformer) Lister() cache.GenericLister { + return cache.NewGenericLister(f.Informer().GetIndexer(), f.resource) +} + +// ForResource gives generic access to a shared informer of the matching type +// TODO extend this to unknown resources with a client pool +func (f *sharedInformerFactory) ForResource(resource schema.GroupVersionResource) (GenericInformer, error) { + switch resource { + // Group=flyteworkflow.flyte.net, Version=v1alpha1 + case v1alpha1.SchemeGroupVersion.WithResource("flyteworkflows"): + return &genericInformer{resource: resource.GroupResource(), informer: f.Flyteworkflow().V1alpha1().FlyteWorkflows().Informer()}, nil + + } + + return nil, fmt.Errorf("no informer found for %v", resource) +} diff --git a/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go b/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go new file mode 100644 index 000000000..147b4a34c --- /dev/null +++ b/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go @@ -0,0 +1,24 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package internalinterfaces + +import ( + time "time" + + versioned "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + cache "k8s.io/client-go/tools/cache" +) + +// NewInformerFunc takes versioned.Interface and time.Duration to return a SharedIndexInformer. +type NewInformerFunc func(versioned.Interface, time.Duration) cache.SharedIndexInformer + +// SharedInformerFactory a small interface to allow for adding an informer without an import cycle +type SharedInformerFactory interface { + Start(stopCh <-chan struct{}) + InformerFor(obj runtime.Object, newFunc NewInformerFunc) cache.SharedIndexInformer +} + +// TweakListOptionsFunc is a function that transforms a v1.ListOptions. +type TweakListOptionsFunc func(*v1.ListOptions) diff --git a/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go b/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go new file mode 100644 index 000000000..74ad85548 --- /dev/null +++ b/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go @@ -0,0 +1,11 @@ +// Code generated by lister-gen. DO NOT EDIT. + +package v1alpha1 + +// FlyteWorkflowListerExpansion allows custom methods to be added to +// FlyteWorkflowLister. +type FlyteWorkflowListerExpansion interface{} + +// FlyteWorkflowNamespaceListerExpansion allows custom methods to be added to +// FlyteWorkflowNamespaceLister. +type FlyteWorkflowNamespaceListerExpansion interface{} diff --git a/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go b/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go new file mode 100644 index 000000000..1ddbf256b --- /dev/null +++ b/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go @@ -0,0 +1,78 @@ +// Code generated by lister-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/tools/cache" +) + +// FlyteWorkflowLister helps list FlyteWorkflows. +type FlyteWorkflowLister interface { + // List lists all FlyteWorkflows in the indexer. + List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) + // FlyteWorkflows returns an object that can list and get FlyteWorkflows. + FlyteWorkflows(namespace string) FlyteWorkflowNamespaceLister + FlyteWorkflowListerExpansion +} + +// flyteWorkflowLister implements the FlyteWorkflowLister interface. +type flyteWorkflowLister struct { + indexer cache.Indexer +} + +// NewFlyteWorkflowLister returns a new FlyteWorkflowLister. +func NewFlyteWorkflowLister(indexer cache.Indexer) FlyteWorkflowLister { + return &flyteWorkflowLister{indexer: indexer} +} + +// List lists all FlyteWorkflows in the indexer. +func (s *flyteWorkflowLister) List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) { + err = cache.ListAll(s.indexer, selector, func(m interface{}) { + ret = append(ret, m.(*v1alpha1.FlyteWorkflow)) + }) + return ret, err +} + +// FlyteWorkflows returns an object that can list and get FlyteWorkflows. +func (s *flyteWorkflowLister) FlyteWorkflows(namespace string) FlyteWorkflowNamespaceLister { + return flyteWorkflowNamespaceLister{indexer: s.indexer, namespace: namespace} +} + +// FlyteWorkflowNamespaceLister helps list and get FlyteWorkflows. +type FlyteWorkflowNamespaceLister interface { + // List lists all FlyteWorkflows in the indexer for a given namespace. + List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) + // Get retrieves the FlyteWorkflow from the indexer for a given namespace and name. + Get(name string) (*v1alpha1.FlyteWorkflow, error) + FlyteWorkflowNamespaceListerExpansion +} + +// flyteWorkflowNamespaceLister implements the FlyteWorkflowNamespaceLister +// interface. +type flyteWorkflowNamespaceLister struct { + indexer cache.Indexer + namespace string +} + +// List lists all FlyteWorkflows in the indexer for a given namespace. +func (s flyteWorkflowNamespaceLister) List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) { + err = cache.ListAllByNamespace(s.indexer, s.namespace, selector, func(m interface{}) { + ret = append(ret, m.(*v1alpha1.FlyteWorkflow)) + }) + return ret, err +} + +// Get retrieves the FlyteWorkflow from the indexer for a given namespace and name. +func (s flyteWorkflowNamespaceLister) Get(name string) (*v1alpha1.FlyteWorkflow, error) { + obj, exists, err := s.indexer.GetByKey(s.namespace + "/" + name) + if err != nil { + return nil, err + } + if !exists { + return nil, errors.NewNotFound(v1alpha1.Resource("flyteworkflow"), name) + } + return obj.(*v1alpha1.FlyteWorkflow), nil +} diff --git a/pkg/compiler/builders.go b/pkg/compiler/builders.go new file mode 100755 index 000000000..70f320525 --- /dev/null +++ b/pkg/compiler/builders.go @@ -0,0 +1,136 @@ +package compiler + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" +) + +type flyteTask = core.TaskTemplate +type flyteWorkflow = core.CompiledWorkflow +type flyteNode = core.Node + +// A builder object for the Graph struct. This contains information the compiler uses while building the final Graph +// struct. +type workflowBuilder struct { + CoreWorkflow *flyteWorkflow + LaunchPlans map[c.WorkflowIDKey]c.InterfaceProvider + Tasks c.TaskIndex + downstreamNodes c.StringAdjacencyList + upstreamNodes c.StringAdjacencyList + Nodes c.NodeIndex + + // These are references to all subgraphs and tasks passed to CompileWorkflow. They will be passed around but will + // not show in their entirety in the final Graph. The required subset of these will be added to each subgraph as + // the compile traverses them. + allLaunchPlans map[string]c.InterfaceProvider + allTasks c.TaskIndex + allSubWorkflows c.WorkflowIndex +} + +func (w workflowBuilder) GetFailureNode() c.Node { + if w.GetCoreWorkflow() != nil && w.GetCoreWorkflow().GetTemplate() != nil && w.GetCoreWorkflow().GetTemplate().FailureNode != nil { + return w.NewNodeBuilder(w.GetCoreWorkflow().GetTemplate().FailureNode) + } + + return nil +} + +func (w workflowBuilder) GetNodes() c.NodeIndex { + return w.Nodes +} + +func (w workflowBuilder) GetTasks() c.TaskIndex { + return w.Tasks +} + +func (w workflowBuilder) GetDownstreamNodes() c.StringAdjacencyList { + return w.downstreamNodes +} + +func (w workflowBuilder) GetUpstreamNodes() c.StringAdjacencyList { + return w.upstreamNodes +} + +func (w workflowBuilder) NewNodeBuilder(n *flyteNode) c.NodeBuilder { + return &nodeBuilder{flyteNode: n} +} + +func (w workflowBuilder) GetNode(id c.NodeID) (node c.NodeBuilder, found bool) { + node, found = w.Nodes[id] + return +} + +func (w workflowBuilder) GetTask(id c.TaskID) (task c.Task, found bool) { + task, found = w.Tasks[id.String()] + return +} + +func (w workflowBuilder) GetLaunchPlan(id c.LaunchPlanID) (wf c.InterfaceProvider, found bool) { + wf, found = w.LaunchPlans[id.String()] + return +} + +func (w workflowBuilder) GetSubWorkflow(id c.WorkflowID) (wf *core.CompiledWorkflow, found bool) { + wf, found = w.allSubWorkflows[id.String()] + return +} + +func (w workflowBuilder) GetCoreWorkflow() *flyteWorkflow { + return w.CoreWorkflow +} + +// A wrapper around core.nodeBuilder to augment with computed fields during compilation +type nodeBuilder struct { + *flyteNode + subWorkflow c.Workflow + Task c.Task + Iface *core.TypedInterface +} + +func (n nodeBuilder) GetTask() c.Task { + return n.Task +} + +func (n *nodeBuilder) SetTask(task c.Task) { + n.Task = task +} + +func (n nodeBuilder) GetSubWorkflow() c.Workflow { + return n.subWorkflow +} + +func (n nodeBuilder) GetCoreNode() *core.Node { + return n.flyteNode +} + +func (n nodeBuilder) GetInterface() *core.TypedInterface { + return n.Iface +} + +func (n *nodeBuilder) SetInterface(iface *core.TypedInterface) { + n.Iface = iface +} + +func (n *nodeBuilder) SetSubWorkflow(wf c.Workflow) { + n.subWorkflow = wf +} + +func (n *nodeBuilder) SetInputs(inputs []*core.Binding) { + n.Inputs = inputs +} + +type taskBuilder struct { + *flyteTask +} + +func (t taskBuilder) GetCoreTask() *core.TaskTemplate { + return t.flyteTask +} + +func (t taskBuilder) GetID() c.Identifier { + if t.flyteTask.Id != nil { + return *t.flyteTask.Id + } + + return c.Identifier{} +} diff --git a/pkg/compiler/common/builder.go b/pkg/compiler/common/builder.go new file mode 100644 index 000000000..032b9c02e --- /dev/null +++ b/pkg/compiler/common/builder.go @@ -0,0 +1,32 @@ +// This package defines the intermediate layer that the compiler builds and transformers accept. +package common + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +const ( + StartNodeID = "start-node" + EndNodeID = "end-node" +) + +//go:generate mockery -all -output=mocks -case=underscore + +// A mutable workflow used during the build of the intermediate layer. +type WorkflowBuilder interface { + Workflow + AddExecutionEdge(nodeFrom, nodeTo NodeID) + AddNode(n NodeBuilder, errs errors.CompileErrors) (node NodeBuilder, ok bool) + ValidateWorkflow(fg *core.CompiledWorkflow, errs errors.CompileErrors) (Workflow, bool) + NewNodeBuilder(n *core.Node) NodeBuilder +} + +// A mutable node used during the build of the intermediate layer. +type NodeBuilder interface { + Node + SetInterface(iface *core.TypedInterface) + SetInputs(inputs []*core.Binding) + SetSubWorkflow(wf Workflow) + SetTask(task Task) +} diff --git a/pkg/compiler/common/id_set.go b/pkg/compiler/common/id_set.go new file mode 100644 index 000000000..8489f10f0 --- /dev/null +++ b/pkg/compiler/common/id_set.go @@ -0,0 +1,99 @@ +package common + +import ( + "sort" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type Empty struct{} +type Identifier = core.Identifier +type IdentifierSet map[string]Identifier + +// NewString creates a String from a list of values. +func NewIdentifierSet(items ...Identifier) IdentifierSet { + ss := IdentifierSet{} + ss.Insert(items...) + return ss +} + +// Insert adds items to the set. +func (s IdentifierSet) Insert(items ...Identifier) { + for _, item := range items { + s[item.String()] = item + } +} + +// Delete removes all items from the set. +func (s IdentifierSet) Delete(items ...Identifier) { + for _, item := range items { + delete(s, item.String()) + } +} + +// Has returns true if and only if item is contained in the set. +func (s IdentifierSet) Has(item Identifier) bool { + _, contained := s[item.String()] + return contained +} + +// HasAll returns true if and only if all items are contained in the set. +func (s IdentifierSet) HasAll(items ...Identifier) bool { + for _, item := range items { + if !s.Has(item) { + return false + } + } + return true +} + +// HasAny returns true if any items are contained in the set. +func (s IdentifierSet) HasAny(items ...Identifier) bool { + for _, item := range items { + if s.Has(item) { + return true + } + } + return false +} + +type sortableSliceOfString []Identifier + +func (s sortableSliceOfString) Len() int { return len(s) } +func (s sortableSliceOfString) Less(i, j int) bool { + first, second := s[i], s[j] + if first.ResourceType != second.ResourceType { + return first.ResourceType < second.ResourceType + } + + if first.Project != second.Project { + return first.Project < second.Project + } + + if first.Domain != second.Domain { + return first.Domain < second.Domain + } + + if first.Name != second.Name { + return first.Name < second.Name + } + + if first.Version != second.Version { + return first.Version < second.Version + } + + return false +} + +func (s sortableSliceOfString) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// List returns the contents as a sorted Identifier slice. +func (s IdentifierSet) List() []Identifier { + res := make(sortableSliceOfString, 0, len(s)) + for _, value := range s { + res = append(res, value) + } + + sort.Sort(res) + return []Identifier(res) +} diff --git a/pkg/compiler/common/index.go b/pkg/compiler/common/index.go new file mode 100644 index 000000000..e445616fb --- /dev/null +++ b/pkg/compiler/common/index.go @@ -0,0 +1,71 @@ +package common + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +// Defines an index of nodebuilders based on the id. +type NodeIndex map[NodeID]NodeBuilder + +// Defines an index of tasks based on the id. +type TaskIndex map[TaskIDKey]Task + +type WorkflowIndex map[WorkflowIDKey]*core.CompiledWorkflow + +// Defines a string adjacency list. +type AdjacencyList map[string]IdentifierSet + +type StringAdjacencyList map[string]sets.String + +// Converts the sets in the adjacency list to sorted arrays. +func (l AdjacencyList) ToMapOfLists() map[string][]Identifier { + res := make(map[string][]Identifier, len(l)) + for key, set := range l { + res[key] = set.List() + } + + return res +} + +// Creates a new TaskIndex. +func NewTaskIndex(tasks ...Task) TaskIndex { + res := make(TaskIndex, len(tasks)) + for _, task := range tasks { + id := task.GetID() + res[(&id).String()] = task + } + + return res +} + +// Creates a new NodeIndex +func NewNodeIndex(nodes ...NodeBuilder) NodeIndex { + res := make(NodeIndex, len(nodes)) + for _, task := range nodes { + res[task.GetId()] = task + } + + return res +} + +func NewWorkflowIndex(workflows []*core.CompiledWorkflow, errs errors.CompileErrors) (index WorkflowIndex, ok bool) { + ok = true + index = make(WorkflowIndex, len(workflows)) + for _, wf := range workflows { + if wf.Template.Id == nil { + // TODO: Log/Return error + return nil, false + } + + if _, found := index[wf.Template.Id.String()]; found { + errs.Collect(errors.NewDuplicateIDFoundErr(wf.Template.Id.String())) + ok = false + } else { + index[wf.Template.Id.String()] = wf + } + } + + return +} diff --git a/pkg/compiler/common/mocks/interface_provider.go b/pkg/compiler/common/mocks/interface_provider.go new file mode 100644 index 000000000..d7f776ffc --- /dev/null +++ b/pkg/compiler/common/mocks/interface_provider.go @@ -0,0 +1,59 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// InterfaceProvider is an autogenerated mock type for the InterfaceProvider type +type InterfaceProvider struct { + mock.Mock +} + +// GetExpectedInputs provides a mock function with given fields: +func (_m *InterfaceProvider) GetExpectedInputs() *core.ParameterMap { + ret := _m.Called() + + var r0 *core.ParameterMap + if rf, ok := ret.Get(0).(func() *core.ParameterMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ParameterMap) + } + } + + return r0 +} + +// GetExpectedOutputs provides a mock function with given fields: +func (_m *InterfaceProvider) GetExpectedOutputs() *core.VariableMap { + ret := _m.Called() + + var r0 *core.VariableMap + if rf, ok := ret.Get(0).(func() *core.VariableMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.VariableMap) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *InterfaceProvider) GetID() *core.Identifier { + ret := _m.Called() + + var r0 *core.Identifier + if rf, ok := ret.Get(0).(func() *core.Identifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Identifier) + } + } + + return r0 +} diff --git a/pkg/compiler/common/mocks/node.go b/pkg/compiler/common/mocks/node.go new file mode 100644 index 000000000..eebbeb509 --- /dev/null +++ b/pkg/compiler/common/mocks/node.go @@ -0,0 +1,202 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// Node is an autogenerated mock type for the Node type +type Node struct { + mock.Mock +} + +// GetBranchNode provides a mock function with given fields: +func (_m *Node) GetBranchNode() *core.BranchNode { + ret := _m.Called() + + var r0 *core.BranchNode + if rf, ok := ret.Get(0).(func() *core.BranchNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.BranchNode) + } + } + + return r0 +} + +// GetCoreNode provides a mock function with given fields: +func (_m *Node) GetCoreNode() *core.Node { + ret := _m.Called() + + var r0 *core.Node + if rf, ok := ret.Get(0).(func() *core.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Node) + } + } + + return r0 +} + +// GetId provides a mock function with given fields: +func (_m *Node) GetId() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetInputs provides a mock function with given fields: +func (_m *Node) GetInputs() []*core.Binding { + ret := _m.Called() + + var r0 []*core.Binding + if rf, ok := ret.Get(0).(func() []*core.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Binding) + } + } + + return r0 +} + +// GetInterface provides a mock function with given fields: +func (_m *Node) GetInterface() *core.TypedInterface { + ret := _m.Called() + + var r0 *core.TypedInterface + if rf, ok := ret.Get(0).(func() *core.TypedInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TypedInterface) + } + } + + return r0 +} + +// GetMetadata provides a mock function with given fields: +func (_m *Node) GetMetadata() *core.NodeMetadata { + ret := _m.Called() + + var r0 *core.NodeMetadata + if rf, ok := ret.Get(0).(func() *core.NodeMetadata); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.NodeMetadata) + } + } + + return r0 +} + +// GetOutputAliases provides a mock function with given fields: +func (_m *Node) GetOutputAliases() []*core.Alias { + ret := _m.Called() + + var r0 []*core.Alias + if rf, ok := ret.Get(0).(func() []*core.Alias); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Alias) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: +func (_m *Node) GetSubWorkflow() common.Workflow { + ret := _m.Called() + + var r0 common.Workflow + if rf, ok := ret.Get(0).(func() common.Workflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Workflow) + } + } + + return r0 +} + +// GetTask provides a mock function with given fields: +func (_m *Node) GetTask() common.Task { + ret := _m.Called() + + var r0 common.Task + if rf, ok := ret.Get(0).(func() common.Task); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + return r0 +} + +// GetTaskNode provides a mock function with given fields: +func (_m *Node) GetTaskNode() *core.TaskNode { + ret := _m.Called() + + var r0 *core.TaskNode + if rf, ok := ret.Get(0).(func() *core.TaskNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskNode) + } + } + + return r0 +} + +// GetUpstreamNodeIds provides a mock function with given fields: +func (_m *Node) GetUpstreamNodeIds() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetWorkflowNode provides a mock function with given fields: +func (_m *Node) GetWorkflowNode() *core.WorkflowNode { + ret := _m.Called() + + var r0 *core.WorkflowNode + if rf, ok := ret.Get(0).(func() *core.WorkflowNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.WorkflowNode) + } + } + + return r0 +} diff --git a/pkg/compiler/common/mocks/node_builder.go b/pkg/compiler/common/mocks/node_builder.go new file mode 100644 index 000000000..2a164b806 --- /dev/null +++ b/pkg/compiler/common/mocks/node_builder.go @@ -0,0 +1,222 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// NodeBuilder is an autogenerated mock type for the NodeBuilder type +type NodeBuilder struct { + mock.Mock +} + +// GetBranchNode provides a mock function with given fields: +func (_m *NodeBuilder) GetBranchNode() *core.BranchNode { + ret := _m.Called() + + var r0 *core.BranchNode + if rf, ok := ret.Get(0).(func() *core.BranchNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.BranchNode) + } + } + + return r0 +} + +// GetCoreNode provides a mock function with given fields: +func (_m *NodeBuilder) GetCoreNode() *core.Node { + ret := _m.Called() + + var r0 *core.Node + if rf, ok := ret.Get(0).(func() *core.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Node) + } + } + + return r0 +} + +// GetId provides a mock function with given fields: +func (_m *NodeBuilder) GetId() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetInputs provides a mock function with given fields: +func (_m *NodeBuilder) GetInputs() []*core.Binding { + ret := _m.Called() + + var r0 []*core.Binding + if rf, ok := ret.Get(0).(func() []*core.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Binding) + } + } + + return r0 +} + +// GetInterface provides a mock function with given fields: +func (_m *NodeBuilder) GetInterface() *core.TypedInterface { + ret := _m.Called() + + var r0 *core.TypedInterface + if rf, ok := ret.Get(0).(func() *core.TypedInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TypedInterface) + } + } + + return r0 +} + +// GetMetadata provides a mock function with given fields: +func (_m *NodeBuilder) GetMetadata() *core.NodeMetadata { + ret := _m.Called() + + var r0 *core.NodeMetadata + if rf, ok := ret.Get(0).(func() *core.NodeMetadata); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.NodeMetadata) + } + } + + return r0 +} + +// GetOutputAliases provides a mock function with given fields: +func (_m *NodeBuilder) GetOutputAliases() []*core.Alias { + ret := _m.Called() + + var r0 []*core.Alias + if rf, ok := ret.Get(0).(func() []*core.Alias); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Alias) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: +func (_m *NodeBuilder) GetSubWorkflow() common.Workflow { + ret := _m.Called() + + var r0 common.Workflow + if rf, ok := ret.Get(0).(func() common.Workflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Workflow) + } + } + + return r0 +} + +// GetTask provides a mock function with given fields: +func (_m *NodeBuilder) GetTask() common.Task { + ret := _m.Called() + + var r0 common.Task + if rf, ok := ret.Get(0).(func() common.Task); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + return r0 +} + +// GetTaskNode provides a mock function with given fields: +func (_m *NodeBuilder) GetTaskNode() *core.TaskNode { + ret := _m.Called() + + var r0 *core.TaskNode + if rf, ok := ret.Get(0).(func() *core.TaskNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskNode) + } + } + + return r0 +} + +// GetUpstreamNodeIds provides a mock function with given fields: +func (_m *NodeBuilder) GetUpstreamNodeIds() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetWorkflowNode provides a mock function with given fields: +func (_m *NodeBuilder) GetWorkflowNode() *core.WorkflowNode { + ret := _m.Called() + + var r0 *core.WorkflowNode + if rf, ok := ret.Get(0).(func() *core.WorkflowNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.WorkflowNode) + } + } + + return r0 +} + +// SetInputs provides a mock function with given fields: inputs +func (_m *NodeBuilder) SetInputs(inputs []*core.Binding) { + _m.Called(inputs) +} + +// SetInterface provides a mock function with given fields: iface +func (_m *NodeBuilder) SetInterface(iface *core.TypedInterface) { + _m.Called(iface) +} + +// SetSubWorkflow provides a mock function with given fields: wf +func (_m *NodeBuilder) SetSubWorkflow(wf common.Workflow) { + _m.Called(wf) +} + +// SetTask provides a mock function with given fields: task +func (_m *NodeBuilder) SetTask(task common.Task) { + _m.Called(task) +} diff --git a/pkg/compiler/common/mocks/task.go b/pkg/compiler/common/mocks/task.go new file mode 100644 index 000000000..476961874 --- /dev/null +++ b/pkg/compiler/common/mocks/task.go @@ -0,0 +1,57 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// Task is an autogenerated mock type for the Task type +type Task struct { + mock.Mock +} + +// GetCoreTask provides a mock function with given fields: +func (_m *Task) GetCoreTask() *core.TaskTemplate { + ret := _m.Called() + + var r0 *core.TaskTemplate + if rf, ok := ret.Get(0).(func() *core.TaskTemplate); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskTemplate) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *Task) GetID() core.Identifier { + ret := _m.Called() + + var r0 core.Identifier + if rf, ok := ret.Get(0).(func() core.Identifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(core.Identifier) + } + + return r0 +} + +// GetInterface provides a mock function with given fields: +func (_m *Task) GetInterface() *core.TypedInterface { + ret := _m.Called() + + var r0 *core.TypedInterface + if rf, ok := ret.Get(0).(func() *core.TypedInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TypedInterface) + } + } + + return r0 +} diff --git a/pkg/compiler/common/mocks/workflow.go b/pkg/compiler/common/mocks/workflow.go new file mode 100644 index 000000000..2e1a3dc09 --- /dev/null +++ b/pkg/compiler/common/mocks/workflow.go @@ -0,0 +1,200 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// Workflow is an autogenerated mock type for the Workflow type +type Workflow struct { + mock.Mock +} + +// GetCoreWorkflow provides a mock function with given fields: +func (_m *Workflow) GetCoreWorkflow() *core.CompiledWorkflow { + ret := _m.Called() + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func() *core.CompiledWorkflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + return r0 +} + +// GetDownstreamNodes provides a mock function with given fields: +func (_m *Workflow) GetDownstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} + +// GetFailureNode provides a mock function with given fields: +func (_m *Workflow) GetFailureNode() common.Node { + ret := _m.Called() + + var r0 common.Node + if rf, ok := ret.Get(0).(func() common.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Node) + } + } + + return r0 +} + +// GetLaunchPlan provides a mock function with given fields: id +func (_m *Workflow) GetLaunchPlan(id core.Identifier) (common.InterfaceProvider, bool) { + ret := _m.Called(id) + + var r0 common.InterfaceProvider + if rf, ok := ret.Get(0).(func(core.Identifier) common.InterfaceProvider); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.InterfaceProvider) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNode provides a mock function with given fields: id +func (_m *Workflow) GetNode(id string) (common.NodeBuilder, bool) { + ret := _m.Called(id) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(string) common.NodeBuilder); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodes provides a mock function with given fields: +func (_m *Workflow) GetNodes() common.NodeIndex { + ret := _m.Called() + + var r0 common.NodeIndex + if rf, ok := ret.Get(0).(func() common.NodeIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeIndex) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: id +func (_m *Workflow) GetSubWorkflow(id core.Identifier) (*core.CompiledWorkflow, bool) { + ret := _m.Called(id) + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func(core.Identifier) *core.CompiledWorkflow); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTask provides a mock function with given fields: id +func (_m *Workflow) GetTask(id core.Identifier) (common.Task, bool) { + ret := _m.Called(id) + + var r0 common.Task + if rf, ok := ret.Get(0).(func(core.Identifier) common.Task); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTasks provides a mock function with given fields: +func (_m *Workflow) GetTasks() common.TaskIndex { + ret := _m.Called() + + var r0 common.TaskIndex + if rf, ok := ret.Get(0).(func() common.TaskIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.TaskIndex) + } + } + + return r0 +} + +// GetUpstreamNodes provides a mock function with given fields: +func (_m *Workflow) GetUpstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} diff --git a/pkg/compiler/common/mocks/workflow_builder.go b/pkg/compiler/common/mocks/workflow_builder.go new file mode 100644 index 000000000..7816e7624 --- /dev/null +++ b/pkg/compiler/common/mocks/workflow_builder.go @@ -0,0 +1,268 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import errors "github.com/lyft/flytepropeller/pkg/compiler/errors" +import mock "github.com/stretchr/testify/mock" + +// WorkflowBuilder is an autogenerated mock type for the WorkflowBuilder type +type WorkflowBuilder struct { + mock.Mock +} + +// AddExecutionEdge provides a mock function with given fields: nodeFrom, nodeTo +func (_m *WorkflowBuilder) AddExecutionEdge(nodeFrom string, nodeTo string) { + _m.Called(nodeFrom, nodeTo) +} + +// AddNode provides a mock function with given fields: n, errs +func (_m *WorkflowBuilder) AddNode(n common.NodeBuilder, errs errors.CompileErrors) (common.NodeBuilder, bool) { + ret := _m.Called(n, errs) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(common.NodeBuilder, errors.CompileErrors) common.NodeBuilder); ok { + r0 = rf(n, errs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(common.NodeBuilder, errors.CompileErrors) bool); ok { + r1 = rf(n, errs) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetCoreWorkflow provides a mock function with given fields: +func (_m *WorkflowBuilder) GetCoreWorkflow() *core.CompiledWorkflow { + ret := _m.Called() + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func() *core.CompiledWorkflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + return r0 +} + +// GetDownstreamNodes provides a mock function with given fields: +func (_m *WorkflowBuilder) GetDownstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} + +// GetFailureNode provides a mock function with given fields: +func (_m *WorkflowBuilder) GetFailureNode() common.Node { + ret := _m.Called() + + var r0 common.Node + if rf, ok := ret.Get(0).(func() common.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Node) + } + } + + return r0 +} + +// GetLaunchPlan provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetLaunchPlan(id core.Identifier) (common.InterfaceProvider, bool) { + ret := _m.Called(id) + + var r0 common.InterfaceProvider + if rf, ok := ret.Get(0).(func(core.Identifier) common.InterfaceProvider); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.InterfaceProvider) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNode provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetNode(id string) (common.NodeBuilder, bool) { + ret := _m.Called(id) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(string) common.NodeBuilder); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodes provides a mock function with given fields: +func (_m *WorkflowBuilder) GetNodes() common.NodeIndex { + ret := _m.Called() + + var r0 common.NodeIndex + if rf, ok := ret.Get(0).(func() common.NodeIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeIndex) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetSubWorkflow(id core.Identifier) (*core.CompiledWorkflow, bool) { + ret := _m.Called(id) + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func(core.Identifier) *core.CompiledWorkflow); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTask provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetTask(id core.Identifier) (common.Task, bool) { + ret := _m.Called(id) + + var r0 common.Task + if rf, ok := ret.Get(0).(func(core.Identifier) common.Task); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTasks provides a mock function with given fields: +func (_m *WorkflowBuilder) GetTasks() common.TaskIndex { + ret := _m.Called() + + var r0 common.TaskIndex + if rf, ok := ret.Get(0).(func() common.TaskIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.TaskIndex) + } + } + + return r0 +} + +// GetUpstreamNodes provides a mock function with given fields: +func (_m *WorkflowBuilder) GetUpstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} + +// NewNodeBuilder provides a mock function with given fields: n +func (_m *WorkflowBuilder) NewNodeBuilder(n *core.Node) common.NodeBuilder { + ret := _m.Called(n) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(*core.Node) common.NodeBuilder); ok { + r0 = rf(n) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + return r0 +} + +// ValidateWorkflow provides a mock function with given fields: fg, errs +func (_m *WorkflowBuilder) ValidateWorkflow(fg *core.CompiledWorkflow, errs errors.CompileErrors) (common.Workflow, bool) { + ret := _m.Called(fg, errs) + + var r0 common.Workflow + if rf, ok := ret.Get(0).(func(*core.CompiledWorkflow, errors.CompileErrors) common.Workflow); ok { + r0 = rf(fg, errs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Workflow) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(*core.CompiledWorkflow, errors.CompileErrors) bool); ok { + r1 = rf(fg, errs) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} diff --git a/pkg/compiler/common/reader.go b/pkg/compiler/common/reader.go new file mode 100644 index 000000000..2edd098da --- /dev/null +++ b/pkg/compiler/common/reader.go @@ -0,0 +1,55 @@ +package common + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type NodeID = string +type TaskID = Identifier +type WorkflowID = Identifier +type LaunchPlanID = Identifier +type TaskIDKey = string +type WorkflowIDKey = string + +// An immutable workflow that represents the final output of the compiler. +type Workflow interface { + GetNode(id NodeID) (node NodeBuilder, found bool) + GetTask(id TaskID) (task Task, found bool) + GetLaunchPlan(id LaunchPlanID) (wf InterfaceProvider, found bool) + GetSubWorkflow(id WorkflowID) (wf *core.CompiledWorkflow, found bool) + GetCoreWorkflow() *core.CompiledWorkflow + GetFailureNode() Node + GetNodes() NodeIndex + GetTasks() TaskIndex + GetDownstreamNodes() StringAdjacencyList + GetUpstreamNodes() StringAdjacencyList +} + +// An immutable Node that represents the final output of the compiler. +type Node interface { + GetId() NodeID + GetInterface() *core.TypedInterface + GetInputs() []*core.Binding + GetWorkflowNode() *core.WorkflowNode + GetOutputAliases() []*core.Alias + GetUpstreamNodeIds() []string + GetCoreNode() *core.Node + GetBranchNode() *core.BranchNode + GetTaskNode() *core.TaskNode + GetMetadata() *core.NodeMetadata + GetTask() Task + GetSubWorkflow() Workflow +} + +// An immutable task that represents the final output of the compiler. +type Task interface { + GetID() TaskID + GetCoreTask() *core.TaskTemplate + GetInterface() *core.TypedInterface +} + +type InterfaceProvider interface { + GetID() *core.Identifier + GetExpectedInputs() *core.ParameterMap + GetExpectedOutputs() *core.VariableMap +} diff --git a/pkg/compiler/errors/compiler_error_test.go b/pkg/compiler/errors/compiler_error_test.go new file mode 100644 index 000000000..a3d1eb755 --- /dev/null +++ b/pkg/compiler/errors/compiler_error_test.go @@ -0,0 +1,52 @@ +package errors + +import ( + "testing" + + "github.com/magiconair/properties/assert" + "github.com/pkg/errors" +) + +func mustErrorCode(t *testing.T, compileError *CompileError, code ErrorCode) { + assert.Equal(t, code, compileError.Code()) +} + +func TestErrorCodes(t *testing.T) { + testCases := map[ErrorCode]*CompileError{ + CycleDetected: NewCycleDetectedInWorkflowErr("", ""), + BranchNodeIDNotFound: NewBranchNodeNotSpecified(""), + BranchNodeHasNoCondition: NewBranchNodeHasNoCondition(""), + ValueRequired: NewValueRequiredErr("", ""), + NodeReferenceNotFound: NewNodeReferenceNotFoundErr("", ""), + TaskReferenceNotFound: NewTaskReferenceNotFoundErr("", ""), + WorkflowReferenceNotFound: NewWorkflowReferenceNotFoundErr("", ""), + VariableNameNotFound: NewVariableNameNotFoundErr("", "", ""), + DuplicateAlias: NewDuplicateAliasErr("", ""), + DuplicateNodeID: NewDuplicateIDFoundErr(""), + MismatchingTypes: NewMismatchingTypesErr("", "", "", ""), + MismatchingInterfaces: NewMismatchingInterfacesErr("", ""), + InconsistentTypes: NewInconsistentTypesErr("", "", ""), + ParameterBoundMoreThanOnce: NewParameterBoundMoreThanOnceErr("", ""), + ParameterNotBound: NewParameterNotBoundErr("", ""), + NoEntryNodeFound: NewWorkflowHasNoEntryNodeErr(""), + UnreachableNodes: NewUnreachableNodesErr("", ""), + UnrecognizedValue: NewUnrecognizedValueErr("", ""), + WorkflowBuildError: NewWorkflowBuildError(errors.New("")), + } + + for key, value := range testCases { + t.Run(string(key), func(t *testing.T) { + mustErrorCode(t, value, key) + }) + } +} + +func TestIncludeSource(t *testing.T) { + e := NewCycleDetectedInWorkflowErr("", "") + assert.Equal(t, e.source, "") + + SetConfig(Config{IncludeSource: true}) + e = NewCycleDetectedInWorkflowErr("", "") + assert.Equal(t, e.source, "compiler_error_test.go:49") + SetConfig(Config{}) +} diff --git a/pkg/compiler/errors/compiler_errors.go b/pkg/compiler/errors/compiler_errors.go new file mode 100755 index 000000000..86236c894 --- /dev/null +++ b/pkg/compiler/errors/compiler_errors.go @@ -0,0 +1,272 @@ +package errors + +import ( + "fmt" + "runtime" + "strings" +) + +const ( + // A cycle is detected in the Workflow, the error description should detail the nodes involved. + CycleDetected ErrorCode = "CycleDetected" + + // BranchNode is missing a case with ThenNode populated. + BranchNodeIDNotFound ErrorCode = "BranchNodeIdNotFound" + + // BranchNode is missing a condition. + BranchNodeHasNoCondition ErrorCode = "BranchNodeHasNoCondition" + + // An expected field isn't populated. + ValueRequired ErrorCode = "ValueRequired" + + // A nodeBuilder referenced by an edge doesn't belong to the Workflow. + NodeReferenceNotFound ErrorCode = "NodeReferenceNotFound" + + // A Task referenced by a node wasn't found. + TaskReferenceNotFound ErrorCode = "TaskReferenceNotFound" + + // A Workflow referenced by a node wasn't found. + WorkflowReferenceNotFound ErrorCode = "WorkflowReferenceNotFound" + + // A referenced variable (in a parameter or a condition) wasn't found. + VariableNameNotFound ErrorCode = "VariableNameNotFound" + + // An alias existed twice. + DuplicateAlias ErrorCode = "DuplicateAlias" + + // An Id existed twice. + DuplicateNodeID ErrorCode = "DuplicateId" + + // Two types expected to be compatible but aren't. + MismatchingTypes ErrorCode = "MismatchingTypes" + + // A binding is attempted via a list or map syntax, but the underlying type isn't a list or map. + MismatchingBindings ErrorCode = "MismatchingBindings" + + // Two interfaced expected to be compatible but aren't. + MismatchingInterfaces ErrorCode = "MismatchingInterfaces" + + // Expected types to be consistent. + InconsistentTypes ErrorCode = "InconsistentTypes" + + // An input/output parameter was assigned a value through an edge more than once. + ParameterBoundMoreThanOnce ErrorCode = "ParameterBoundMoreThanOnce" + + // One of the required input parameters or a Workflow output parameter wasn't bound. + ParameterNotBound ErrorCode = "ParameterNotBound" + + // When we couldn't assign an entry point to the Workflow. + NoEntryNodeFound ErrorCode = "NoEntryNodeFound" + + // When one more more unreachable node are detected. + UnreachableNodes ErrorCode = "UnreachableNodes" + + // A Value doesn't fall within the expected range. + UnrecognizedValue ErrorCode = "UnrecognizedValue" + + // An unknown error occurred while building the workflow. + WorkflowBuildError ErrorCode = "WorkflowBuildError" + + // A value is expected to be unique but wasnt. + ValueCollision ErrorCode = "ValueCollision" + + // A value isn't on the right syntax. + SyntaxError ErrorCode = "SyntaxError" +) + +func NewBranchNodeNotSpecified(branchNodeID string) *CompileError { + return newError( + BranchNodeIDNotFound, + fmt.Sprintf("BranchNode not assigned"), + branchNodeID, + ) +} + +func NewBranchNodeHasNoCondition(branchNodeID string) *CompileError { + return newError( + BranchNodeHasNoCondition, + "One of the branches on the node doesn't have a condition.", + branchNodeID, + ) +} + +func NewValueRequiredErr(nodeID, paramName string) *CompileError { + return newError( + ValueRequired, + fmt.Sprintf("Value required [%v].", paramName), + nodeID, + ) +} + +func NewParameterNotBoundErr(nodeID, paramName string) *CompileError { + return newError( + ParameterNotBound, + fmt.Sprintf("Parameter not bound [%v].", paramName), + nodeID, + ) +} + +func NewNodeReferenceNotFoundErr(nodeID, referenceID string) *CompileError { + return newError( + NodeReferenceNotFound, + fmt.Sprintf("Referenced node [%v] not found.", referenceID), + nodeID, + ) +} + +func NewWorkflowReferenceNotFoundErr(nodeID, referenceID string) *CompileError { + return newError( + WorkflowReferenceNotFound, + fmt.Sprintf("Referenced Workflow [%v] not found.", referenceID), + nodeID, + ) +} + +func NewTaskReferenceNotFoundErr(nodeID, referenceID string) *CompileError { + return newError( + TaskReferenceNotFound, + fmt.Sprintf("Referenced Task [%v] not found.", referenceID), + nodeID, + ) +} + +func NewVariableNameNotFoundErr(nodeID, referenceID, variableName string) *CompileError { + return newError( + VariableNameNotFound, + fmt.Sprintf("Variable [%v] not found on node [%v].", variableName, referenceID), + nodeID, + ) +} + +func NewParameterBoundMoreThanOnceErr(nodeID, paramName string) *CompileError { + return newError( + ParameterBoundMoreThanOnce, + fmt.Sprintf("Input [%v] is bound more than once.", paramName), + nodeID, + ) +} + +func NewDuplicateAliasErr(nodeID, alias string) *CompileError { + return newError( + DuplicateAlias, + fmt.Sprintf("Duplicate alias [%v] found. An output alias can only be used once in the Workflow.", alias), + nodeID, + ) +} + +func NewDuplicateIDFoundErr(nodeID string) *CompileError { + return newError( + DuplicateNodeID, + "Trying to insert two nodes with the same id.", + nodeID, + ) +} + +func NewMismatchingTypesErr(nodeID, fromVar, fromType, toType string) *CompileError { + return newError( + MismatchingTypes, + fmt.Sprintf("Variable [%v] (type [%v]) doesn't match expected type [%v].", fromVar, fromType, + toType), + nodeID, + ) +} + +func NewMismatchingBindingsErr(nodeID, sinkParam, expectedType, receivedType string) *CompileError { + return newError( + MismatchingBindings, + fmt.Sprintf("Input [%v] on node [%v] expects bindings of type [%v]. Received [%v]", sinkParam, nodeID, expectedType, receivedType), + nodeID, + ) +} + +func NewMismatchingInterfacesErr(nodeID1, nodeID2 string) *CompileError { + return newError( + MismatchingInterfaces, + fmt.Sprintf("Interfaces of nodes [%v] and [%v] do not match.", nodeID1, nodeID2), + nodeID1, + ) +} + +func NewInconsistentTypesErr(nodeID, expectedType, actualType string) *CompileError { + return newError( + InconsistentTypes, + fmt.Sprintf("Expected type: %v but found %v", expectedType, actualType), + nodeID, + ) +} + +func NewWorkflowHasNoEntryNodeErr(graphID string) *CompileError { + return newError( + NoEntryNodeFound, + fmt.Sprintf("Can't find a node to start executing Workflow [%v].", graphID), + graphID, + ) +} + +func NewCycleDetectedInWorkflowErr(nodeID, cycle string) *CompileError { + return newError( + CycleDetected, + fmt.Sprintf("A cycle has been detected while traversing the Workflow [%v].", cycle), + nodeID, + ) +} + +func NewUnreachableNodesErr(nodeID, nodes string) *CompileError { + return newError( + UnreachableNodes, + fmt.Sprintf("The Workflow contain unreachable nodes [%v].", nodes), + nodeID, + ) +} + +func NewUnrecognizedValueErr(nodeID, value string) *CompileError { + return newError( + UnrecognizedValue, + fmt.Sprintf("Unrecognized value [%v].", value), + nodeID, + ) +} + +func NewWorkflowBuildError(err error) *CompileError { + return newError(WorkflowBuildError, err.Error(), "") +} + +func NewValueCollisionError(nodeID string, valueName, value string) *CompileError { + return newError( + ValueCollision, + fmt.Sprintf("%v is expected to be unique. %v already exists.", valueName, value), + nodeID, + ) +} + +func NewSyntaxError(nodeID string, element string, err error) *CompileError { + return newError(SyntaxError, + fmt.Sprintf("Failed to parse element [%v].", element), + nodeID, + ) +} + +func newError(code ErrorCode, description, nodeID string) (err *CompileError) { + err = &CompileError{ + code: code, + description: description, + nodeID: nodeID, + } + + if GetConfig().IncludeSource { + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] + } + } + + err.source = fmt.Sprintf("%v:%v", file, line) + } + + return +} diff --git a/pkg/compiler/errors/config.go b/pkg/compiler/errors/config.go new file mode 100644 index 000000000..7cde12c91 --- /dev/null +++ b/pkg/compiler/errors/config.go @@ -0,0 +1,27 @@ +package errors + +// Represents error config that can change the behavior of how errors collection/reporting is handled. +type Config struct { + // Indicates that a panic should be issued as soon as the first error is collected. + PanicOnError bool + + // Indicates that errors should include source code information when collected. There is an associated performance + // penalty with this behavior. + IncludeSource bool +} + +var config = Config{} + +// Sets global config. +func SetConfig(cfg Config) { + config = cfg +} + +// Gets global config. +func GetConfig() Config { + return config +} + +func SetIncludeSource() { + config.IncludeSource = true +} diff --git a/pkg/compiler/errors/error.go b/pkg/compiler/errors/error.go new file mode 100755 index 000000000..735e42619 --- /dev/null +++ b/pkg/compiler/errors/error.go @@ -0,0 +1,125 @@ +// This package is a central repository of all compile errors that can be reported. It contains ways to collect and format +// errors to make it easy to find and correct workflow spec problems. +package errors + +import "fmt" + +type ErrorCode string + +// Represents a compile error for coreWorkflow. +type CompileError struct { + code ErrorCode + nodeID string + description string + source string +} + +// Represents a compile error with a root cause. +type CompileErrorWithCause struct { + *CompileError + cause error +} + +// A set of Compile errors. +type CompileErrors interface { + error + Collect(e ...*CompileError) + NewScope() CompileErrors + Errors() *compileErrorSet + HasErrors() bool + ErrorCount() int +} + +type compileErrors struct { + errorSet *compileErrorSet + parent CompileErrors + errorCountInScope int +} + +// Gets the Compile Error code +func (err CompileError) Code() ErrorCode { + return err.code +} + +// Gets a readable/formatted string explaining the compile error as well as at which node it occurred. +func (err CompileError) Error() string { + source := "" + if err.source != "" { + source = fmt.Sprintf("[%v] ", err.source) + } + + return fmt.Sprintf("%vCode: %s, Node Id: %s, Description: %s", source, err.code, err.nodeID, err.description) +} + +// Gets a readable/formatted string explaining the compile error as well as at which node it occurred. +func (err CompileErrorWithCause) Error() string { + cause := "" + if err.cause != nil { + cause = fmt.Sprintf(", Cause: %v", err.cause.Error()) + } + + return fmt.Sprintf("%v%v", err.CompileError.Error(), cause) +} + +// Exposes the set of unique errors. +func (errs *compileErrors) Errors() *compileErrorSet { + return errs.errorSet +} + +// Appends a compile error to the set. +func (errs *compileErrors) Collect(e ...*CompileError) { + if e != nil { + if GetConfig().PanicOnError { + panic(e) + } + + if errs.parent != nil { + errs.parent.Collect(e...) + errs.errorCountInScope += len(e) + } else { + for _, err := range e { + if err != nil { + errs.errorSet.Put(*err) + errs.errorCountInScope++ + } + } + } + } +} + +// Creates a new scope for compile errors. Parent scope will always automatically collect errors reported in any of its +// child scopes. +func (errs *compileErrors) NewScope() CompileErrors { + return &compileErrors{parent: errs} +} + +// Gets a formatted string of all compile errors collected. +func (errs *compileErrors) Error() (err string) { + if errs.parent != nil { + return errs.parent.Error() + } + + err = fmt.Sprintf("Collected Errors: %v\n", len(*errs.Errors())) + i := 0 + for _, e := range errs.Errors().List() { + err += fmt.Sprintf("\tError %d: %s\n", i, e.Error()) + i++ + } + + return err +} + +// Gets a value indicating whether there are any errors collected within current scope and all of its children. +func (errs *compileErrors) HasErrors() bool { + return errs.errorCountInScope > 0 +} + +// Gets the number of errors collected within current scope and all of its children. +func (errs *compileErrors) ErrorCount() int { + return errs.errorCountInScope +} + +// Creates a new empty compile errors +func NewCompileErrors() CompileErrors { + return &compileErrors{errorSet: &compileErrorSet{}} +} diff --git a/pkg/compiler/errors/error_test.go b/pkg/compiler/errors/error_test.go new file mode 100755 index 000000000..3756501fb --- /dev/null +++ b/pkg/compiler/errors/error_test.go @@ -0,0 +1,43 @@ +package errors + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func addError(errs CompileErrors) { + errs.Collect(NewValueRequiredErr("node", "param")) +} + +func TestCompileErrors_Collect(t *testing.T) { + errs := NewCompileErrors() + assert.False(t, errs.HasErrors()) + addError(errs) + assert.True(t, errs.HasErrors()) +} + +func TestCompileErrors_NewScope(t *testing.T) { + errs := NewCompileErrors() + addError(errs.NewScope().NewScope()) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 1, errs.ErrorCount()) +} + +func TestCompileErrors_Errors(t *testing.T) { + errs := NewCompileErrors() + addError(errs.NewScope().NewScope()) + addError(errs.NewScope().NewScope()) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 2, errs.ErrorCount()) + + set := errs.Errors() + assert.Equal(t, 1, len(*set)) +} + +func TestCompileErrors_Error(t *testing.T) { + errs := NewCompileErrors() + addError(errs.NewScope().NewScope()) + addError(errs.NewScope().NewScope()) + assert.NotEqual(t, "", errs.Error()) +} diff --git a/pkg/compiler/errors/sets.go b/pkg/compiler/errors/sets.go new file mode 100755 index 000000000..c03c7259d --- /dev/null +++ b/pkg/compiler/errors/sets.go @@ -0,0 +1,44 @@ +package errors + +import ( + "sort" + "strings" +) + +var keyExists = struct{}{} + +type compileErrorSet map[CompileError]struct{} + +func (s compileErrorSet) Put(key CompileError) { + s[key] = keyExists +} + +func (s compileErrorSet) Contains(key CompileError) bool { + _, ok := s[key] + return ok +} + +func (s compileErrorSet) Remove(key CompileError) { + delete(s, key) +} + +func refCompileError(x CompileError) *CompileError { + return &x +} + +func (s compileErrorSet) List() []*CompileError { + res := make([]*CompileError, 0, len(s)) + for key := range s { + res = append(res, refCompileError(key)) + } + + sort.SliceStable(res, func(i, j int) bool { + if res[i].Code() == res[j].Code() { + return res[i].Error() < res[j].Error() + } + + return strings.Compare(string(res[i].Code()), string(res[j].Code())) < 0 + }) + + return res +} diff --git a/pkg/compiler/errors/sets_test.go b/pkg/compiler/errors/sets_test.go new file mode 100644 index 000000000..d487d3b56 --- /dev/null +++ b/pkg/compiler/errors/sets_test.go @@ -0,0 +1,20 @@ +package errors + +import ( + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestCompileErrorSet_List(t *testing.T) { + set := compileErrorSet{} + set.Put(*NewValueRequiredErr("node1", "param")) + set.Put(*NewWorkflowHasNoEntryNodeErr("graph1")) + set.Put(*NewWorkflowHasNoEntryNodeErr("graph1")) + assert.Equal(t, len(set), 2) + + lst := set.List() + assert.Equal(t, len(lst), 2) + assert.Equal(t, lst[0].Code(), NoEntryNodeFound) + assert.Equal(t, lst[1].Code(), ValueRequired) +} diff --git a/pkg/compiler/requirements.go b/pkg/compiler/requirements.go new file mode 100755 index 000000000..989ecb403 --- /dev/null +++ b/pkg/compiler/requirements.go @@ -0,0 +1,88 @@ +package compiler + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +type TaskIdentifier = common.Identifier +type LaunchPlanRefIdentifier = common.Identifier + +// Represents the set of required resources for a given Workflow's execution. All of the resources should be loaded before +// hand and passed to the compiler. +type WorkflowExecutionRequirements struct { + taskIds []TaskIdentifier + launchPlanIds []LaunchPlanRefIdentifier +} + +// Gets a slice of required Task ids to load. +func (g WorkflowExecutionRequirements) GetRequiredTaskIds() []TaskIdentifier { + return g.taskIds +} + +// Gets a slice of required Workflow ids to load. +func (g WorkflowExecutionRequirements) GetRequiredLaunchPlanIds() []LaunchPlanRefIdentifier { + return g.launchPlanIds +} + +// Computes requirements for a given Workflow. +func GetRequirements(fg *core.WorkflowTemplate, subWfs []*core.WorkflowTemplate) (reqs WorkflowExecutionRequirements, err error) { + errs := errors.NewCompileErrors() + compiledSubWfs := toCompiledWorkflows(subWfs...) + + index, ok := common.NewWorkflowIndex(compiledSubWfs, errs) + + if ok { + return getRequirements(fg, index, true, errs), nil + } + + return WorkflowExecutionRequirements{}, errs +} + +func getRequirements(fg *core.WorkflowTemplate, subWfs common.WorkflowIndex, followSubworkflows bool, + errs errors.CompileErrors) (reqs WorkflowExecutionRequirements) { + + taskIds := common.NewIdentifierSet() + launchPlanIds := common.NewIdentifierSet() + updateWorkflowRequirements(fg, subWfs, taskIds, launchPlanIds, followSubworkflows, errs) + + reqs.taskIds = taskIds.List() + reqs.launchPlanIds = launchPlanIds.List() + + return +} + +// Augments taskIds and launchPlanIds with referenced tasks/workflows within coreWorkflow nodes +func updateWorkflowRequirements(workflow *core.WorkflowTemplate, subWfs common.WorkflowIndex, + taskIds, workflowIds common.IdentifierSet, followSubworkflows bool, errs errors.CompileErrors) { + + for _, node := range workflow.Nodes { + updateNodeRequirements(node, subWfs, taskIds, workflowIds, followSubworkflows, errs) + } +} + +func updateNodeRequirements(node *flyteNode, subWfs common.WorkflowIndex, taskIds, workflowIds common.IdentifierSet, + followSubworkflows bool, errs errors.CompileErrors) (ok bool) { + + if taskN := node.GetTaskNode(); taskN != nil && taskN.GetReferenceId() != nil { + taskIds.Insert(*taskN.GetReferenceId()) + } else if workflowNode := node.GetWorkflowNode(); workflowNode != nil { + if workflowNode.GetLaunchplanRef() != nil { + workflowIds.Insert(*workflowNode.GetLaunchplanRef()) + } else if workflowNode.GetSubWorkflowRef() != nil && followSubworkflows { + if subWf, found := subWfs[workflowNode.GetSubWorkflowRef().String()]; !found { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr(node.Id, workflowNode.GetSubWorkflowRef().String())) + } else { + updateWorkflowRequirements(subWf.Template, subWfs, taskIds, workflowIds, followSubworkflows, errs) + } + } + } else if branchN := node.GetBranchNode(); branchN != nil { + updateNodeRequirements(branchN.IfElse.Case.ThenNode, subWfs, taskIds, workflowIds, followSubworkflows, errs) + for _, otherCase := range branchN.IfElse.Other { + updateNodeRequirements(otherCase.ThenNode, subWfs, taskIds, workflowIds, followSubworkflows, errs) + } + } + + return !errs.HasErrors() +} diff --git a/pkg/compiler/requirements_test.go b/pkg/compiler/requirements_test.go new file mode 100755 index 000000000..fdc7eaa72 --- /dev/null +++ b/pkg/compiler/requirements_test.go @@ -0,0 +1,125 @@ +package compiler + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestGetRequirements(t *testing.T) { + g := &core.WorkflowTemplate{ + Nodes: []*core.Node{ + { + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_1"}, + }, + }, + }, + }, + { + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_2"}, + }, + }, + }, + }, + { + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "Graph_1"}, + }, + }, + }, + }, + { + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "Graph_1"}, + }, + }, + }, + }, + }, + Other: []*core.IfBlock{ + { + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_3"}, + }, + }, + }, + }, + }, + { + ThenNode: &core.Node{ + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "Graph_2"}, + }, + }, + }, + }, + }, + Other: []*core.IfBlock{ + { + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_4"}, + }, + }, + }, + }, + }, + { + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_5"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(g, subWorkflows) + assert.NoError(t, err) + assert.Equal(t, 5, len(reqs.GetRequiredTaskIds())) + assert.Equal(t, 2, len(reqs.GetRequiredLaunchPlanIds())) +} diff --git a/pkg/compiler/task_compiler.go b/pkg/compiler/task_compiler.go new file mode 100644 index 000000000..08f3b8806 --- /dev/null +++ b/pkg/compiler/task_compiler.go @@ -0,0 +1,98 @@ +package compiler + +import ( + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/api/resource" +) + +func validateResource(resourceName core.Resources_ResourceName, resourceVal string, errs errors.CompileErrors) (ok bool) { + if _, err := resource.ParseQuantity(resourceVal); err != nil { + errs.Collect(errors.NewUnrecognizedValueErr(fmt.Sprintf("resources.%v", resourceName), resourceVal)) + return true + } + return false +} + +func validateKnownResources(resources []*core.Resources_ResourceEntry, errs errors.CompileErrors) (ok bool) { + for _, r := range resources { + validateResource(r.Name, r.Value, errs.NewScope()) + } + + return !errs.HasErrors() +} + +func validateResources(resources *core.Resources, errs errors.CompileErrors) (ok bool) { + // Validate known resource keys. + validateKnownResources(resources.Requests, errs.NewScope()) + validateKnownResources(resources.Limits, errs.NewScope()) + + return !errs.HasErrors() +} + +func validateContainerCommand(task *core.TaskTemplate, errs errors.CompileErrors) (ok bool) { + if task.Interface == nil { + // Nothing to validate. + return + } + hasInputs := task.Interface.Inputs != nil && len(task.Interface.GetInputs().Variables) > 0 + hasOutputs := task.Interface.Outputs != nil && len(task.Interface.GetOutputs().Variables) > 0 + if !(hasInputs || hasOutputs) { + // Nothing to validate. + return + } + if task.GetContainer().Command == nil && task.GetContainer().Args == nil { + // When an interface with inputs or outputs is defined, the container command + args together must not be empty. + errs.Collect(errors.NewValueRequiredErr("container", "command")) + } + + return !errs.HasErrors() +} + +func validateContainer(task *core.TaskTemplate, errs errors.CompileErrors) (ok bool) { + if task.GetContainer() == nil { + errs.Collect(errors.NewValueRequiredErr("root", "container")) + return + } + + validateContainerCommand(task, errs) + + container := task.GetContainer() + if container.Image == "" { + errs.Collect(errors.NewValueRequiredErr("container", "image")) + } + + if container.Resources != nil { + validateResources(container.Resources, errs.NewScope()) + } + + return !errs.HasErrors() +} + +func compileTaskInternal(task *core.TaskTemplate, errs errors.CompileErrors) (common.Task, bool) { + if task.Id == nil { + errs.Collect(errors.NewValueRequiredErr("root", "Id")) + } + + switch task.GetTarget().(type) { + case *core.TaskTemplate_Container: + validateContainer(task, errs.NewScope()) + } + + return taskBuilder{flyteTask: task}, !errs.HasErrors() +} + +// Task compiler compiles a given Task into an executable Task. It validates all required parameters and ensures a Task +// is well-formed. +func CompileTask(task *core.TaskTemplate) (*core.CompiledTask, error) { + errs := errors.NewCompileErrors() + t, _ := compileTaskInternal(task, errs.NewScope()) + if errs.HasErrors() { + return nil, errs + } + + return &core.CompiledTask{Template: t.GetCoreTask()}, nil +} diff --git a/pkg/compiler/task_compiler_test.go b/pkg/compiler/task_compiler_test.go new file mode 100644 index 000000000..ff878509b --- /dev/null +++ b/pkg/compiler/task_compiler_test.go @@ -0,0 +1,82 @@ +package compiler + +import ( + "testing" + + "github.com/lyft/flytepropeller/pkg/compiler/errors" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func MakeResource(name core.Resources_ResourceName, v string) *core.Resources_ResourceEntry { + return &core.Resources_ResourceEntry{ + Name: name, + Value: v, + } +} + +func TestValidateContainerCommand(t *testing.T) { + task := core.TaskTemplate{ + Id: &core.Identifier{Name: "task_123"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "foo": {}, + }), + Outputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + }, + }, + } + errs := errors.NewCompileErrors() + assert.False(t, validateContainerCommand(&task, errs)) + assert.Contains(t, errs.Error(), "Node Id: container, Description: Value required [command]") + + task.GetContainer().Command = []string{"cmd"} + errs = errors.NewCompileErrors() + assert.True(t, validateContainerCommand(&task, errs)) + assert.False(t, errs.HasErrors()) +} + +func TestCompileTask(t *testing.T) { + task, err := CompileTask(&core.TaskTemplate{ + Id: &core.Identifier{Name: "task_123"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + Command: []string{"cmd"}, + Args: []string{"args"}, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + MakeResource(core.Resources_CPU, "5"), + }, + Limits: []*core.Resources_ResourceEntry{ + MakeResource(core.Resources_MEMORY, "100Gi"), + }, + }, + Env: []*core.KeyValuePair{ + { + Key: "Env_Var", + Value: "Env_Val", + }, + }, + Config: []*core.KeyValuePair{ + { + Key: "config_key", + Value: "config_value", + }, + }, + }, + }, + }) + + assert.NoError(t, err) + assert.NotNil(t, task) +} diff --git a/pkg/compiler/test/compiler_test.go b/pkg/compiler/test/compiler_test.go new file mode 100644 index 000000000..f9bc36067 --- /dev/null +++ b/pkg/compiler/test/compiler_test.go @@ -0,0 +1,247 @@ +package test + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ghodss/yaml" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytepropeller/pkg/visualize" + "github.com/stretchr/testify/assert" +) + +var update = flag.Bool("update", false, "Update .golden files") +var reverse = flag.Bool("reverse", false, "Reverse .golden files") + +func makeDefaultInputs(iface *core.TypedInterface) *core.LiteralMap { + if iface == nil || iface.GetInputs() == nil { + return nil + } + + res := make(map[string]*core.Literal, len(iface.GetInputs().Variables)) + for inputName, inputVar := range iface.GetInputs().Variables { + val := utils.MustMakeDefaultLiteralForType(inputVar.Type) + res[inputName] = val + } + + return &core.LiteralMap{ + Literals: res, + } +} + +func setDefaultFields(task *core.TaskTemplate) { + if container := task.GetContainer(); container != nil { + if container.Config == nil { + container.Config = []*core.KeyValuePair{} + } + + container.Config = append(container.Config, &core.KeyValuePair{ + Key: "testKey1", + Value: "testValue1", + }) + container.Config = append(container.Config, &core.KeyValuePair{ + Key: "testKey2", + Value: "testValue2", + }) + container.Config = append(container.Config, &core.KeyValuePair{ + Key: "testKey3", + Value: "testValue3", + }) + } +} + +func mustCompileTasks(t *testing.T, tasks []*core.TaskTemplate) []*core.CompiledTask { + compiledTasks := make([]*core.CompiledTask, 0, len(tasks)) + for _, inputTask := range tasks { + setDefaultFields(inputTask) + task, err := compiler.CompileTask(inputTask) + compiledTasks = append(compiledTasks, task) + assert.NoError(t, err) + if err != nil { + assert.FailNow(t, err.Error()) + } + } + + return compiledTasks +} + +func marshalProto(t *testing.T, filename string, p proto.Message) { + marshaller := &jsonpb.Marshaler{} + s, err := marshaller.MarshalToString(p) + assert.NoError(t, err) + + if err != nil { + return + } + + originalRaw, err := proto.Marshal(p) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(strings.Replace(filename, filepath.Ext(filename), ".pb", 1), originalRaw, os.ModePerm)) + + m := map[string]interface{}{} + err = json.Unmarshal([]byte(s), &m) + assert.NoError(t, err) + + b, err := yaml.Marshal(m) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(strings.Replace(filename, filepath.Ext(filename), ".yaml", 1), b, os.ModePerm)) +} + +func TestReverseEngineerFromYaml(t *testing.T) { + root := "testdata" + errors.SetConfig(errors.Config{IncludeSource: true}) + assert.NoError(t, filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + if !strings.HasSuffix(path, ".yaml") { + return nil + } + + if strings.HasSuffix(path, "-inputs.yaml") { + return nil + } + + ext := ".yaml" + + testName := strings.TrimLeft(path, root) + testName = strings.Trim(testName, string(os.PathSeparator)) + testName = strings.TrimSuffix(testName, ext) + testName = strings.Replace(testName, string(os.PathSeparator), "_", -1) + + t.Run(testName, func(t *testing.T) { + t.Log("Reading from file") + raw, err := ioutil.ReadFile(path) + assert.NoError(t, err) + + raw, err = yaml.YAMLToJSON(raw) + assert.NoError(t, err) + + t.Log("Unmarshalling Workflow Closure") + wf := &core.WorkflowClosure{} + err = jsonpb.UnmarshalString(string(raw), wf) + assert.NoError(t, err) + assert.NotNil(t, wf) + if err != nil { + return + } + + t.Log("Compiling Workflow") + compiledWf, err := compiler.CompileWorkflow(wf.Workflow, []*core.WorkflowTemplate{}, mustCompileTasks(t, wf.Tasks), []common.InterfaceProvider{}) + assert.NoError(t, err) + if err != nil { + return + } + + inputs := makeDefaultInputs(compiledWf.Primary.Template.GetInterface()) + if *reverse { + marshalProto(t, strings.Replace(path, ext, fmt.Sprintf("-inputs%v", ext), -1), inputs) + } + + t.Log("Building k8s resource") + _, err = k8s.BuildFlyteWorkflow(compiledWf, inputs, nil, "") + assert.NoError(t, err) + if err != nil { + return + } + + dotFormat := visualize.ToGraphViz(compiledWf.Primary) + t.Logf("GraphViz Dot: %v\n", dotFormat) + + if *reverse { + marshalProto(t, path, wf) + } + }) + + return nil + })) +} + +func TestCompileAndBuild(t *testing.T) { + root := "testdata" + errors.SetConfig(errors.Config{IncludeSource: true}) + assert.NoError(t, filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + if ext := filepath.Ext(path); ext != ".pb" { + return nil + } + + if strings.HasSuffix(path, "-inputs.pb") { + return nil + } + + testName := strings.TrimLeft(path, root) + testName = strings.Trim(testName, string(os.PathSeparator)) + testName = strings.Trim(testName, filepath.Ext(testName)) + testName = strings.Replace(testName, string(os.PathSeparator), "_", -1) + + t.Run(testName, func(t *testing.T) { + t.Log("Reading from file") + raw, err := ioutil.ReadFile(path) + assert.NoError(t, err) + + t.Log("Unmarshalling Workflow Closure") + wf := &core.WorkflowClosure{} + err = proto.Unmarshal(raw, wf) + assert.NoError(t, err) + assert.NotNil(t, wf) + if err != nil { + return + } + + t.Log("Compiling Workflow") + compiledWf, err := compiler.CompileWorkflow(wf.Workflow, []*core.WorkflowTemplate{}, mustCompileTasks(t, wf.Tasks), []common.InterfaceProvider{}) + assert.NoError(t, err) + if err != nil { + return + } + + inputs := makeDefaultInputs(compiledWf.Primary.Template.GetInterface()) + if *update { + marshalProto(t, strings.Replace(path, filepath.Ext(path), fmt.Sprintf("-inputs%v", filepath.Ext(path)), -1), inputs) + } + + t.Log("Building k8s resource") + _, err = k8s.BuildFlyteWorkflow(compiledWf, inputs, nil, "") + assert.NoError(t, err) + if err != nil { + return + } + + dotFormat := visualize.ToGraphViz(compiledWf.Primary) + t.Logf("GraphViz Dot: %v\n", dotFormat) + + if *update { + marshalProto(t, path, wf) + } + }) + + return nil + })) +} diff --git a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.pb b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.pb new file mode 100755 index 000000000..e69de29bb diff --git a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml new file mode 100755 index 000000000..3893cfb77 --- /dev/null +++ b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml @@ -0,0 +1 @@ +literals: {} diff --git a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb new file mode 100755 index 000000000..5f6557e40 Binary files /dev/null and b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb differ diff --git a/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml new file mode 100644 index 000000000..a9c3f1f96 --- /dev/null +++ b/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml @@ -0,0 +1,79 @@ +tasks: +- container: + args: + - --task-module + - app.workflows.work + - --task-name + - find_odd_numbers + - --inputs + - '{{$input}}' + - --output-prefix + - '{{$outputPrefix}}' + command: + - pyflyte-execute + config: + - key: testKey1 + value: testValue1 + - key: testKey2 + value: testValue2 + - key: testKey3 + value: testValue3 + image: myflytecontainer:abc123 + resources: {} + id: + name: app.workflows.work.find_odd_numbers + interface: + inputs: + variables: + list_of_nums: + type: + collectionType: + simple: INTEGER + outputs: + variables: + are_num_odd: + type: + collectionType: + simple: BOOLEAN + metadata: + discoveryVersion: "1" + retries: {} + runtime: + flavor: python + type: FLYTE_SDK + version: 0.0.1a0 + timeout: 0s + type: python-task +workflow: + id: + name: app-workflows-work-one-python-task-w-f + interface: + inputs: {} + outputs: {} + metadata: {} + nodes: + - id: odd-nums-task + inputs: + - binding: + collection: + bindings: + - scalar: + primitive: + integer: "2" + - scalar: + primitive: + integer: "3" + - scalar: + primitive: + integer: "4" + - scalar: + primitive: + integer: "7" + var: list_of_nums + metadata: + name: DEADBEEF + retries: {} + timeout: 0s + taskNode: + referenceId: + name: app.workflows.work.find_odd_numbers diff --git a/pkg/compiler/testdata/beta-one-second-functional-test.dot.golden b/pkg/compiler/testdata/beta-one-second-functional-test.dot.golden new file mode 100755 index 000000000..8f800bde2 --- /dev/null +++ b/pkg/compiler/testdata/beta-one-second-functional-test.dot.golden @@ -0,0 +1 @@ +digraph G {rankdir=TB;workflow[label="Workflow Id: beta-one-second-functional-test"];node[style=filled];"start-node(start)" [shape=Msquare];} \ No newline at end of file diff --git a/pkg/compiler/transformers/k8s/builder_mock_test.go b/pkg/compiler/transformers/k8s/builder_mock_test.go new file mode 100644 index 000000000..c92bb9406 --- /dev/null +++ b/pkg/compiler/transformers/k8s/builder_mock_test.go @@ -0,0 +1,140 @@ +package k8s + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" +) + +type mockWorkflow struct { + *core.CompiledWorkflow + nodes common.NodeIndex + tasks common.TaskIndex + wfs map[common.WorkflowIDKey]common.InterfaceProvider + failureNode *mockNode + downstream common.StringAdjacencyList + upstream common.StringAdjacencyList +} + +func (m mockWorkflow) GetSubWorkflow(id common.WorkflowID) (wf *core.CompiledWorkflow, found bool) { + panic("method invocation not expected") +} + +func (m mockWorkflow) GetNode(id common.NodeID) (node common.NodeBuilder, found bool) { + node, found = m.nodes[id] + return +} + +func (m mockWorkflow) GetTask(id common.TaskID) (task common.Task, found bool) { + task, found = m.tasks[id.String()] + return +} + +func (m mockWorkflow) GetLaunchPlan(id common.LaunchPlanID) (wf common.InterfaceProvider, found bool) { + wf, found = m.wfs[id.String()] + return +} + +func (m mockWorkflow) GetCoreWorkflow() *core.CompiledWorkflow { + return m.CompiledWorkflow +} + +func (m mockWorkflow) GetFailureNode() common.Node { + if m.failureNode == nil { + return nil + } + + return m.failureNode +} + +func (m mockWorkflow) GetNodes() common.NodeIndex { + return m.nodes +} + +func (m mockWorkflow) GetTasks() common.TaskIndex { + return m.tasks +} + +func (m mockWorkflow) GetDownstreamNodes() common.StringAdjacencyList { + return m.downstream +} + +func (m mockWorkflow) GetUpstreamNodes() common.StringAdjacencyList { + return m.upstream +} + +type mockNode struct { + *core.Node + id common.NodeID + iface *core.TypedInterface + inputs []*core.Binding + aliases []*core.Alias + upstream []string + task common.Task + subWF common.Workflow +} + +func (n mockNode) GetID() common.NodeID { + return n.id +} + +func (n mockNode) GetInterface() *core.TypedInterface { + return n.iface +} + +func (n mockNode) GetInputs() []*core.Binding { + return n.inputs +} + +func (n mockNode) GetOutputAliases() []*core.Alias { + return n.aliases +} + +func (n mockNode) GetUpstreamNodeIds() []string { + return n.upstream +} + +func (n mockNode) GetCoreNode() *core.Node { + return n.Node +} + +func (n mockNode) GetTask() common.Task { + return n.task +} + +func (n *mockNode) SetTask(task common.Task) { + n.task = task +} + +func (n mockNode) GetSubWorkflow() common.Workflow { + return n.subWF +} + +func (n *mockNode) SetInterface(iface *core.TypedInterface) { + n.iface = iface +} + +func (n *mockNode) SetInputs(inputs []*core.Binding) { + n.inputs = inputs +} + +func (n *mockNode) SetSubWorkflow(wf common.Workflow) { + n.subWF = wf +} + +type mockTask struct { + id common.TaskID + task *core.TaskTemplate + iface *core.TypedInterface +} + +func (m mockTask) GetID() common.TaskID { + return m.id +} + +func (m mockTask) GetCoreTask() *core.TaskTemplate { + return m.task +} + +func (m mockTask) GetInterface() *core.TypedInterface { + return m.iface +} diff --git a/pkg/compiler/transformers/k8s/inputs.go b/pkg/compiler/transformers/k8s/inputs.go new file mode 100644 index 000000000..964e2608c --- /dev/null +++ b/pkg/compiler/transformers/k8s/inputs.go @@ -0,0 +1,53 @@ +package k8s + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/compiler/validators" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs core.LiteralMap, errs errors.CompileErrors) (ok bool) { + if iface == nil { + errs.Collect(errors.NewValueRequiredErr(nodeID, "interface")) + return + } + + if iface.Inputs == nil { + errs.Collect(errors.NewValueRequiredErr(nodeID, "interface.InputsRef")) + return + } + + varMap := make(map[string]*core.Variable, len(iface.Inputs.Variables)) + requiredInputsSet := sets.String{} + for name, v := range iface.Inputs.Variables { + varMap[name] = v + requiredInputsSet.Insert(name) + } + + boundInputsSet := sets.String{} + for inputVar, inputVal := range inputs.Literals { + v, exists := varMap[inputVar] + if !exists { + errs.Collect(errors.NewVariableNameNotFoundErr(nodeID, "", inputVar)) + continue + } + + inputType := validators.LiteralTypeForLiteral(inputVal) + if !validators.AreTypesCastable(inputType, v.Type) { + errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String())) + continue + } + + boundInputsSet.Insert(inputVar) + } + + if diff := requiredInputsSet.Difference(boundInputsSet); len(diff) > 0 { + for param := range diff { + errs.Collect(errors.NewParameterNotBoundErr(nodeID, param)) + } + } + + return !errs.HasErrors() +} diff --git a/pkg/compiler/transformers/k8s/inputs_test.go b/pkg/compiler/transformers/k8s/inputs_test.go new file mode 100644 index 000000000..01d667c35 --- /dev/null +++ b/pkg/compiler/transformers/k8s/inputs_test.go @@ -0,0 +1 @@ +package k8s diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go new file mode 100644 index 000000000..fe31b6e59 --- /dev/null +++ b/pkg/compiler/transformers/k8s/node.go @@ -0,0 +1,169 @@ +package k8s + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/utils" +) + +// Gets the compiled subgraph if this node contains an inline-declared coreWorkflow. Otherwise nil. +func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.CompileErrors) (*v1alpha1.NodeSpec, bool) { + if n == nil { + errs.Collect(errors.NewValueRequiredErr("root", "node")) + return nil, !errs.HasErrors() + } + + if n.GetId() != common.StartNodeID && n.GetId() != common.EndNodeID && + n.GetTarget() == nil { + + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "target")) + return nil, !errs.HasErrors() + } + + var task *core.TaskTemplate + if n.GetTaskNode() != nil { + taskID := n.GetTaskNode().GetReferenceId().String() + // TODO: Use task index for quick lookup + for _, t := range tasks { + if t.Template.Id.String() == taskID { + task = t.Template + break + } + } + + if task == nil { + errs.Collect(errors.NewTaskReferenceNotFoundErr(n.GetId(), taskID)) + return nil, !errs.HasErrors() + } + } + + res, err := utils.ToK8sResourceRequirements(getResources(task)) + if err != nil { + errs.Collect(errors.NewWorkflowBuildError(err)) + return nil, false + } + + nodeSpec := &v1alpha1.NodeSpec{ + ID: n.GetId(), + RetryStrategy: computeRetryStrategy(n, task), + Resources: res, + OutputAliases: toAliasValueArray(n.GetOutputAliases()), + InputBindings: toBindingValueArray(n.GetInputs()), + ActiveDeadlineSeconds: computeActiveDeadlineSeconds(n, task), + } + + switch v := n.GetTarget().(type) { + case *core.Node_TaskNode: + nodeSpec.Kind = v1alpha1.NodeKindTask + nodeSpec.TaskRef = refStr(n.GetTaskNode().GetReferenceId().String()) + case *core.Node_WorkflowNode: + if n.GetWorkflowNode().Reference == nil { + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "WorkflowNode.Reference")) + return nil, !errs.HasErrors() + } + + switch n.GetWorkflowNode().Reference.(type) { + case *core.WorkflowNode_LaunchplanRef: + nodeSpec.Kind = v1alpha1.NodeKindWorkflow + nodeSpec.WorkflowNode = &v1alpha1.WorkflowNodeSpec{ + LaunchPlanRefID: &v1alpha1.LaunchPlanRefID{Identifier: n.GetWorkflowNode().GetLaunchplanRef()}, + } + case *core.WorkflowNode_SubWorkflowRef: + nodeSpec.Kind = v1alpha1.NodeKindWorkflow + if v.WorkflowNode.GetSubWorkflowRef() != nil { + nodeSpec.WorkflowNode = &v1alpha1.WorkflowNodeSpec{ + SubWorkflowReference: refStr(v.WorkflowNode.GetSubWorkflowRef().String()), + } + } else if v.WorkflowNode.GetLaunchplanRef() != nil { + nodeSpec.WorkflowNode = &v1alpha1.WorkflowNodeSpec{ + LaunchPlanRefID: &v1alpha1.LaunchPlanRefID{Identifier: n.GetWorkflowNode().GetLaunchplanRef()}, + } + } else { + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "WorkflowNode.WorkflowTemplate")) + return nil, !errs.HasErrors() + } + } + case *core.Node_BranchNode: + nodeSpec.Kind = v1alpha1.NodeKindBranch + nodeSpec.BranchNode = buildBranchNodeSpec(n.GetBranchNode(), errs.NewScope()) + default: + if n.GetId() == v1alpha1.StartNodeID { + nodeSpec.Kind = v1alpha1.NodeKindStart + } else if n.GetId() == v1alpha1.EndNodeID { + nodeSpec.Kind = v1alpha1.NodeKindEnd + } + } + + return nodeSpec, !errs.HasErrors() +} + +func buildIfBlockSpec(block *core.IfBlock, _ errors.CompileErrors) *v1alpha1.IfBlock { + return &v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{BooleanExpression: block.Condition}, + ThenNode: refStr(block.ThenNode.Id), + } +} + +func buildBranchNodeSpec(branch *core.BranchNode, errs errors.CompileErrors) *v1alpha1.BranchNodeSpec { + if branch == nil { + return nil + } + + res := &v1alpha1.BranchNodeSpec{ + If: *buildIfBlockSpec(branch.IfElse.Case, errs.NewScope()), + } + + switch branch.IfElse.GetDefault().(type) { + case *core.IfElseBlock_ElseNode: + res.Else = refStr(branch.IfElse.GetElseNode().Id) + case *core.IfElseBlock_Error: + res.ElseFail = &v1alpha1.Error{Error: branch.IfElse.GetError()} + } + + other := make([]*v1alpha1.IfBlock, 0, len(branch.IfElse.Other)) + for _, block := range branch.IfElse.Other { + other = append(other, buildIfBlockSpec(block, errs.NewScope())) + } + + res.ElseIf = other + + return res +} + +func buildNodes(nodes []*core.Node, tasks []*core.CompiledTask, errs errors.CompileErrors) (map[common.NodeID]*v1alpha1.NodeSpec, bool) { + res := make(map[common.NodeID]*v1alpha1.NodeSpec, len(nodes)) + for _, nodeBuidler := range nodes { + n, ok := buildNodeSpec(nodeBuidler, tasks, errs.NewScope()) + if !ok { + return nil, ok + } + + if _, exists := res[n.ID]; exists { + errs.Collect(errors.NewValueCollisionError(nodeBuidler.GetId(), "Id", n.ID)) + } + + res[n.ID] = n + } + + return res, !errs.HasErrors() +} + +func buildTasks(tasks []*core.CompiledTask, errs errors.CompileErrors) map[common.TaskIDKey]*v1alpha1.TaskSpec { + res := make(map[common.TaskIDKey]*v1alpha1.TaskSpec, len(tasks)) + for _, flyteTask := range tasks { + if flyteTask == nil { + errs.Collect(errors.NewValueRequiredErr("root", "coreTask")) + } else { + taskID := flyteTask.Template.Id.String() + if _, exists := res[taskID]; exists { + errs.Collect(errors.NewValueCollisionError(taskID, "Id", taskID)) + } + + res[taskID] = &v1alpha1.TaskSpec{TaskTemplate: flyteTask.Template} + } + } + + return res +} diff --git a/pkg/compiler/transformers/k8s/node_test.go b/pkg/compiler/transformers/k8s/node_test.go new file mode 100644 index 000000000..abe13f6bc --- /dev/null +++ b/pkg/compiler/transformers/k8s/node_test.go @@ -0,0 +1,181 @@ +package k8s + +import ( + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/stretchr/testify/assert" +) + +func createNodeWithTask() *core.Node { + return &core.Node{ + Id: "n_1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + } +} + +func TestBuildNodeSpec(t *testing.T) { + n := mockNode{ + id: "n_1", + Node: &core.Node{}, + } + + tasks := []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_2"}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "10Mi", + }, + }, + }, + }, + }, + }, + }, + } + + errors.SetConfig(errors.Config{IncludeSource: true}) + errs := errors.NewCompileErrors() + + mustBuild := func(n common.Node, errs errors.CompileErrors) *v1alpha1.NodeSpec { + spec, ok := buildNodeSpec(n.GetCoreNode(), tasks, errs) + assert.False(t, errs.HasErrors()) + assert.True(t, ok) + assert.NotNil(t, spec) + + if errs.HasErrors() { + assert.Fail(t, errs.Error()) + } + + return spec + } + + t.Run("Task", func(t *testing.T) { + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + + t.Run("Task with resources", func(t *testing.T) { + expectedCPU := resource.MustParse("10Mi") + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_2"}, + }, + }, + } + + spec := mustBuild(n, errs.NewScope()) + assert.NotNil(t, spec.Resources) + assert.NotNil(t, spec.Resources.Requests.Cpu()) + assert.Equal(t, expectedCPU.Value(), spec.Resources.Requests.Cpu().Value()) + }) + + t.Run("LaunchPlanRef", func(t *testing.T) { + n.Node.Target = &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "ref_1"}, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + + t.Run("Workflow", func(t *testing.T) { + n.subWF = createSampleMockWorkflow() + n.Node.Target = &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: n.subWF.GetCoreWorkflow().Template.Id, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + + t.Run("Branch", func(t *testing.T) { + n.Node.Target = &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Other: []*core.IfBlock{}, + Default: &core.IfElseBlock_Error{ + Error: &core.Error{ + Message: "failed", + }, + }, + Case: &core.IfBlock{ + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + Condition: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_EQ, + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 123, + }, + }, + }, + }, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 123, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + +} diff --git a/pkg/compiler/transformers/k8s/utils.go b/pkg/compiler/transformers/k8s/utils.go new file mode 100644 index 000000000..1836944ea --- /dev/null +++ b/pkg/compiler/transformers/k8s/utils.go @@ -0,0 +1,86 @@ +package k8s + +import ( + "math" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +func refInt(i int) *int { + return &i +} + +func refStr(s string) *string { + return &s +} + +func computeRetryStrategy(n *core.Node, t *core.TaskTemplate) *v1alpha1.RetryStrategy { + if n.GetMetadata() != nil && n.GetMetadata().GetRetries() != nil { + return &v1alpha1.RetryStrategy{ + MinAttempts: refInt(int(n.GetMetadata().GetRetries().Retries + 1)), + } + } + + if t != nil && t.GetMetadata() != nil && t.GetMetadata().GetRetries() != nil { + return &v1alpha1.RetryStrategy{ + MinAttempts: refInt(int(t.GetMetadata().GetRetries().Retries + 1)), + } + } + + return nil +} + +func computeActiveDeadlineSeconds(n *core.Node, t *core.TaskTemplate) *int64 { + if n.GetMetadata() != nil && n.GetMetadata().Timeout != nil { + return &n.GetMetadata().Timeout.Seconds + } + + if t != nil && t.GetMetadata() != nil && t.GetMetadata().Timeout != nil { + return &t.GetMetadata().Timeout.Seconds + } + + return nil +} + +func getResources(task *core.TaskTemplate) *core.Resources { + if task == nil { + return nil + } + + if task.GetContainer() == nil { + return nil + } + + return task.GetContainer().Resources +} + +func toAliasValueArray(aliases []*core.Alias) []v1alpha1.Alias { + if aliases == nil { + return nil + } + + res := make([]v1alpha1.Alias, 0, len(aliases)) + for _, alias := range aliases { + res = append(res, v1alpha1.Alias{Alias: *alias}) + } + + return res +} + +func toBindingValueArray(bindings []*core.Binding) []*v1alpha1.Binding { + if bindings == nil { + return nil + } + + res := make([]*v1alpha1.Binding, 0, len(bindings)) + for _, binding := range bindings { + res = append(res, &v1alpha1.Binding{Binding: binding}) + } + + return res +} + +func minInt(i, j int) int { + return int(math.Min(float64(i), float64(j))) +} diff --git a/pkg/compiler/transformers/k8s/utils_test.go b/pkg/compiler/transformers/k8s/utils_test.go new file mode 100644 index 000000000..371b2531a --- /dev/null +++ b/pkg/compiler/transformers/k8s/utils_test.go @@ -0,0 +1,58 @@ +package k8s + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestComputeRetryStrategy(t *testing.T) { + + tests := []struct { + name string + nodeRetries int + taskRetries int + expectedRetries int + }{ + {"node-only", 1, 0, 2}, + {"task-only", 0, 1, 2}, + {"node-task", 2, 3, 3}, + {"no-retries", 0, 0, 0}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + var node *core.Node + if test.nodeRetries != 0 { + node = &core.Node{ + Metadata: &core.NodeMetadata{ + Retries: &core.RetryStrategy{ + Retries: uint32(test.nodeRetries), + }, + }, + } + } + + var tmpl *core.TaskTemplate + if test.taskRetries != 0 { + tmpl = &core.TaskTemplate{ + Metadata: &core.TaskMetadata{ + Retries: &core.RetryStrategy{ + Retries: uint32(test.taskRetries), + }, + }, + } + } + + r := computeRetryStrategy(node, tmpl) + if test.expectedRetries != 0 { + assert.NotNil(t, r) + assert.Equal(t, test.expectedRetries, *r.MinAttempts) + } else { + assert.Nil(t, r) + } + }) + } + +} diff --git a/pkg/compiler/transformers/k8s/workflow.go b/pkg/compiler/transformers/k8s/workflow.go new file mode 100644 index 000000000..868e36bbf --- /dev/null +++ b/pkg/compiler/transformers/k8s/workflow.go @@ -0,0 +1,208 @@ +// This package converts the output of the compiler into a K8s resource for propeller to execute. +package k8s + +import ( + "fmt" + "strings" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ExecutionIDLabel = "execution-id" +const WorkflowIDLabel = "workflow-id" + +func requiresInputs(w *core.WorkflowTemplate) bool { + if w == nil || w.GetInterface() == nil || w.GetInterface().GetInputs() == nil || + w.GetInterface().GetInputs().Variables == nil { + + return false + } + + return len(w.GetInterface().GetInputs().Variables) > 0 +} + +func WorkflowIDAsString(id *core.Identifier) string { + b := strings.Builder{} + _, err := b.WriteString(id.Project) + if err != nil { + return "" + } + + _, err = b.WriteRune(':') + if err != nil { + return "" + } + + _, err = b.WriteString(id.Domain) + if err != nil { + return "" + } + + _, err = b.WriteRune(':') + if err != nil { + return "" + } + + _, err = b.WriteString(id.Name) + if err != nil { + return "" + } + + return b.String() +} + +func buildFlyteWorkflowSpec(wf *core.CompiledWorkflow, tasks []*core.CompiledTask, errs errors.CompileErrors) ( + spec *v1alpha1.WorkflowSpec, ok bool) { + var failureN *v1alpha1.NodeSpec + if n := wf.Template.GetFailureNode(); n != nil { + failureN, _ = buildNodeSpec(n, tasks, errs.NewScope()) + } + + nodes, _ := buildNodes(wf.Template.GetNodes(), tasks, errs.NewScope()) + + if errs.HasErrors() { + return nil, !errs.HasErrors() + } + + outputBindings := make([]*v1alpha1.Binding, 0, len(wf.Template.Outputs)) + for _, b := range wf.Template.Outputs { + outputBindings = append(outputBindings, &v1alpha1.Binding{ + Binding: b, + }) + } + + var outputs *v1alpha1.OutputVarMap + if wf.Template.GetInterface() != nil { + outputs = &v1alpha1.OutputVarMap{VariableMap: wf.Template.GetInterface().Outputs} + } else { + outputs = &v1alpha1.OutputVarMap{VariableMap: &core.VariableMap{}} + } + + return &v1alpha1.WorkflowSpec{ + ID: WorkflowIDAsString(wf.Template.Id), + OnFailure: failureN, + Nodes: nodes, + Connections: buildConnections(wf), + Outputs: outputs, + OutputBindings: outputBindings, + }, !errs.HasErrors() +} + +func withSeparatorIfNotEmpty(value string) string { + if len(value) > 0 { + return fmt.Sprintf("%v-", value) + } + + return "" +} + +func generateName(wfID *core.Identifier, execID *core.WorkflowExecutionIdentifier) ( + name string, generateName string, label string, err error) { + + if execID != nil { + return execID.Name, "", execID.Name, nil + } else if wfID != nil { + wid := fmt.Sprintf("%v%v%v", + withSeparatorIfNotEmpty(wfID.Project), + withSeparatorIfNotEmpty(wfID.Domain), + wfID.Name, + ) + + // TODO: this is a hack until we figure out how to restrict generated names. K8s has a limitation of 63 chars + wid = wid[:minInt(32, len(wid))] + return "", fmt.Sprintf("%v-", wid), wid, nil + } else { + return "", "", "", fmt.Errorf("expected param not set. wfID or execID must be non-nil values") + } +} + +// Builds v1alpha1.FlyteWorkflow resource. Returned error, if not nil, is of type errors.CompilerErrors. +func BuildFlyteWorkflow(wfClosure *core.CompiledWorkflowClosure, inputs *core.LiteralMap, + executionID *core.WorkflowExecutionIdentifier, namespace string) (*v1alpha1.FlyteWorkflow, error) { + + errs := errors.NewCompileErrors() + if wfClosure == nil { + errs.Collect(errors.NewValueRequiredErr("root", "wfClosure")) + return nil, errs + } + + primarySpec, _ := buildFlyteWorkflowSpec(wfClosure.Primary, wfClosure.Tasks, errs.NewScope()) + subwfs := make(map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec, len(wfClosure.SubWorkflows)) + for _, subWf := range wfClosure.SubWorkflows { + spec, _ := buildFlyteWorkflowSpec(wfClosure.Primary, wfClosure.Tasks, errs.NewScope()) + subwfs[subWf.Template.Id.String()] = spec + } + + wf := wfClosure.Primary.Template + tasks := wfClosure.Tasks + // Fill in inputs in the start node. + if inputs != nil { + if ok := validateInputs(common.StartNodeID, wf.GetInterface(), *inputs, errs.NewScope()); !ok { + return nil, errs + } + } else if requiresInputs(wf) { + errs.Collect(errors.NewValueRequiredErr("root", "inputs")) + return nil, errs + } + + obj := &v1alpha1.FlyteWorkflow{ + TypeMeta: v1.TypeMeta{ + Kind: v1alpha1.FlyteWorkflowKind, + APIVersion: v1alpha1.SchemeGroupVersion.String(), + }, + ObjectMeta: v1.ObjectMeta{ + Namespace: namespace, + Labels: map[string]string{}, + }, + Inputs: &v1alpha1.Inputs{LiteralMap: inputs}, + WorkflowSpec: primarySpec, + SubWorkflows: subwfs, + Tasks: buildTasks(tasks, errs.NewScope()), + } + + var err error + obj.ObjectMeta.Name, obj.ObjectMeta.GenerateName, obj.ObjectMeta.Labels[ExecutionIDLabel], err = + generateName(wf.GetId(), executionID) + + if err != nil { + errs.Collect(errors.NewWorkflowBuildError(err)) + } + + if obj.Nodes == nil || obj.Connections.DownstreamEdges == nil { + // If we come here, we'd better have an error generated earlier. Otherwise, add one to make sure build fails. + if !errs.HasErrors() { + errs.Collect(errors.NewWorkflowBuildError(fmt.Errorf("failed to build workflow for unknown reason." + + " Make sure to pass this workflow through the compiler first"))) + } + } else if startingNodes, err := obj.FromNode(v1alpha1.StartNodeID); err == nil && len(startingNodes) == 0 { + errs.Collect(errors.NewWorkflowHasNoEntryNodeErr(wf.GetId().String())) + } else if err != nil { + errs.Collect(errors.NewWorkflowBuildError(err)) + } + + if errs.HasErrors() { + return nil, errs + } + + return obj, nil +} + +func toMapOfLists(connections map[string]*core.ConnectionSet_IdList) map[string][]string { + res := make(map[string][]string, len(connections)) + for key, val := range connections { + res[key] = val.Ids + } + + return res +} + +func buildConnections(w *core.CompiledWorkflow) v1alpha1.Connections { + res := v1alpha1.Connections{} + res.DownstreamEdges = toMapOfLists(w.GetConnections().GetDownstream()) + res.UpstreamEdges = toMapOfLists(w.GetConnections().GetUpstream()) + return res +} diff --git a/pkg/compiler/transformers/k8s/workflow_test.go b/pkg/compiler/transformers/k8s/workflow_test.go new file mode 100644 index 000000000..098dab245 --- /dev/null +++ b/pkg/compiler/transformers/k8s/workflow_test.go @@ -0,0 +1,238 @@ +package k8s + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func createSampleMockWorkflow() *mockWorkflow { + return &mockWorkflow{ + tasks: common.TaskIndex{ + "task_1": &mockTask{ + task: &core.TaskTemplate{ + Id: &core.Identifier{Name: "task_1"}, + }, + }, + }, + nodes: common.NodeIndex{ + "node_1": &mockNode{ + upstream: []string{common.StartNodeID}, + inputs: []*core.Binding{}, + task: &mockTask{}, + id: "node_1", + Node: createNodeWithTask(), + }, + common.StartNodeID: &mockNode{ + id: common.StartNodeID, + Node: &core.Node{}, + }, + }, + //failureNode: &mockNode{ + // id: "node_1", + //}, + downstream: common.StringAdjacencyList{ + common.StartNodeID: sets.NewString("node_1"), + }, + upstream: common.StringAdjacencyList{ + "node_1": sets.NewString(common.StartNodeID), + }, + CompiledWorkflow: &core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "wf_1"}, + Nodes: []*core.Node{ + createNodeWithTask(), + { + Id: common.StartNodeID, + }, + }, + }, + Connections: &core.ConnectionSet{ + Downstream: map[string]*core.ConnectionSet_IdList{ + common.StartNodeID: { + Ids: []string{"node_1"}, + }, + }, + }, + }, + } +} + +func TestWorkflowIDAsString(t *testing.T) { + assert.Equal(t, "project:domain:name", WorkflowIDAsString(&core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "v1", + })) + + assert.Equal(t, ":domain:name", WorkflowIDAsString(&core.Identifier{ + Domain: "domain", + Name: "name", + Version: "v1", + })) + + assert.Equal(t, "project::name", WorkflowIDAsString(&core.Identifier{ + Project: "project", + Name: "name", + Version: "v1", + })) + + assert.Equal(t, "project:domain:name", WorkflowIDAsString(&core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + })) +} + +func TestBuildFlyteWorkflow(t *testing.T) { + w := createSampleMockWorkflow() + + errors.SetConfig(errors.Config{IncludeSource: true}) + wf, err := BuildFlyteWorkflow( + &core.CompiledWorkflowClosure{ + Primary: w.GetCoreWorkflow(), + Tasks: []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + nil, nil, "") + assert.NoError(t, err) + assert.NotNil(t, wf) + errors.SetConfig(errors.Config{}) +} + +func TestBuildFlyteWorkflow_withInputs(t *testing.T) { + w := createSampleMockWorkflow() + + startNode := w.GetNodes()[common.StartNodeID].(*mockNode) + vars := []*core.Variable{ + { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + }, + } + + w.Template.Interface = &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": vars[0], + "y": vars[1], + }, + }, + } + + startNode.iface = &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": vars[0], + "y": vars[1], + }, + }, + } + + intLiteral, err := utils.MakePrimitiveLiteral(123) + assert.NoError(t, err) + stringLiteral, err := utils.MakePrimitiveLiteral("hello") + assert.NoError(t, err) + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": intLiteral, + "y": stringLiteral, + }, + } + + errors.SetConfig(errors.Config{IncludeSource: true}) + wf, err := BuildFlyteWorkflow( + &core.CompiledWorkflowClosure{ + Primary: w.GetCoreWorkflow(), + Tasks: []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + inputs, nil, "") + assert.NoError(t, err) + assert.NotNil(t, wf) + errors.SetConfig(errors.Config{}) + + assert.Equal(t, 2, len(wf.Inputs.Literals)) + assert.Equal(t, int64(123), wf.Inputs.Literals["x"].GetScalar().GetPrimitive().GetInteger()) +} + +func TestGenerateName(t *testing.T) { + t.Run("Invalid params", func(t *testing.T) { + _, _, _, err := generateName(nil, nil) + assert.Error(t, err) + }) + + t.Run("wfID full", func(t *testing.T) { + name, generateName, _, err := generateName(&core.Identifier{ + Name: "myworkflow", + Project: "myproject", + Domain: "development", + }, nil) + + assert.NoError(t, err) + assert.Empty(t, name) + assert.Equal(t, "myproject-development-myworkflow-", generateName) + }) + + t.Run("wfID missing project domain", func(t *testing.T) { + name, generateName, _, err := generateName(&core.Identifier{ + Name: "myworkflow", + }, nil) + + assert.NoError(t, err) + assert.Empty(t, name) + assert.Equal(t, "myworkflow-", generateName) + }) + + t.Run("wfID too long", func(t *testing.T) { + name, generateName, _, err := generateName(&core.Identifier{ + Name: "workflowsomethingsomethingsomething", + Project: "myproject", + Domain: "development", + }, nil) + + assert.NoError(t, err) + assert.Empty(t, name) + assert.Equal(t, "myproject-development-workflowso-", generateName) + }) + + t.Run("execID full", func(t *testing.T) { + name, generateName, _, err := generateName(nil, &core.WorkflowExecutionIdentifier{ + Name: "myexecution", + Project: "myproject", + Domain: "development", + }) + + assert.NoError(t, err) + assert.Empty(t, generateName) + assert.Equal(t, "myexecution", name) + }) + + t.Run("execID missing project domain", func(t *testing.T) { + name, generateName, _, err := generateName(nil, &core.WorkflowExecutionIdentifier{ + Name: "myexecution", + }) + + assert.NoError(t, err) + assert.Empty(t, generateName) + assert.Equal(t, "myexecution", name) + }) +} diff --git a/pkg/compiler/typing/variable.go b/pkg/compiler/typing/variable.go new file mode 100644 index 000000000..958d49e3c --- /dev/null +++ b/pkg/compiler/typing/variable.go @@ -0,0 +1,36 @@ +package typing + +import ( + "fmt" + "regexp" + "strconv" +) + +var arrayVarMatcher = regexp.MustCompile(`(\[(?P\d+)\]\.)?(?P\w+)`) + +type Variable struct { + Name string + Index *int +} + +// Parses var names +func ParseVarName(varName string) (v Variable, err error) { + allMatches := arrayVarMatcher.FindAllStringSubmatch(varName, -1) + if len(allMatches) != 1 { + return Variable{}, fmt.Errorf("unexpected number of matches [%v]", len(allMatches)) + } + + if len(allMatches[0]) != 4 { + return Variable{}, fmt.Errorf("unexpected number of groups [%v]", len(allMatches[0])) + } + + res := Variable{} + if len(allMatches[0][2]) > 0 { + index, convErr := strconv.Atoi(allMatches[0][2]) + err = convErr + res.Index = &index + } + + res.Name = allMatches[0][3] + return res, err +} diff --git a/pkg/compiler/utils.go b/pkg/compiler/utils.go new file mode 100755 index 000000000..1ff24fc1b --- /dev/null +++ b/pkg/compiler/utils.go @@ -0,0 +1,79 @@ +package compiler + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "k8s.io/apimachinery/pkg/util/sets" +) + +func toInterfaceProviderMap(tasks []common.InterfaceProvider) map[string]common.InterfaceProvider { + res := make(map[string]common.InterfaceProvider, len(tasks)) + for _, task := range tasks { + res[task.GetID().String()] = task + } + + return res +} + +func toSlice(s sets.String) []string { + res := make([]string, 0, len(s)) + for str := range s { + res = append(res, str) + } + + return res +} + +func toNodeIdsSet(nodes common.NodeIndex) sets.String { + res := sets.NewString() + for nodeID := range nodes { + res.Insert(nodeID) + } + + return res +} + +// Runs a depth-first coreWorkflow traversal to detect any cycles in the coreWorkflow. It produces the first cycle found, as well as +// all visited nodes and a boolean indicating whether or not it found a cycle. +func detectCycle(startNode string, neighbors func(nodeId string) sets.String) (cycle []common.NodeID, visited sets.String, + detected bool) { + + // This is a set of nodes that were ever visited. + visited = sets.NewString() + // This is a set of in-progress visiting nodes. + visiting := sets.NewString() + var detector func(nodeId string) ([]common.NodeID, bool) + detector = func(nodeId string) ([]common.NodeID, bool) { + if visiting.Has(nodeId) { + return []common.NodeID{}, true + } + + visiting.Insert(nodeId) + visited.Insert(nodeId) + + for nextID := range neighbors(nodeId) { + if path, detected := detector(nextID); detected { + return append([]common.NodeID{nextID}, path...), true + } + } + + visiting.Delete(nodeId) + + return []common.NodeID{}, false + } + + if path, detected := detector(startNode); detected { + return append([]common.NodeID{startNode}, path...), visiting, true + } + + return []common.NodeID{}, visited, false +} + +func toCompiledWorkflows(wfs ...*core.WorkflowTemplate) []*core.CompiledWorkflow { + compiledSubWfs := make([]*core.CompiledWorkflow, 0, len(wfs)) + for _, wf := range wfs { + compiledSubWfs = append(compiledSubWfs, &core.CompiledWorkflow{Template: wf}) + } + + return compiledSubWfs +} diff --git a/pkg/compiler/utils_test.go b/pkg/compiler/utils_test.go new file mode 100644 index 000000000..6d7065f7a --- /dev/null +++ b/pkg/compiler/utils_test.go @@ -0,0 +1,65 @@ +package compiler + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func neighbors(adjList map[string][]string) func(nodeId string) sets.String { + return func(nodeId string) sets.String { + if lst, found := adjList[nodeId]; found { + return sets.NewString(lst...) + } + + return sets.NewString() + } +} + +func uniqueNodesCount(adjList map[string][]string) int { + uniqueNodeIds := sets.NewString() + for key, value := range adjList { + uniqueNodeIds.Insert(key) + uniqueNodeIds.Insert(value...) + } + + return uniqueNodeIds.Len() +} + +func assertNoCycle(t *testing.T, startNode string, adjList map[string][]string) { + cycle, visited, detected := detectCycle(startNode, neighbors(adjList)) + assert.False(t, detected) + assert.Equal(t, uniqueNodesCount(adjList), len(visited)) + assert.Equal(t, 0, len(cycle)) +} + +func assertCycle(t *testing.T, startNode string, adjList map[string][]string) { + cycle, _, detected := detectCycle(startNode, neighbors(adjList)) + assert.True(t, detected) + assert.NotEqual(t, 0, len(cycle)) + t.Logf("Cycle: %v", strings.Join(cycle, ",")) +} + +func TestDetectCycle(t *testing.T) { + t.Run("Linear", func(t *testing.T) { + linear := map[string][]string{ + "1": {"2"}, + "2": {"3"}, + "3": {"4"}, + } + + assertNoCycle(t, "1", linear) + }) + + t.Run("Cycle", func(t *testing.T) { + cyclic := map[string][]string{ + "1": {"2", "3"}, + "2": {"3"}, + "3": {"1"}, + } + + assertCycle(t, "1", cyclic) + }) +} diff --git a/pkg/compiler/validators/bindings.go b/pkg/compiler/validators/bindings.go new file mode 100644 index 000000000..10cf0c705 --- /dev/null +++ b/pkg/compiler/validators/bindings.go @@ -0,0 +1,105 @@ +package validators + +import ( + "reflect" + + "github.com/lyft/flytepropeller/pkg/compiler/typing" + + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, binding *flyte.BindingData, expectedType *flyte.LiteralType, errs errors.CompileErrors) ( + []c.NodeID, bool) { + + switch binding.GetValue().(type) { + case *flyte.BindingData_Collection: + if expectedType.GetCollectionType() != nil { + allNodeIds := make([]c.NodeID, 0, len(binding.GetMap().GetBindings())) + for _, v := range binding.GetCollection().GetBindings() { + if nodeIds, ok := validateBinding(w, nodeID, nodeParam, v, expectedType.GetCollectionType(), errs.NewScope()); ok { + allNodeIds = append(allNodeIds, nodeIds...) + } + } + return allNodeIds, !errs.HasErrors() + } + errs.Collect(errors.NewMismatchingBindingsErr(nodeID, nodeParam, expectedType.String(), binding.GetCollection().String())) + case *flyte.BindingData_Map: + if expectedType.GetMapValueType() != nil { + allNodeIds := make([]c.NodeID, 0, len(binding.GetMap().GetBindings())) + for _, v := range binding.GetMap().GetBindings() { + if nodeIds, ok := validateBinding(w, nodeID, nodeParam, v, expectedType.GetMapValueType(), errs.NewScope()); ok { + allNodeIds = append(allNodeIds, nodeIds...) + } + } + return allNodeIds, !errs.HasErrors() + } + + errs.Collect(errors.NewMismatchingBindingsErr(nodeID, nodeParam, expectedType.String(), binding.GetMap().String())) + case *flyte.BindingData_Promise: + if upNode, found := validateNodeID(w, binding.GetPromise().NodeId, errs.NewScope()); found { + v, err := typing.ParseVarName(binding.GetPromise().GetVar()) + if err != nil { + errs.Collect(errors.NewSyntaxError(nodeID, binding.GetPromise().GetVar(), err)) + return nil, !errs.HasErrors() + } + + if param, paramFound := validateOutputVar(upNode, v.Name, errs.NewScope()); paramFound { + if AreTypesCastable(param.Type, expectedType) { + binding.GetPromise().NodeId = upNode.GetId() + return []c.NodeID{binding.GetPromise().NodeId}, true + } + errs.Collect(errors.NewMismatchingTypesErr(nodeID, binding.GetPromise().Var, param.Type.String(), expectedType.String())) + } + } + case *flyte.BindingData_Scalar: + literalType := literalTypeForScalar(binding.GetScalar()) + if literalType == nil { + errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(binding.GetScalar().GetValue()).String())) + } + + if !AreTypesCastable(literalType, expectedType) { + errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, literalType.String(), expectedType.String())) + } + + return []c.NodeID{}, !errs.HasErrors() + default: + errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(binding.GetValue()).String())) + } + + return nil, !errs.HasErrors() +} + +func ValidateBindings(w c.WorkflowBuilder, node c.Node, bindings []*flyte.Binding, params *flyte.VariableMap, + errs errors.CompileErrors) (ok bool) { + + providedBindings := sets.NewString() + for _, binding := range bindings { + if param, ok := findVariableByName(params, binding.GetVar()); !ok { + errs.Collect(errors.NewVariableNameNotFoundErr(node.GetId(), node.GetId(), binding.GetVar())) + } else if binding.GetBinding() == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Binding")) + } else if providedBindings.Has(binding.GetVar()) { + errs.Collect(errors.NewParameterBoundMoreThanOnceErr(node.GetId(), binding.GetVar())) + } else { + providedBindings.Insert(binding.GetVar()) + if upstreamNodes, bindingOk := validateBinding(w, node.GetId(), binding.GetVar(), binding.GetBinding(), param.Type, errs.NewScope()); bindingOk { + for _, upNode := range upstreamNodes { + // Add implicit Edges + w.AddExecutionEdge(upNode, node.GetId()) + } + } + } + } + + // If we missed binding some params, add errors + for paramName := range params.Variables { + if !providedBindings.Has(paramName) { + errs.Collect(errors.NewParameterNotBoundErr(node.GetId(), paramName)) + } + } + + return !errs.HasErrors() +} diff --git a/pkg/compiler/validators/branch.go b/pkg/compiler/validators/branch.go new file mode 100644 index 000000000..a85131078 --- /dev/null +++ b/pkg/compiler/validators/branch.go @@ -0,0 +1,92 @@ +package validators + +import ( + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateBranchInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs errors.CompileErrors) (iface *flyte.TypedInterface, ok bool) { + if branch := node.GetBranchNode(); branch == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch")) + return + } + + if ifBlock := node.GetBranchNode().IfElse; ifBlock == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch.IfElse")) + return + } + + if ifCase := node.GetBranchNode().IfElse.Case; ifCase == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch.IfElse.Case")) + return + } + + if thenNode := node.GetBranchNode().IfElse.Case.ThenNode; thenNode == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch.IfElse.Case.ThenNode")) + return + } + + finalInputParameterNames := sets.NewString() + finalOutputParameterNames := sets.NewString() + + var inputs map[string]*flyte.Variable + var outputs map[string]*flyte.Variable + inputsSet := sets.NewString() + outputsSet := sets.NewString() + + validateIfaceMatch := func(nodeId string, iface2 *flyte.TypedInterface, errsScope errors.CompileErrors) (match bool) { + inputs2, inputs2Set := buildVariablesIndex(iface2.Inputs) + validateVarsSetMatch(nodeId, inputs, inputs2, inputsSet, inputs2Set, errsScope.NewScope()) + finalInputParameterNames = finalInputParameterNames.Intersection(inputs2Set) + + outputs2, outputs2Set := buildVariablesIndex(iface2.Outputs) + validateVarsSetMatch(nodeId, outputs, outputs2, outputsSet, outputs2Set, errsScope.NewScope()) + finalOutputParameterNames = finalOutputParameterNames.Intersection(outputs2Set) + + return !errsScope.HasErrors() + } + + cases := make([]*flyte.IfBlock, 0, len(node.GetBranchNode().IfElse.Other)+1) + caseBlock := node.GetBranchNode().IfElse.Case + cases = append(cases, caseBlock) + + if otherCases := node.GetBranchNode().IfElse.Other; otherCases != nil { + cases = append(cases, otherCases...) + } + + for _, block := range cases { + if block.ThenNode == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "IfElse.Case.ThenNode")) + continue + } + + n := w.NewNodeBuilder(block.ThenNode) + if iface == nil { + // if this is the first node to validate, just assume all other nodes will match the interface + if iface, ok = ValidateUnderlyingInterface(w, n, errs.NewScope()); ok { + inputs, inputsSet = buildVariablesIndex(iface.Inputs) + finalInputParameterNames = finalInputParameterNames.Union(inputsSet) + + outputs, outputsSet = buildVariablesIndex(iface.Outputs) + finalOutputParameterNames = finalOutputParameterNames.Union(outputsSet) + } + } else { + if iface2, ok2 := ValidateUnderlyingInterface(w, n, errs.NewScope()); ok2 { + validateIfaceMatch(n.GetId(), iface2, errs.NewScope()) + } + } + } + + if !errs.HasErrors() { + iface = &flyte.TypedInterface{ + Inputs: filterVariables(iface.Inputs, finalInputParameterNames), + Outputs: filterVariables(iface.Outputs, finalOutputParameterNames), + } + } else { + iface = nil + } + + return iface, !errs.HasErrors() +} diff --git a/pkg/compiler/validators/condition.go b/pkg/compiler/validators/condition.go new file mode 100644 index 000000000..0c800a9b8 --- /dev/null +++ b/pkg/compiler/validators/condition.go @@ -0,0 +1,55 @@ +package validators + +import ( + "fmt" + + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operand, + errs errors.CompileErrors) (literalType *flyte.LiteralType, ok bool) { + if operand == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), paramName)) + } else if operand.GetPrimitive() != nil { + // no validation + literalType = literalTypeForPrimitive(operand.GetPrimitive()) + } else if operand.GetVar() != "" { + if node.GetInterface() != nil { + if param, paramOk := validateInputVar(node, operand.GetVar(), errs.NewScope()); paramOk { + literalType = param.GetType() + } + } + } else { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), fmt.Sprintf("%v.%v", paramName, "Val"))) + } + + return literalType, !errs.HasErrors() +} + +func ValidateBooleanExpression(node c.NodeBuilder, expr *flyte.BooleanExpression, errs errors.CompileErrors) (ok bool) { + if expr == nil { + errs.Collect(errors.NewBranchNodeHasNoCondition(node.GetId())) + } else { + if expr.GetComparison() != nil { + op1Type, op1Valid := validateOperand(node, "RightValue", + expr.GetComparison().GetRightValue(), errs.NewScope()) + op2Type, op2Valid := validateOperand(node, "LeftValue", + expr.GetComparison().GetLeftValue(), errs.NewScope()) + if op1Valid && op2Valid { + if op1Type.String() != op2Type.String() { + errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue", + op1Type.String(), op2Type.String())) + } + } + } else if expr.GetConjunction() != nil { + ValidateBooleanExpression(node, expr.GetConjunction().LeftExpression, errs.NewScope()) + ValidateBooleanExpression(node, expr.GetConjunction().RightExpression, errs.NewScope()) + } else { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Expr")) + } + } + + return !errs.HasErrors() +} diff --git a/pkg/compiler/validators/interface.go b/pkg/compiler/validators/interface.go new file mode 100644 index 000000000..f5cc3b8bf --- /dev/null +++ b/pkg/compiler/validators/interface.go @@ -0,0 +1,128 @@ +package validators + +import ( + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +// Validate interface has its required attributes set +func ValidateInterface(nodeID c.NodeID, iface *core.TypedInterface, errs errors.CompileErrors) ( + typedInterface *core.TypedInterface, ok bool) { + + if iface == nil { + iface = &core.TypedInterface{} + } + + // validate InputsRef/OutputsRef parameters required attributes are set + if iface.Inputs != nil && iface.Inputs.Variables != nil { + validateVariables(nodeID, iface.Inputs, errs.NewScope()) + } else { + iface.Inputs = &core.VariableMap{Variables: map[string]*core.Variable{}} + } + + if iface.Outputs != nil && iface.Outputs.Variables != nil { + validateVariables(nodeID, iface.Outputs, errs.NewScope()) + } else { + iface.Outputs = &core.VariableMap{Variables: map[string]*core.Variable{}} + } + + return iface, !errs.HasErrors() +} + +// Validates underlying interface of a node and returns the effective Typed Interface. +func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs errors.CompileErrors) (iface *core.TypedInterface, ok bool) { + switch node.GetCoreNode().GetTarget().(type) { + case *core.Node_TaskNode: + if node.GetTaskNode().GetReferenceId() == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "TaskNode.ReferenceId")) + } else if task, taskOk := w.GetTask(*node.GetTaskNode().GetReferenceId()); taskOk { + iface = task.GetInterface() + if iface == nil { + // Default value for no interface is nil, initialize an empty interface + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + } + } + } else { + errs.Collect(errors.NewTaskReferenceNotFoundErr(node.GetId(), node.GetTaskNode().GetReferenceId().String())) + } + case *core.Node_WorkflowNode: + if node.GetWorkflowNode().GetLaunchplanRef().String() == w.GetCoreWorkflow().Template.Id.String() { + iface = w.GetCoreWorkflow().Template.Interface + if iface == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.Interface")) + } + } else if node.GetWorkflowNode().GetLaunchplanRef() != nil { + if launchPlan, launchPlanOk := w.GetLaunchPlan(*node.GetWorkflowNode().GetLaunchplanRef()); launchPlanOk { + inputs := launchPlan.GetExpectedInputs() + if inputs == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.ExpectedInputs")) + } + + outputs := launchPlan.GetExpectedOutputs() + if outputs == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.ExpectedOutputs")) + } + + // Compute exposed inputs as the union of all required inputs and any input overwritten by the node. + exposedInputs := map[string]*core.Variable{} + for name, p := range inputs.Parameters { + if p.GetRequired() { + exposedInputs[name] = p.Var + } else if _, found := findBindingByVariableName(node.GetInputs(), name); found { + exposedInputs[name] = p.Var + } + // else, the param has a default value and is not being overwritten by the node + } + + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: exposedInputs, + }, + Outputs: outputs, + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr( + node.GetId(), + fmt.Sprintf("%v", node.GetWorkflowNode().GetLaunchplanRef()))) + } + } else if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + if wf, wfOk := w.GetSubWorkflow(*node.GetWorkflowNode().GetSubWorkflowRef()); wfOk { + if wf.Template == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.Template")) + } else { + iface = wf.Template.Interface + if iface == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.Template.Interface")) + } + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr( + node.GetId(), + fmt.Sprintf("%v", node.GetWorkflowNode().GetSubWorkflowRef()))) + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr( + node.GetId(), + fmt.Sprintf("%v/%v", node.GetWorkflowNode().GetLaunchplanRef(), node.GetWorkflowNode().GetSubWorkflowRef()))) + } + case *core.Node_BranchNode: + iface, _ = validateBranchInterface(w, node, errs.NewScope()) + default: + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) + } + + if iface != nil { + ValidateInterface(node.GetId(), iface, errs.NewScope()) + } + + if !errs.HasErrors() { + node.SetInterface(iface) + } + + return iface, !errs.HasErrors() +} diff --git a/pkg/compiler/validators/interface_test.go b/pkg/compiler/validators/interface_test.go new file mode 100644 index 000000000..f85d4f7d9 --- /dev/null +++ b/pkg/compiler/validators/interface_test.go @@ -0,0 +1,282 @@ +package validators + +import ( + "testing" + + "github.com/lyft/flyteidl/clients/go/coreutils" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/common/mocks" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestValidateInterface(t *testing.T) { + t.Run("Happy path", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ok := ValidateInterface( + c.NodeID("node1"), + &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + }, + errs.NewScope(), + ) + + assertNonEmptyInterface(t, iface, ok, errs) + }) + + t.Run("Empty Inputs/Outputs", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ok := ValidateInterface( + c.NodeID("node1"), + &core.TypedInterface{}, + errs.NewScope(), + ) + + assertNonEmptyInterface(t, iface, ok, errs) + }) + + t.Run("Empty Interface", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ok := ValidateInterface( + c.NodeID("node1"), + nil, + errs.NewScope(), + ) + + assertNonEmptyInterface(t, iface, ok, errs) + }) +} + +func assertNonEmptyInterface(t testing.TB, iface *core.TypedInterface, ifaceOk bool, errs errors.CompileErrors) { + assert.True(t, ifaceOk) + assert.NotNil(t, iface) + assert.False(t, errs.HasErrors()) + if !ifaceOk { + t.Fatal(errs) + } + + assert.NotNil(t, iface.Inputs) + assert.NotNil(t, iface.Inputs.Variables) + assert.NotNil(t, iface.Outputs) + assert.NotNil(t, iface.Outputs.Variables) +} + +func TestValidateUnderlyingInterface(t *testing.T) { + t.Run("Invalid empty node", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{}) + nodeBuilder.On("GetId").Return("node_1") + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + assert.Nil(t, iface) + assert.True(t, errs.HasErrors()) + }) + + t.Run("Task Node", func(t *testing.T) { + task := mocks.Task{} + task.On("GetInterface").Return(nil) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetTask", mock.MatchedBy(func(id core.Identifier) bool { + return id.String() == (&core.Identifier{ + Name: "Task_1", + }).String() + })).Return(&task, true) + + taskNode := &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{ + Name: "Task_1", + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: taskNode, + }, + }) + + nodeBuilder.On("GetTaskNode").Return(taskNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("Workflow Node", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_1", + }, + }, + }) + workflowNode := &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{ + Name: "Ref_1", + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: workflowNode, + }, + }) + + nodeBuilder.On("GetWorkflowNode").Return(workflowNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + nodeBuilder.On("GetInputs").Return([]*core.Binding{}) + + t.Run("Self", func(t *testing.T) { + errs := errors.NewCompileErrors() + _, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + }, + }, + }) + + errs = errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("LP_Ref", func(t *testing.T) { + lp := mocks.InterfaceProvider{} + lp.On("GetID").Return(&core.Identifier{Name: "Ref_1"}) + lp.On("GetExpectedInputs").Return(&core.ParameterMap{ + Parameters: map[string]*core.Parameter{ + "required": { + Var: &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + Behavior: &core.Parameter_Required{ + Required: true, + }, + }, + "default_value": { + Var: &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + Behavior: &core.Parameter_Default{ + Default: coreutils.MustMakeLiteral(5), + }, + }, + }, + }) + lp.On("GetExpectedOutputs").Return(&core.VariableMap{}) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetLaunchPlan", mock.Anything).Return(nil, false) + + errs := errors.NewCompileErrors() + _, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + + wfBuilder = mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetLaunchPlan", matchIdentifier(core.Identifier{Name: "Ref_1"})).Return(&lp, true) + + errs = errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("Subwf", func(t *testing.T) { + subWf := core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{}, + }, + }, + } + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetLaunchPlan", mock.Anything).Return(nil, false) + + errs := errors.NewCompileErrors() + _, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + + wfBuilder = mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetSubWorkflow", matchIdentifier(core.Identifier{Name: "Ref_1"})).Return(&subWf, true) + + workflowNode.Reference = &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: &core.Identifier{Name: "Ref_1"}, + } + + errs = errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + }) +} + +func matchIdentifier(id core.Identifier) interface{} { + return mock.MatchedBy(func(arg core.Identifier) bool { + return arg.String() == id.String() + }) +} diff --git a/pkg/compiler/validators/node.go b/pkg/compiler/validators/node.go new file mode 100644 index 000000000..cf8f9a4b4 --- /dev/null +++ b/pkg/compiler/validators/node.go @@ -0,0 +1,116 @@ +// This package contains validators for all elements of the workflow spec (node, task, branch, interface, bindings... etc.) +package validators + +import ( + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +// Computes output parameters after applying all aliases -if any-. +func validateEffectiveOutputParameters(n c.NodeBuilder, errs errors.CompileErrors) ( + params *flyte.VariableMap, ok bool) { + aliases := make(map[string]string, len(n.GetOutputAliases())) + for _, alias := range n.GetOutputAliases() { + if _, found := aliases[alias.Var]; found { + errs.Collect(errors.NewDuplicateAliasErr(n.GetId(), alias.Alias)) + } else { + aliases[alias.Var] = alias.Alias + } + } + + if n.GetInterface() != nil { + params = &flyte.VariableMap{ + Variables: make(map[string]*flyte.Variable, len(n.GetInterface().GetOutputs().Variables)), + } + + for paramName, param := range n.GetInterface().GetOutputs().Variables { + if alias, found := aliases[paramName]; found { + if newParam, paramOk := withVariableName(param); paramOk { + params.Variables[alias] = newParam + } else { + errs.Collect(errors.NewParameterNotBoundErr(n.GetId(), alias)) + } + + delete(aliases, paramName) + } else { + params.Variables[paramName] = param + } + } + + // If there are still more aliases at this point, they point to non-existent variables. + for _, alias := range aliases { + errs.Collect(errors.NewParameterNotBoundErr(n.GetId(), alias)) + } + } + + return params, !errs.HasErrors() +} + +func validateBranchNode(w c.WorkflowBuilder, n c.NodeBuilder, errs errors.CompileErrors) bool { + cases := make([]*flyte.IfBlock, 0, len(n.GetBranchNode().IfElse.Other)+1) + cases = append(cases, n.GetBranchNode().IfElse.Case) + cases = append(cases, n.GetBranchNode().IfElse.Other...) + for _, block := range cases { + // Validate condition + ValidateBooleanExpression(n, block.Condition, errs.NewScope()) + + if block.GetThenNode() == nil { + errs.Collect(errors.NewBranchNodeNotSpecified(n.GetId())) + } else { + wrapperNode := w.NewNodeBuilder(block.GetThenNode()) + if ValidateNode(w, wrapperNode, errs.NewScope()) { + // Add to the global nodes to be able to reference it later + w.AddNode(wrapperNode, errs.NewScope()) + w.AddExecutionEdge(n.GetId(), block.GetThenNode().Id) + } + } + } + + return !errs.HasErrors() +} + +func validateNodeID(w c.WorkflowBuilder, nodeID string, errs errors.CompileErrors) (node c.NodeBuilder, ok bool) { + if nodeID == "" { + n, _ := w.GetNode(c.StartNodeID) + return n, !errs.HasErrors() + } else if node, ok = w.GetNode(nodeID); !ok { + errs.Collect(errors.NewNodeReferenceNotFoundErr(nodeID, nodeID)) + } + + return node, !errs.HasErrors() +} + +func ValidateNode(w c.WorkflowBuilder, n c.NodeBuilder, errs errors.CompileErrors) (ok bool) { + if n.GetId() == "" { + errs.Collect(errors.NewValueRequiredErr("", "Id")) + } + + if _, ifaceOk := ValidateUnderlyingInterface(w, n, errs.NewScope()); ifaceOk { + // Validate node output aliases + validateEffectiveOutputParameters(n, errs.NewScope()) + } + + // Validate branch node conditions and inner nodes. + if n.GetBranchNode() != nil { + validateBranchNode(w, n, errs.NewScope()) + } else if workflowN := n.GetWorkflowNode(); workflowN != nil && workflowN.GetSubWorkflowRef() != nil { + if wf, wfOk := w.GetSubWorkflow(*workflowN.GetSubWorkflowRef()); wfOk { + if subWorkflow, workflowOk := w.ValidateWorkflow(wf, errs.NewScope()); workflowOk { + n.SetSubWorkflow(subWorkflow) + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr(n.GetId(), workflowN.GetSubWorkflowRef().String())) + } + } else if taskN := n.GetTaskNode(); taskN != nil && taskN.GetReferenceId() != nil { + if task, found := w.GetTask(*taskN.GetReferenceId()); found { + n.SetTask(task) + } else if taskN.GetReferenceId() == nil { + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "TaskNode.ReferenceId")) + } else { + errs.Collect(errors.NewTaskReferenceNotFoundErr(n.GetId(), taskN.GetReferenceId().String())) + } + } + + return !errs.HasErrors() +} diff --git a/pkg/compiler/validators/typing.go b/pkg/compiler/validators/typing.go new file mode 100644 index 000000000..f98023717 --- /dev/null +++ b/pkg/compiler/validators/typing.go @@ -0,0 +1,154 @@ +package validators + +import ( + structpb "github.com/golang/protobuf/ptypes/struct" + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type typeChecker interface { + CastsFrom(*flyte.LiteralType) bool +} + +type trivialChecker struct { + literalType *flyte.LiteralType +} + +type voidChecker struct{} + +type mapTypeChecker struct { + literalType *flyte.LiteralType +} + +type collectionTypeChecker struct { + literalType *flyte.LiteralType +} + +type schemaTypeChecker struct { + literalType *flyte.LiteralType +} + +// The trivial type checker merely checks if types match exactly. +func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Everything is nullable currently + if isVoid(upstreamType) { + return true + } + // Ignore metadata when comparing types. + upstreamTypeCopy := *upstreamType + downstreamTypeCopy := *t.literalType + upstreamTypeCopy.Metadata = &structpb.Struct{} + downstreamTypeCopy.Metadata = &structpb.Struct{} + return upstreamTypeCopy.String() == downstreamTypeCopy.String() +} + +// The void type matches everything +func (t voidChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + return true +} + +// For a map type checker, we need to ensure both the key types and value types match. +func (t mapTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Maps are nullable + if isVoid(upstreamType) { + return true + } + + mapLiteralType := upstreamType.GetMapValueType() + if mapLiteralType != nil { + return getTypeChecker(t.literalType.GetMapValueType()).CastsFrom(mapLiteralType) + } + return false +} + +// For a collection type, we need to ensure that the nesting is correct and the final sub-types match. +func (t collectionTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Collections are nullable + if isVoid(upstreamType) { + return true + } + + collectionType := upstreamType.GetCollectionType() + if collectionType != nil { + return getTypeChecker(t.literalType.GetCollectionType()).CastsFrom(collectionType) + } + return false +} + +// Schemas are more complex types in the Flyte ecosystem. A schema is considered castable in the following +// cases. +// +// 1. The downstream schema has no column types specified. In such a case, it accepts all schema input since it is +// generic. +// +// 2. The downstream schema has a subset of the upstream columns and they match perfectly. +// +func (t schemaTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Schemas are nullable + if isVoid(upstreamType) { + return true + } + + schemaType := upstreamType.GetSchema() + if schemaType == nil { + return false + } + + // If no columns are specified, this is a generic schema and it can accept any schema type. + if len(t.literalType.GetSchema().Columns) == 0 { + return true + } + + nameToTypeMap := make(map[string]flyte.SchemaType_SchemaColumn_SchemaColumnType) + for _, column := range schemaType.Columns { + nameToTypeMap[column.Name] = column.Type + } + + // Check that the downstream schema is a strict sub-set of the upstream schema. + for _, column := range t.literalType.GetSchema().Columns { + upstreamType, ok := nameToTypeMap[column.Name] + if !ok { + return false + } + if upstreamType != column.Type { + return false + } + } + return true +} + +func isVoid(t *flyte.LiteralType) bool { + switch t.GetType().(type) { + case *flyte.LiteralType_Simple: + return t.GetSimple() == flyte.SimpleType_NONE + default: + return false + } +} + +func getTypeChecker(t *flyte.LiteralType) typeChecker { + switch t.GetType().(type) { + case *flyte.LiteralType_CollectionType: + return collectionTypeChecker{ + literalType: t, + } + case *flyte.LiteralType_MapValueType: + return mapTypeChecker{ + literalType: t, + } + case *flyte.LiteralType_Schema: + return schemaTypeChecker{ + literalType: t, + } + default: + if isVoid(t) { + return voidChecker{} + } + return trivialChecker{ + literalType: t, + } + } +} + +func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool { + return getTypeChecker(downstreamType).CastsFrom(upstreamType) +} diff --git a/pkg/compiler/validators/typing_test.go b/pkg/compiler/validators/typing_test.go new file mode 100644 index 000000000..2e4654653 --- /dev/null +++ b/pkg/compiler/validators/typing_test.go @@ -0,0 +1,359 @@ +package validators + +import ( + "testing" + + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestSimpleLiteralCasting(t *testing.T) { + t.Run("BaseCase_Integer", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.True(t, castable, "Integers should be castable to other integers") + }) + + t.Run("IntegerToFloat", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + ) + assert.False(t, castable, "Integers should not be castable to floats") + }) + + t.Run("FloatToInteger", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.False(t, castable, "Floats should not be castable to integers") + }) + + t.Run("VoidToInteger", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.True(t, castable, "Floats are nullable") + }) + + t.Run("IgnoreMetadata", func(t *testing.T) { + s := structpb.Struct{ + Fields: map[string]*structpb.Value{ + "a": {}, + }, + } + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + Metadata: &s, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.True(t, castable, "Metadata should be ignored") + }) +} + +func TestCollectionCasting(t *testing.T) { + t.Run("BaseCase_SingleIntegerCollection", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + ) + assert.True(t, castable, "[Integer] should be castable to [Integer].") + }) + + t.Run("SingleIntegerCollectionToSingleFloatCollection", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + }, + }, + ) + assert.False(t, castable, "[Integer] should not be castable to [Float]") + }) + + t.Run("MismatchedNestLevels_Scalar", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.False(t, castable, "[Integer] should not be castable to Integer") + }) + + t.Run("MismatchedNestLevels_Collections", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + }, + }, + ) + assert.False(t, castable, "[Integer] should not be castable to [[Integer]]") + }) + + t.Run("Nullable_Collections", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + }, + }, + ) + assert.True(t, castable, "Collections are nullable") + }) +} + +func TestMapCasting(t *testing.T) { + t.Run("BaseCase_SingleIntegerMap", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + ) + assert.True(t, castable, "{k: Integer} should be castable to {k: Integer}.") + }) + + t.Run("ScalarIntegerMapToScalarFloatMap", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + }, + }, + ) + assert.False(t, castable, "{k: Integer} should not be castable to {k: Float}") + }) + + t.Run("MismatchedMapNestLevels_Scalar", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.False(t, castable, "{k: Integer} should not be castable to Integer") + }) + + t.Run("MismatchedMapNestLevels_Maps", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + }, + }, + ) + assert.False(t, castable, "{k: Integer} should not be castable to {k: {k: Integer}}") + }) +} + +func TestSchemaCasting(t *testing.T) { + genericSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{}, + }, + }, + } + subsetIntegerSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{ + { + Name: "a", + Type: core.SchemaType_SchemaColumn_INTEGER, + }, + }, + }, + }, + } + supersetIntegerAndFloatSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{ + { + Name: "a", + Type: core.SchemaType_SchemaColumn_INTEGER, + }, + { + Name: "b", + Type: core.SchemaType_SchemaColumn_FLOAT, + }, + }, + }, + }, + } + mismatchedSubsetSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{ + { + Name: "a", + Type: core.SchemaType_SchemaColumn_FLOAT, + }, + }, + }, + }, + } + + t.Run("BaseCase_GenericSchema", func(t *testing.T) { + castable := AreTypesCastable(genericSchema, genericSchema) + assert.True(t, castable, "Schema() should be castable to Schema()") + }) + + t.Run("GenericSchemaToNonGeneric", func(t *testing.T) { + castable := AreTypesCastable(genericSchema, subsetIntegerSchema) + assert.False(t, castable, "Schema() should not be castable to Schema(a=Integer)") + }) + + t.Run("NonGenericSchemaToGeneric", func(t *testing.T) { + castable := AreTypesCastable(subsetIntegerSchema, genericSchema) + assert.True(t, castable, "Schema(a=Integer) should be castable to Schema()") + }) + + t.Run("SupersetToSubsetTypedSchema", func(t *testing.T) { + castable := AreTypesCastable(supersetIntegerAndFloatSchema, subsetIntegerSchema) + assert.True(t, castable, "Schema(a=Integer, b=Float) should be castable to Schema(a=Integer)") + }) + + t.Run("SubsetToSupersetSchema", func(t *testing.T) { + castable := AreTypesCastable(subsetIntegerSchema, supersetIntegerAndFloatSchema) + assert.False(t, castable, "Schema(a=Integer) should not be castable to Schema(a=Integer, b=Float)") + }) + + t.Run("MismatchedColumns", func(t *testing.T) { + castable := AreTypesCastable(subsetIntegerSchema, mismatchedSubsetSchema) + assert.False(t, castable, "Schema(a=Integer) should not be castable to Schema(a=Float)") + }) + + t.Run("MismatchedColumnsFlipped", func(t *testing.T) { + castable := AreTypesCastable(mismatchedSubsetSchema, subsetIntegerSchema) + assert.False(t, castable, "Schema(a=Float) should not be castable to Schema(a=Integer)") + }) + + t.Run("SchemasAreNullable", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + }, + subsetIntegerSchema) + assert.True(t, castable, "Schemas are nullable") + }) +} diff --git a/pkg/compiler/validators/utils.go b/pkg/compiler/validators/utils.go new file mode 100644 index 000000000..19d6ee835 --- /dev/null +++ b/pkg/compiler/validators/utils.go @@ -0,0 +1,199 @@ +package validators + +import ( + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "k8s.io/apimachinery/pkg/util/sets" +) + +func findBindingByVariableName(bindings []*core.Binding, name string) (binding *core.Binding, found bool) { + for _, b := range bindings { + if b.Var == name { + return b, true + } + } + + return nil, false +} + +func findVariableByName(vars *core.VariableMap, name string) (variable *core.Variable, found bool) { + if vars == nil || vars.Variables == nil { + return nil, false + } + + variable, found = vars.Variables[name] + return +} + +// Gets literal type for scalar value. This can be used to compare the underlying type of two scalars for compatibility. +func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { + // TODO: Should we just pass the type information with the value? That way we don't have to guess? + var literalType *core.LiteralType + switch scalar.GetValue().(type) { + case *core.Scalar_Primitive: + literalType = literalTypeForPrimitive(scalar.GetPrimitive()) + case *core.Scalar_Blob: + if scalar.GetBlob().GetMetadata() == nil { + return nil + } + + literalType = &core.LiteralType{Type: &core.LiteralType_Blob{Blob: scalar.GetBlob().GetMetadata().GetType()}} + case *core.Scalar_Binary: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + case *core.Scalar_Schema: + literalType = &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: scalar.GetSchema().Type, + }, + } + case *core.Scalar_NoneType: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + case *core.Scalar_Error: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_ERROR}} + default: + return nil + } + + return literalType +} + +func literalTypeForPrimitive(primitive *core.Primitive) *core.LiteralType { + simpleType := core.SimpleType_NONE + switch primitive.GetValue().(type) { + case *core.Primitive_Integer: + simpleType = core.SimpleType_INTEGER + case *core.Primitive_FloatValue: + simpleType = core.SimpleType_FLOAT + case *core.Primitive_StringValue: + simpleType = core.SimpleType_STRING + case *core.Primitive_Boolean: + simpleType = core.SimpleType_BOOLEAN + case *core.Primitive_Datetime: + simpleType = core.SimpleType_DATETIME + case *core.Primitive_Duration: + simpleType = core.SimpleType_DURATION + } + + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: simpleType}} +} + +func buildVariablesIndex(params *core.VariableMap) (map[string]*core.Variable, sets.String) { + paramMap := make(map[string]*core.Variable, len(params.Variables)) + paramSet := sets.NewString() + for paramName, param := range params.Variables { + paramMap[paramName] = param + paramSet.Insert(paramName) + } + + return paramMap, paramSet +} + +func filterVariables(vars *core.VariableMap, varNames sets.String) *core.VariableMap { + res := &core.VariableMap{ + Variables: make(map[string]*core.Variable, len(varNames)), + } + + for paramName, param := range vars.Variables { + if varNames.Has(paramName) { + res.Variables[paramName] = param + } + } + + return res +} + +func withVariableName(param *core.Variable) (newParam *core.Variable, ok bool) { + if raw, err := proto.Marshal(param); err == nil { + newParam = &core.Variable{} + if err = proto.Unmarshal(raw, newParam); err == nil { + ok = true + } + } + + return +} + +// Gets LiteralType for literal, nil if the value of literal is unknown, or type None if the literal is a non-homogeneous +// type. +func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { + switch l.GetValue().(type) { + case *core.Literal_Scalar: + return literalTypeForScalar(l.GetScalar()) + case *core.Literal_Collection: + if len(l.GetCollection().Literals) == 0 { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + // Ensure literal collection types are homogeneous. + var innerType *core.LiteralType + for _, x := range l.GetCollection().Literals { + otherType := LiteralTypeForLiteral(x) + if innerType != nil && !AreTypesCastable(otherType, innerType) { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + innerType = otherType + } + + return &core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: innerType}} + case *core.Literal_Map: + if len(l.GetMap().Literals) == 0 { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + // Ensure literal map types are homogeneous. + var innerType *core.LiteralType + for _, x := range l.GetMap().Literals { + otherType := LiteralTypeForLiteral(x) + if innerType != nil && !AreTypesCastable(otherType, innerType) { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + innerType = otherType + } + + return &core.LiteralType{Type: &core.LiteralType_MapValueType{MapValueType: innerType}} + } + + return nil +} + +// Converts a literal to a non-promise binding data. +func LiteralToBinding(l *core.Literal) *core.BindingData { + switch l.GetValue().(type) { + case *core.Literal_Scalar: + return &core.BindingData{ + Value: &core.BindingData_Scalar{ + Scalar: l.GetScalar(), + }, + } + case *core.Literal_Collection: + x := make([]*core.BindingData, 0, len(l.GetCollection().Literals)) + for _, sub := range l.GetCollection().Literals { + x = append(x, LiteralToBinding(sub)) + } + + return &core.BindingData{ + Value: &core.BindingData_Collection{ + Collection: &core.BindingDataCollection{ + Bindings: x, + }, + }, + } + case *core.Literal_Map: + x := make(map[string]*core.BindingData, len(l.GetMap().Literals)) + for key, val := range l.GetMap().Literals { + x[key] = LiteralToBinding(val) + } + + return &core.BindingData{ + Value: &core.BindingData_Map{ + Map: &core.BindingDataMap{ + Bindings: x, + }, + }, + } + } + + return nil +} diff --git a/pkg/compiler/validators/vars.go b/pkg/compiler/validators/vars.go new file mode 100644 index 000000000..04058b841 --- /dev/null +++ b/pkg/compiler/validators/vars.go @@ -0,0 +1,81 @@ +package validators + +import ( + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateOutputVar(n c.NodeBuilder, paramName string, errs errors.CompileErrors) ( + param *flyte.Variable, ok bool) { + if outputs, effectiveOk := validateEffectiveOutputParameters(n, errs.NewScope()); effectiveOk { + var paramFound bool + if param, paramFound = findVariableByName(outputs, paramName); !paramFound { + errs.Collect(errors.NewVariableNameNotFoundErr(n.GetId(), n.GetId(), paramName)) + } + } + + return param, !errs.HasErrors() +} + +func validateInputVar(n c.NodeBuilder, paramName string, errs errors.CompileErrors) (param *flyte.Variable, ok bool) { + if n.GetInterface() == nil { + return nil, false + } + + if param, ok = findVariableByName(n.GetInterface().GetInputs(), paramName); !ok { + errs.Collect(errors.NewVariableNameNotFoundErr(n.GetId(), n.GetId(), paramName)) + } + + return +} + +func validateVarType(nodeID c.NodeID, paramName string, param *flyte.Variable, + expectedType *flyte.LiteralType, errs errors.CompileErrors) (ok bool) { + if param.GetType().String() != expectedType.String() { + errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, param.GetType().String(), expectedType.String())) + } + + return !errs.HasErrors() +} + +func validateVarsSetMatch(nodeID string, params1, params2 map[string]*flyte.Variable, + params1Set, params2Set sets.String, errs errors.CompileErrors) (match bool) { + // Validate that parameters that exist in both interfaces have compatible types. + inBoth := params1Set.Intersection(params2Set) + for paramName := range inBoth { + if validateVarType(nodeID, paramName, params1[paramName], params2[paramName].Type, errs.NewScope()) { + validateVarType(nodeID, paramName, params2[paramName], params1[paramName].Type, errs.NewScope()) + } + } + + // All remaining params on either sides indicate errors + inLeftSide := params1Set.Intersection(params2Set) + for range inLeftSide { + errs.Collect(errors.NewMismatchingInterfacesErr(nodeID, nodeID)) + } + + inRightSide := params2Set.Intersection(params1Set) + for range inRightSide { + errs.Collect(errors.NewMismatchingInterfacesErr(nodeID, nodeID)) + } + + return !errs.HasErrors() +} + +// Validate parameters have their required attributes set +func validateVariables(nodeID c.NodeID, params *flyte.VariableMap, errs errors.CompileErrors) (ok bool) { + + for paramName, param := range params.Variables { + if len(paramName) == 0 { + errs.Collect(errors.NewValueRequiredErr(nodeID, "paramName")) + } + + if param.Type == nil { + errs.Collect(errors.NewValueRequiredErr(nodeID, "param.Type")) + } + } + + return !errs.HasErrors() +} diff --git a/pkg/compiler/workflow_compiler.go b/pkg/compiler/workflow_compiler.go new file mode 100755 index 000000000..2ec6e39fc --- /dev/null +++ b/pkg/compiler/workflow_compiler.go @@ -0,0 +1,330 @@ +// This package provides compiler services for flyte workflows. It performs static analysis on the Workflow and produces +// CompilerErrors for any detected issue. A flyte workflow should only be considered valid for execution if it passed through +// the compiler first. The intended usage for the compiler is as follows: +// 1) Call GetRequirements(...) and load/retrieve all tasks/workflows referenced in the response. +// 2) Call CompileWorkflow(...) and make sure it reports no errors. +// 3) Use one of the transformer packages (e.g. transformer/k8s) to build the final executable workflow. +// +// +-------------------+ +// | start(StartNode) | +// +-------------------+ +// | +// | wf_input +// v +// +--------+ +-------------------+ +// | static | --> | node_1(TaskNode) | +// +--------+ +-------------------+ +// | | +// | | x +// | v +// | +-------------------+ +// +----------> | node_2(TaskNode) | +// +-------------------+ +// | +// | n2_output +// v +// +-------------------+ +// | end(EndNode) | +// +-------------------+ +// +-------------------+ +// | Workflow Id: repo | +// +-------------------+ +package compiler + +import ( + "strings" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + v "github.com/lyft/flytepropeller/pkg/compiler/validators" + "k8s.io/apimachinery/pkg/util/sets" +) + +// Updates workflows and tasks references to reflect the needed ones for this workflow (ignoring subworkflows) +func (w *workflowBuilder) updateRequiredReferences() { + reqs := getRequirements(w.CoreWorkflow.Template, w.allSubWorkflows, false, errors.NewCompileErrors()) + workflows := map[c.WorkflowIDKey]c.InterfaceProvider{} + tasks := c.TaskIndex{} + for _, workflowID := range reqs.launchPlanIds { + if wf, ok := w.allLaunchPlans[workflowID.String()]; ok { + workflows[workflowID.String()] = wf + } + } + + for _, taskID := range reqs.taskIds { + if task, ok := w.allTasks[taskID.String()]; ok { + tasks[taskID.String()] = task + } + } + + w.Tasks = tasks + w.LaunchPlans = workflows +} + +// Validates the coreWorkflow contains no cycles and that all nodes are reachable. +func (w workflowBuilder) validateReachable(errs errors.CompileErrors) (ok bool) { + neighbors := func(nodeId string) sets.String { + downNodes := w.downstreamNodes[nodeId] + if downNodes == nil { + return sets.String{} + } + + return downNodes + } + + // TODO: If a branch node can exist in a cycle and not actually be a cycle since it can branch off... + if cycle, visited, detected := detectCycle(c.StartNodeID, neighbors); detected { + errs.Collect(errors.NewCycleDetectedInWorkflowErr(c.StartNodeID, strings.Join(cycle, ">"))) + } else { + // If no cycles are detected, we expect all nodes to have been visited. Otherwise there are unreachable + // node(s).. + if visited.Len() != len(w.Nodes) { + // report unreachable nodes + allNodes := toNodeIdsSet(w.Nodes) + unreachableNodes := allNodes.Difference(visited).Difference(sets.NewString(c.EndNodeID)) + if len(unreachableNodes) > 0 { + errs.Collect(errors.NewUnreachableNodesErr(c.StartNodeID, strings.Join(toSlice(unreachableNodes), ","))) + } + } + } + + return !errs.HasErrors() +} + +// Adds unique nodes to the workflow. +func (w workflowBuilder) AddNode(n c.NodeBuilder, errs errors.CompileErrors) (node c.NodeBuilder, ok bool) { + if _, ok := w.Nodes[n.GetId()]; ok { + errs.Collect(errors.NewDuplicateIDFoundErr(n.GetId())) + } + + node = n + w.Nodes[n.GetId()] = node + ok = !errs.HasErrors() + w.CoreWorkflow.Template.Nodes = append(w.CoreWorkflow.Template.Nodes, node.GetCoreNode()) + return +} + +func (w workflowBuilder) AddExecutionEdge(nodeFrom, nodeTo c.NodeID) { + if nodeFrom == "" { + nodeFrom = c.StartNodeID + } + + if _, found := w.downstreamNodes[nodeFrom]; !found { + w.downstreamNodes[nodeFrom] = sets.String{} + w.CoreWorkflow.Connections.Downstream[nodeFrom] = &core.ConnectionSet_IdList{} + } + + if _, found := w.upstreamNodes[nodeTo]; !found { + w.upstreamNodes[nodeTo] = sets.String{} + w.CoreWorkflow.Connections.Upstream[nodeTo] = &core.ConnectionSet_IdList{ + Ids: make([]string, 1), + } + } + + w.downstreamNodes[nodeFrom].Insert(nodeTo) + w.upstreamNodes[nodeTo].Insert(nodeFrom) + w.CoreWorkflow.Connections.Downstream[nodeFrom].Ids = w.downstreamNodes[nodeFrom].List() + w.CoreWorkflow.Connections.Upstream[nodeTo].Ids = w.upstreamNodes[nodeTo].List() +} + +func (w workflowBuilder) AddEdges(n c.NodeBuilder, errs errors.CompileErrors) (ok bool) { + if n.GetInterface() == nil { + // If there were errors computing node's interface, don't add any edges and just bail. + return + } + + // Add explicitly declared edges + if n.GetUpstreamNodeIds() != nil { + for _, upNode := range n.GetUpstreamNodeIds() { + w.AddExecutionEdge(upNode, n.GetId()) + } + } + + // Add implicit Edges + return v.ValidateBindings(&w, n, n.GetInputs(), n.GetInterface().GetInputs(), errs.NewScope()) +} + +// Contains the main validation logic for the coreWorkflow. If successful, it'll build an executable Workflow. +func (w workflowBuilder) ValidateWorkflow(fg *flyteWorkflow, errs errors.CompileErrors) (c.Workflow, bool) { + // Initialize workflow + wf := w.newWorkflowBuilder(fg) + wf.updateRequiredReferences() + + // Start building out the workflow + // Create global sentinel nodeBuilder with the workflow as its interface. + startNode := &core.Node{ + Id: c.StartNodeID, + } + + var ok bool + if wf.CoreWorkflow.Template.Interface, ok = v.ValidateInterface(c.StartNodeID, wf.CoreWorkflow.Template.Interface, errs.NewScope()); !ok { + return nil, !errs.HasErrors() + } + + checkpoint := make([]*core.Node, 0, len(fg.Template.Nodes)) + checkpoint = append(checkpoint, fg.Template.Nodes...) + fg.Template.Nodes = make([]*core.Node, 0, len(fg.Template.Nodes)) + wf.GetCoreWorkflow().Connections = &core.ConnectionSet{ + Downstream: make(map[string]*core.ConnectionSet_IdList), + Upstream: make(map[string]*core.ConnectionSet_IdList), + } + + globalInputNode, _ := wf.AddNode(wf.NewNodeBuilder(startNode), errs) + globalInputNode.SetInterface(&core.TypedInterface{Outputs: wf.CoreWorkflow.Template.Interface.Inputs}) + + endNode := &core.Node{Id: c.EndNodeID} + globalOutputNode, _ := wf.AddNode(wf.NewNodeBuilder(endNode), errs) + globalOutputNode.SetInterface(&core.TypedInterface{Inputs: wf.CoreWorkflow.Template.Interface.Outputs}) + globalOutputNode.SetInputs(wf.CoreWorkflow.Template.Outputs) + + // Add and validate all other nodes + for _, n := range checkpoint { + if node, addOk := wf.AddNode(wf.NewNodeBuilder(n), errs.NewScope()); addOk { + v.ValidateNode(&wf, node, errs.NewScope()) + } + } + + // Add explicitly and implicitly declared edges + for nodeID, n := range wf.Nodes { + if nodeID == c.StartNodeID { + continue + } + + wf.AddEdges(n, errs.NewScope()) + } + + // Add execution edges for orphan nodes that don't have any inward/outward edges. + for nodeID := range wf.Nodes { + if nodeID == c.StartNodeID || nodeID == c.EndNodeID { + continue + } + + if _, foundUpStream := wf.upstreamNodes[nodeID]; !foundUpStream { + wf.AddExecutionEdge(c.StartNodeID, nodeID) + } + + if _, foundDownStream := wf.downstreamNodes[nodeID]; !foundDownStream { + wf.AddExecutionEdge(nodeID, c.EndNodeID) + } + } + + // Validate workflow outputs are bound + if _, wfIfaceOk := v.ValidateInterface(globalOutputNode.GetId(), globalOutputNode.GetInterface(), errs.NewScope()); wfIfaceOk { + v.ValidateBindings(&wf, globalOutputNode, globalOutputNode.GetInputs(), + globalOutputNode.GetInterface().GetInputs(), errs.NewScope()) + } + + // Validate no cycles are detected. + wf.validateReachable(errs.NewScope()) + + return wf, !errs.HasErrors() +} + +// Validates that all requirements for the coreWorkflow and its subworkflows are present. +func (w workflowBuilder) validateAllRequirements(errs errors.CompileErrors) bool { + reqs := getRequirements(w.CoreWorkflow.Template, w.allSubWorkflows, true, errs) + + for _, lp := range reqs.launchPlanIds { + if _, ok := w.allLaunchPlans[lp.String()]; !ok { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr(c.StartNodeID, lp.String())) + } + } + + for _, taskID := range reqs.taskIds { + if _, ok := w.allTasks[taskID.String()]; !ok { + errs.Collect(errors.NewTaskReferenceNotFoundErr(c.StartNodeID, taskID.String())) + } + } + + return !errs.HasErrors() +} + +// Compiles a flyte workflow a and all of its dependencies into a single executable Workflow. Refer to GetRequirements() +// to obtain a list of launchplan and Task ids to load/compile first. +// Returns an executable Workflow (if no errors are found) or a list of errors that must be addressed before the Workflow +// can be executed. Cast the error to errors.CompileErrors to inspect individual errors. +func CompileWorkflow(primaryWf *core.WorkflowTemplate, subworkflows []*core.WorkflowTemplate, tasks []*core.CompiledTask, + launchPlans []c.InterfaceProvider) (*core.CompiledWorkflowClosure, error) { + + errs := errors.NewCompileErrors() + + if primaryWf == nil { + errs.Collect(errors.NewValueRequiredErr("root", "wf")) + return nil, errs + } + + wf := proto.Clone(primaryWf).(*core.WorkflowTemplate) + + if tasks == nil { + errs.Collect(errors.NewValueRequiredErr("root", "tasks")) + return nil, errs + } + + // Validate all tasks are valid... invalid tasks won't be passed on to the workflow validator + uniqueTasks := sets.NewString() + taskBuilders := make([]c.Task, 0, len(tasks)) + for _, task := range tasks { + if task.Template == nil || task.Template.Id == nil { + errs.Collect(errors.NewValueRequiredErr("task", "Template.Id")) + return nil, errs + } + + if uniqueTasks.Has(task.Template.Id.String()) { + continue + } + + taskBuilders = append(taskBuilders, &taskBuilder{flyteTask: task.Template}) + uniqueTasks.Insert(task.Template.Id.String()) + } + + // Validate overall requirements of the coreWorkflow. + wfIndex, ok := c.NewWorkflowIndex(toCompiledWorkflows(subworkflows...), errs.NewScope()) + if !ok { + return nil, errs + } + + compiledWf := &core.CompiledWorkflow{Template: wf} + + gb := newWorfklowBuilder(compiledWf, wfIndex, c.NewTaskIndex(taskBuilders...), toInterfaceProviderMap(launchPlans)) + // Terminate early if there are some required component not present. + if !gb.validateAllRequirements(errs.NewScope()) { + return nil, errs + } + + validatedWf, ok := gb.ValidateWorkflow(compiledWf, errs.NewScope()) + if ok { + compiledTasks := make([]*core.CompiledTask, 0, len(taskBuilders)) + for _, t := range taskBuilders { + compiledTasks = append(compiledTasks, &core.CompiledTask{Template: t.GetCoreTask()}) + } + + return &core.CompiledWorkflowClosure{ + Primary: validatedWf.GetCoreWorkflow(), + Tasks: compiledTasks, + }, nil + } + + return nil, errs +} + +func (w workflowBuilder) newWorkflowBuilder(fg *flyteWorkflow) workflowBuilder { + return newWorfklowBuilder(fg, w.allSubWorkflows, w.allTasks, w.allLaunchPlans) +} + +func newWorfklowBuilder(fg *flyteWorkflow, wfIndex c.WorkflowIndex, tasks c.TaskIndex, + workflows map[string]c.InterfaceProvider) workflowBuilder { + + return workflowBuilder{ + CoreWorkflow: fg, + LaunchPlans: map[string]c.InterfaceProvider{}, + Nodes: c.NewNodeIndex(), + Tasks: c.NewTaskIndex(), + downstreamNodes: c.StringAdjacencyList{}, + upstreamNodes: c.StringAdjacencyList{}, + allSubWorkflows: wfIndex, + allLaunchPlans: workflows, + allTasks: tasks, + } +} diff --git a/pkg/compiler/workflow_compiler_test.go b/pkg/compiler/workflow_compiler_test.go new file mode 100755 index 000000000..eccdbc3b5 --- /dev/null +++ b/pkg/compiler/workflow_compiler_test.go @@ -0,0 +1,659 @@ +package compiler + +import ( + "fmt" + "strings" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + v "github.com/lyft/flytepropeller/pkg/compiler/validators" + "github.com/lyft/flytepropeller/pkg/visualize" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func createEmptyVariableMap() *core.VariableMap { + res := &core.VariableMap{ + Variables: map[string]*core.Variable{}, + } + return res +} + +func createVariableMap(variableMap map[string]*core.Variable) *core.VariableMap { + res := &core.VariableMap{ + Variables: variableMap, + } + return res +} + +func dumpIdentifierNames(ids []common.Identifier) []string { + res := make([]string, 0, len(ids)) + + for _, id := range ids { + res = append(res, id.Name) + } + + return res +} + +func ExampleCompileWorkflow_basic() { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Nodes: []*core.Node{ + { + Id: "FirstNode", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task_123"}, + }, + }, + }, + }, + }, + } + + // Detect what other workflows/tasks does this coreWorkflow reference + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(inputWorkflow, subWorkflows) + if err != nil { + fmt.Printf("failed to get requirements. Error: %v", err) + return + } + + fmt.Printf("Needed Tasks: [%v], Needed Workflows [%v]\n", + strings.Join(dumpIdentifierNames(reqs.GetRequiredTaskIds()), ","), + strings.Join(dumpIdentifierNames(reqs.GetRequiredLaunchPlanIds()), ",")) + + // Replace with logic to satisfy the requirements + workflows := make([]common.InterfaceProvider, 0) + tasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + Command: []string{"cmd"}, + Args: []string{"args"}, + }, + }, + }, + } + + compiledTasks := make([]*core.CompiledTask, 0, len(tasks)) + for _, task := range tasks { + compiledTask, err := CompileTask(task) + if err != nil { + fmt.Printf("failed to compile task [%v]. Error: %v", task.Id, err) + return + } + + compiledTasks = append(compiledTasks, compiledTask) + } + + output, errs := CompileWorkflow(inputWorkflow, subWorkflows, compiledTasks, workflows) + fmt.Printf("Compiled Workflow in GraphViz: %v\n", visualize.ToGraphViz(output.Primary)) + fmt.Printf("Compile Errors: %v\n", errs) + + // Output: + // Needed Tasks: [task_123], Needed Workflows [] + // Compiled Workflow in GraphViz: digraph G {rankdir=TB;workflow[label="Workflow Id: name:"repo" "];node[style=filled];"start-node(start)" [shape=Msquare];"start-node(start)" -> "FirstNode()" [label="execution",style="dashed"];"FirstNode()" -> "end-node(end)" [label="execution",style="dashed"];} + // Compile Errors: +} + +func ExampleCompileWorkflow_inputsOutputsBinding() { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "wf_input": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "wf_output": { + Type: getIntegerLiteralType(), + }, + }), + }, + Nodes: []*core.Node{ + { + Id: "node_1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newVarBinding("", "wf_input", "x"), newIntegerBinding(124, "y"), + }, + }, + { + Id: "node_2", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newIntegerBinding(124, "y"), newVarBinding("node_1", "x", "x"), + }, + OutputAliases: []*core.Alias{{Var: "x", Alias: "n2_output"}}, + }, + }, + Outputs: []*core.Binding{newVarBinding("node_2", "n2_output", "wf_output")}, + } + + // Detect what other graphs/tasks does this coreWorkflow reference + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(inputWorkflow, subWorkflows) + if err != nil { + fmt.Printf("Failed to get requirements. Error: %v", err) + return + } + + fmt.Printf("Needed Tasks: [%v], Needed Graphs [%v]\n", + strings.Join(dumpIdentifierNames(reqs.GetRequiredTaskIds()), ","), + strings.Join(dumpIdentifierNames(reqs.GetRequiredLaunchPlanIds()), ",")) + + // Replace with logic to satisfy the requirements + graphs := make([]common.InterfaceProvider, 0) + inputTasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, + Metadata: &core.TaskMetadata{}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + "y": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + Command: []string{"cmd"}, + Args: []string{"args"}, + }, + }, + }, + } + + // Compile all tasks before proceeding with Workflow + compiledTasks := make([]*core.CompiledTask, 0, len(inputTasks)) + for _, task := range inputTasks { + compiledTask, err := CompileTask(task) + if err != nil { + fmt.Printf("Failed to compile task [%v]. Error: %v", task.Id, err) + return + } + + compiledTasks = append(compiledTasks, compiledTask) + } + + output, errs := CompileWorkflow(inputWorkflow, subWorkflows, compiledTasks, graphs) + if errs != nil { + fmt.Printf("Compile Errors: %v\n", errs) + } else { + fmt.Printf("Compiled Workflow in GraphViz: %v\n", visualize.ToGraphViz(output.Primary)) + } + + // Output: + // Needed Tasks: [task_123], Needed Graphs [] + // Compiled Workflow in GraphViz: digraph G {rankdir=TB;workflow[label="Workflow Id: name:"repo" "];node[style=filled];"start-node(start)" [shape=Msquare];"start-node(start)" -> "node_1()" [label="wf_input",style="solid"];"node_1()" -> "node_2()" [label="x",style="solid"];"static" -> "node_1()" [label=""];"node_2()" -> "end-node(end)" [label="n2_output",style="solid"];"static" -> "node_2()" [label=""];} +} + +func ExampleCompileWorkflow_compileErrors() { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Nodes: []*core.Node{ + { + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task_123"}, + }, + }, + }, + }, + }, + } + + // Detect what other workflows/tasks does this coreWorkflow reference + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(inputWorkflow, subWorkflows) + if err != nil { + fmt.Printf("Failed to get requirements. Error: %v", err) + return + } + + fmt.Printf("Needed Tasks: [%v], Needed Workflows [%v]\n", + strings.Join(dumpIdentifierNames(reqs.GetRequiredTaskIds()), ","), + strings.Join(dumpIdentifierNames(reqs.GetRequiredLaunchPlanIds()), ",")) + + // Replace with logic to satisfy the requirements + workflows := make([]common.InterfaceProvider, 0) + _, errs := CompileWorkflow(inputWorkflow, subWorkflows, []*core.CompiledTask{}, workflows) + fmt.Printf("Compile Errors: %v\n", errs) + + // Output: + // Needed Tasks: [task_123], Needed Workflows [] + // Compile Errors: Collected Errors: 1 + // Error 0: Code: TaskReferenceNotFound, Node Id: start-node, Description: Referenced Task [name:"task_123" ] not found. +} + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newStringPrimitive(value string) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_StringValue{StringValue: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +func getIntegerLiteralType() *core.LiteralType { + return getSimpleLiteralType(core.SimpleType_INTEGER) +} + +func getSimpleLiteralType(simpleType core.SimpleType) *core.LiteralType { + return &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: simpleType, + }, + } +} + +func newIntegerBinding(value int64, toVar string) *core.Binding { + return &core.Binding{ + Binding: &core.BindingData{ + Value: &core.BindingData_Scalar{Scalar: newIntegerLiteral(value).GetScalar()}, + }, + Var: toVar, + } +} + +func newVarBinding(fromNodeID, fromVar, toVar string) *core.Binding { + return &core.Binding{ + Binding: &core.BindingData{ + Value: &core.BindingData_Promise{ + Promise: &core.OutputReference{ + NodeId: fromNodeID, + Var: fromVar, + }, + }, + }, + Var: toVar, + } +} + +func TestComparisonExpression_MissingLeftRight(t *testing.T) { + bExpr := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_GT, + }, + }, + } + + errs := errors.NewCompileErrors() + v.ValidateBooleanExpression(&nodeBuilder{flyteNode: &flyteNode{}}, bExpr, errs) + assert.Error(t, errs) + assert.Equal(t, 2, errs.ErrorCount()) +} + +func TestComparisonExpression(t *testing.T) { + bExpr := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_GT, + LeftValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newIntegerPrimitive(123)}}, + RightValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newStringPrimitive("hello")}}, + }, + }, + } + + errs := errors.NewCompileErrors() + v.ValidateBooleanExpression(&nodeBuilder{flyteNode: &flyteNode{}}, bExpr, errs) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 1, errs.ErrorCount()) +} + +func TestBooleanExpression_BranchNodeHasNoCondition(t *testing.T) { + bExpr := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: &core.ConjunctionExpression{ + Operator: core.ConjunctionExpression_AND, + RightExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_GT, + LeftValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newIntegerPrimitive(123)}}, + RightValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newIntegerPrimitive(345)}}, + }, + }, + }, + }, + }, + } + + errs := errors.NewCompileErrors() + v.ValidateBooleanExpression(&nodeBuilder{flyteNode: &flyteNode{}}, bExpr, errs) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 1, errs.ErrorCount()) + for e := range *errs.Errors() { + assert.Equal(t, errors.BranchNodeHasNoCondition, e.Code()) + } +} + +func newNodeIDSet(nodeIDs ...common.NodeID) sets.String { + return sets.NewString(nodeIDs...) +} + +func TestValidateReachable(t *testing.T) { + graph := &workflowBuilder{} + graph.downstreamNodes = map[string]sets.String{ + v1alpha1.StartNodeID: newNodeIDSet("1"), + "1": newNodeIDSet("5", "2"), + "2": newNodeIDSet("3"), + "3": newNodeIDSet("4"), + "4": newNodeIDSet(v1alpha1.EndNodeID), + } + + for range graph.downstreamNodes { + graph.Nodes = common.NewNodeIndex(graph.NewNodeBuilder(nil)) + } + + errs := errors.NewCompileErrors() + assert.False(t, graph.validateReachable(errs)) + assert.True(t, errs.HasErrors()) +} + +func TestValidateUnderlyingInterface(parentT *testing.T) { + graphIface := &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + } + + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: graphIface, + Nodes: []*core.Node{ + { + Id: "node_123", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + }, + }, + } + + taskIface := &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + "y": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + } + + inputTasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, + Metadata: &core.TaskMetadata{}, + Interface: taskIface, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "Image://", + Command: []string{"blah"}, + Args: []string{"bloh"}, + }, + }, + }, + } + + errs := errors.NewCompileErrors() + compiledTasks := make([]common.Task, 0, len(inputTasks)) + for _, inputTask := range inputTasks { + t, _ := compileTaskInternal(inputTask, errs) + compiledTasks = append(compiledTasks, t) + assert.False(parentT, errs.HasErrors()) + if errs.HasErrors() { + assert.FailNow(parentT, errs.Error()) + } + } + + g := newWorfklowBuilder( + &core.CompiledWorkflow{Template: inputWorkflow}, + mustBuildWorkflowIndex(inputWorkflow), + common.NewTaskIndex(compiledTasks...), + map[string]common.InterfaceProvider{}) + (&g).Tasks = common.NewTaskIndex(compiledTasks...) + + parentT.Run("TaskNode", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: inputWorkflow.Nodes[0]}, errs) + assert.True(t, ifaceOk) + assert.False(t, errs.HasErrors()) + assert.Equal(t, taskIface, iface) + }) + + parentT.Run("GraphNode", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: inputWorkflow.Id, + }, + }, + }, + }}, errs) + assert.True(t, ifaceOk) + assert.False(t, errs.HasErrors()) + assert.Equal(t, graphIface, iface) + }) + + parentT.Run("BranchNode", func(branchT *testing.T) { + branchT.Run("OneCase", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: &core.Node{ + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: inputWorkflow.Nodes[0], + }, + }, + }, + }, + }}, errs) + assert.True(t, ifaceOk) + assert.False(t, errs.HasErrors()) + assert.Equal(t, taskIface, iface) + }) + + branchT.Run("TwoCases", func(t *testing.T) { + errs := errors.NewCompileErrors() + _, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: &core.Node{ + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: inputWorkflow.Nodes[0], + }, + Other: []*core.IfBlock{ + { + ThenNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: inputWorkflow.Id, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }}, errs) + assert.False(t, ifaceOk) + assert.True(t, errs.HasErrors()) + }) + }) +} + +func TestCompileWorkflow(t *testing.T) { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + }, + Nodes: []*core.Node{ + { + Id: "node_123", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newIntegerBinding(123, "x"), newIntegerBinding(123, "y"), + }, + }, + { + Id: "node_456", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newIntegerBinding(123, "y"), newVarBinding("node_123", "x", "x"), + }, + UpstreamNodeIds: []string{"node_123"}, + }, + }, + Outputs: []*core.Binding{newVarBinding("node_456", "x", "x")}, + } + + inputTasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, Metadata: &core.TaskMetadata{}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + "y": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Command: []string{}, + Image: "image://123", + }, + }, + }, + } + + errors.SetConfig(errors.Config{PanicOnError: true}) + defer errors.SetConfig(errors.Config{}) + output, errs := CompileWorkflow(inputWorkflow, []*core.WorkflowTemplate{}, mustCompileTasks(inputTasks), []common.InterfaceProvider{}) + assert.NoError(t, errs) + assert.NotNil(t, output) + if output != nil { + t.Logf("Graph Repr: %v", visualize.ToGraphViz(output.Primary)) + + assert.Equal(t, []string{"node_123"}, output.Primary.Connections.Upstream["node_456"].Ids) + } +} + +func mustCompileTasks(tasks []*core.TaskTemplate) []*core.CompiledTask { + res := make([]*core.CompiledTask, 0, len(tasks)) + for _, t := range tasks { + compiledT, err := CompileTask(t) + if err != nil { + panic(err) + } + + res = append(res, compiledT) + } + return res +} + +func mustBuildWorkflowIndex(wfs ...*core.WorkflowTemplate) common.WorkflowIndex { + compiledWfs := make([]*core.CompiledWorkflow, 0, len(wfs)) + for _, wf := range wfs { + compiledWfs = append(compiledWfs, &core.CompiledWorkflow{Template: wf}) + } + + err := errors.NewCompileErrors() + if index, ok := common.NewWorkflowIndex(compiledWfs, err); !ok { + panic(err) + } else { + return index + } +} diff --git a/pkg/controller/catalog/catalog_client.go b/pkg/controller/catalog/catalog_client.go new file mode 100644 index 000000000..53a9b8aa5 --- /dev/null +++ b/pkg/controller/catalog/catalog_client.go @@ -0,0 +1,28 @@ +package catalog + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +type Client interface { + Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) + Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error +} + +func NewCatalogClient(store storage.ProtobufStore) Client { + catalogConfig := GetConfig() + + var catalogClient Client + if catalogConfig.Type == LegacyDiscoveryType { + catalogClient = NewLegacyDiscovery(catalogConfig.Endpoint, store) + } else if catalogConfig.Type == NoOpDiscoveryType { + catalogClient = NewNoOpDiscovery() + } + + logger.Infof(context.Background(), "Created Catalog client, type: %v", catalogConfig.Type) + return catalogClient +} diff --git a/pkg/controller/catalog/config_flags.go b/pkg/controller/catalog/config_flags.go new file mode 100755 index 000000000..d67dd751a --- /dev/null +++ b/pkg/controller/catalog/config_flags.go @@ -0,0 +1,47 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package catalog + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Discovery Implementation to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), defaultConfig.Endpoint, " Endpoint for discovery service") + return cmdFlags +} diff --git a/pkg/controller/catalog/config_flags_test.go b/pkg/controller/catalog/config_flags_test.go new file mode 100755 index 000000000..a2538822b --- /dev/null +++ b/pkg/controller/catalog/config_flags_test.go @@ -0,0 +1,146 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package catalog + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("type"); err == nil { + assert.Equal(t, string(defaultConfig.Type), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("type", testValue) + if vString, err := cmdFlags.GetString("type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("endpoint"); err == nil { + assert.Equal(t, string(defaultConfig.Endpoint), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("endpoint", testValue) + if vString, err := cmdFlags.GetString("endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/pkg/controller/catalog/discovery_config.go b/pkg/controller/catalog/discovery_config.go new file mode 100644 index 000000000..bbc1bab8f --- /dev/null +++ b/pkg/controller/catalog/discovery_config.go @@ -0,0 +1,34 @@ +package catalog + +import ( + "github.com/lyft/flytestdlib/config" +) + +//go:generate pflags Config --default-var defaultConfig + +const ConfigSectionKey = "catalog-cache" + +var ( + defaultConfig = &Config{ + Type: NoOpDiscoveryType, + } + + configSection = config.MustRegisterSection(ConfigSectionKey, defaultConfig) +) + +type DiscoveryType = string + +const ( + NoOpDiscoveryType DiscoveryType = "noop" + LegacyDiscoveryType DiscoveryType = "legacy" +) + +type Config struct { + Type DiscoveryType `json:"type" pflag:"\"noop\",Discovery Implementation to use"` + Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for discovery service"` +} + +// Gets loaded config for Discovery +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} diff --git a/pkg/controller/catalog/legacy_discovery.go b/pkg/controller/catalog/legacy_discovery.go new file mode 100644 index 000000000..5383c466f --- /dev/null +++ b/pkg/controller/catalog/legacy_discovery.go @@ -0,0 +1,202 @@ +package catalog + +import ( + "context" + "encoding/base64" + "fmt" + "time" + + "github.com/golang/protobuf/proto" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/datacatalog" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/pbhash" + "github.com/lyft/flytestdlib/storage" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +const maxGrpcMsgSizeBytes = 41943040 + +// LegacyDiscovery encapsulates interactions with the Discovery service using a protobuf provided gRPC client. +type LegacyDiscovery struct { + client datacatalog.ArtifactsClient + store storage.ProtobufStore +} + +// Hash each value in the map and return it as the parameter value to be used to generate the Provenance. +func TransformToInputParameters(ctx context.Context, m *core.LiteralMap) ([]*datacatalog.Parameter, error) { + var params = []*datacatalog.Parameter{} + + // Note: The Discovery service will ensure that the output parameters are sorted so the hash is consistent. + // If the values of the literalmap are also a map, pbhash ensures that maps are deterministically hashed as well + for k, typedValue := range m.GetLiterals() { + inputHash, err := pbhash.ComputeHashString(ctx, typedValue) + if err != nil { + return nil, err + } + params = append(params, &datacatalog.Parameter{ + Name: k, + Value: inputHash, + }) + } + + return params, nil +} + +func TransformToOutputParameters(ctx context.Context, m *core.LiteralMap) ([]*datacatalog.Parameter, error) { + var params = []*datacatalog.Parameter{} + for k, typedValue := range m.GetLiterals() { + bytes, err := proto.Marshal(typedValue) + + if err != nil { + return nil, err + } + params = append(params, &datacatalog.Parameter{ + Name: k, + Value: base64.StdEncoding.EncodeToString(bytes), + }) + } + return params, nil +} + +func TransformFromParameters(m []*datacatalog.Parameter) (*core.LiteralMap, error) { + paramsMap := make(map[string]*core.Literal) + + for _, p := range m { + bytes, err := base64.StdEncoding.DecodeString(p.GetValue()) + if err != nil { + return nil, err + } + literal := &core.Literal{} + if err = proto.Unmarshal(bytes, literal); err != nil { + return nil, err + } + paramsMap[p.Name] = literal + } + return &core.LiteralMap{ + Literals: paramsMap, + }, nil +} + +func (d *LegacyDiscovery) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + inputs := &core.LiteralMap{} + + taskInterface := task.Interface + // only download if there are inputs to the task + if taskInterface != nil && taskInterface.Inputs != nil && len(taskInterface.Inputs.Variables) > 0 { + if err := d.store.ReadProtobuf(ctx, inputPath, inputs); err != nil { + return nil, err + } + } + + inputParams, err := TransformToInputParameters(ctx, inputs) + if err != nil { + return nil, err + } + + artifactID := &datacatalog.ArtifactId{ + Name: fmt.Sprintf("%s:%s:%s", task.Id.Project, task.Id.Domain, task.Id.Name), + Version: task.Metadata.DiscoveryVersion, + Inputs: inputParams, + } + options := []grpc.CallOption{ + grpc.MaxCallRecvMsgSize(maxGrpcMsgSizeBytes), + grpc.MaxCallSendMsgSize(maxGrpcMsgSizeBytes), + } + + request := &datacatalog.GetRequest{ + Id: &datacatalog.GetRequest_ArtifactId{ + ArtifactId: artifactID, + }, + } + resp, err := d.client.Get(ctx, request, options...) + + logger.Infof(ctx, "Discovery Get response for artifact |%v|, resp: |%v|, error: %v", artifactID, resp, err) + if err != nil { + return nil, err + } + return TransformFromParameters(resp.Artifact.Outputs) +} + +func GetDefaultGrpcOptions() []grpc_retry.CallOption { + return []grpc_retry.CallOption{ + grpc_retry.WithBackoff(grpc_retry.BackoffLinear(100 * time.Millisecond)), + grpc_retry.WithCodes(codes.DeadlineExceeded, codes.Unavailable, codes.Canceled), + grpc_retry.WithMax(5), + } +} +func (d *LegacyDiscovery) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + inputs := &core.LiteralMap{} + outputs := &core.LiteralMap{} + + taskInterface := task.Interface + // only download if there are inputs to the task + if taskInterface != nil && taskInterface.Inputs != nil && len(taskInterface.Inputs.Variables) > 0 { + if err := d.store.ReadProtobuf(ctx, inputPath, inputs); err != nil { + return err + } + } + + // only download if there are outputs to the task + if taskInterface != nil && taskInterface.Outputs != nil && len(taskInterface.Outputs.Variables) > 0 { + if err := d.store.ReadProtobuf(ctx, outputPath, outputs); err != nil { + return err + } + } + + outputParams, err := TransformToOutputParameters(ctx, outputs) + if err != nil { + return err + } + + inputParams, err := TransformToInputParameters(ctx, inputs) + if err != nil { + return err + } + + artifactID := &datacatalog.ArtifactId{ + Name: fmt.Sprintf("%s:%s:%s", task.Id.Project, task.Id.Domain, task.Id.Name), + Version: task.Metadata.DiscoveryVersion, + Inputs: inputParams, + } + executionID := fmt.Sprintf("%s:%s:%s", execID.GetNodeExecutionId().GetExecutionId().GetProject(), + execID.GetNodeExecutionId().GetExecutionId().GetDomain(), execID.GetNodeExecutionId().GetExecutionId().GetName()) + request := &datacatalog.CreateRequest{ + Ref: artifactID, + ReferenceId: executionID, + Revision: time.Now().Unix(), + Outputs: outputParams, + } + options := []grpc.CallOption{ + grpc.MaxCallRecvMsgSize(maxGrpcMsgSizeBytes), + grpc.MaxCallSendMsgSize(maxGrpcMsgSizeBytes), + } + + resp, err := d.client.Create(ctx, request, options...) + logger.Infof(ctx, "Discovery Put response for artifact |%v|, resp: |%v|, err: %v", artifactID, resp, err) + return err +} + +func NewLegacyDiscovery(discoveryEndpoint string, store storage.ProtobufStore) *LegacyDiscovery { + + // No discovery endpoint passed. Skip client creation. + if discoveryEndpoint == "" { + return nil + } + + opts := GetDefaultGrpcOptions() + retryInterceptor := grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)) + conn, err := grpc.Dial(discoveryEndpoint, grpc.WithInsecure(), retryInterceptor) + + if err != nil { + return nil + } + client := datacatalog.NewArtifactsClient(conn) + + return &LegacyDiscovery{ + client: client, + store: store, + } +} diff --git a/pkg/controller/catalog/legacy_discovery_test.go b/pkg/controller/catalog/legacy_discovery_test.go new file mode 100644 index 000000000..4f37dabc5 --- /dev/null +++ b/pkg/controller/catalog/legacy_discovery_test.go @@ -0,0 +1,286 @@ +package catalog + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/clients/go/datacatalog/mocks" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/datacatalog" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func init() { + labeled.SetMetricKeys(contextutils.TaskIDKey) +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +func TestTransformToInputParameters(t *testing.T) { + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(200) + + params, err := TransformToInputParameters(context.Background(), &core.LiteralMap{ + Literals: paramsMap, + }) + assert.Nil(t, err) + assert.Equal(t, "out1", params[0].Name) + assert.Equal(t, "c6i2T7NODjwnlxmXKRCNDk/AN4pZpRGGFX49kT6DT/c=", params[0].Value) +} + +func TestTransformToOutputParameters(t *testing.T) { + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + + params, err := TransformToOutputParameters(context.Background(), &core.LiteralMap{ + Literals: paramsMap, + }) + assert.Nil(t, err) + assert.Equal(t, "out1", params[0].Name) + assert.Equal(t, "CgQKAghk", params[0].Value) +} + +func TestTransformFromParameters(t *testing.T) { + params := []*datacatalog.Parameter{ + {Name: "out1", Value: "CgQKAghk"}, + } + literalMap, err := TransformFromParameters(params) + assert.Nil(t, err) + + val, exists := literalMap.Literals["out1"] + assert.True(t, exists) + assert.Equal(t, int64(100), val.GetScalar().GetPrimitive().GetInteger()) +} + +func TestLegacyDiscovery_Get(t *testing.T) { + ctx := context.Background() + + paramMap := &core.LiteralMap{Literals: map[string]*core.Literal{ + "out1": newIntegerLiteral(100), + }} + task := &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "0.0.1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + } + + inputPath := storage.DataReference("test-data/inputs.pb") + + t.Run("notfound", func(t *testing.T) { + mockClient := &mocks.ArtifactsClient{} + store := createInmemoryDataStore(t, promutils.NewScope("get_test_notfound")) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + discovery := LegacyDiscovery{client: mockClient, store: store} + mockClient.On("Get", + ctx, + mock.MatchedBy(func(o *datacatalog.GetRequest) bool { + assert.Equal(t, o.GetArtifactId().Name, "project:domain:name") + params, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetArtifactId().Inputs, params) + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(nil, status.Error(codes.NotFound, "")) + resp, err := discovery.Get(ctx, task, inputPath) + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("found", func(t *testing.T) { + mockClient := &mocks.ArtifactsClient{} + store := createInmemoryDataStore(t, promutils.NewScope("get_test_found")) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + discovery := LegacyDiscovery{client: mockClient, store: store} + outputs, err := TransformToOutputParameters(context.Background(), paramMap) + assert.NoError(t, err) + response := &datacatalog.GetResponse{ + Artifact: &datacatalog.Artifact{ + Outputs: outputs, + }, + } + mockClient.On("Get", + ctx, + mock.MatchedBy(func(o *datacatalog.GetRequest) bool { + assert.Equal(t, o.GetArtifactId().Name, "project:domain:name") + assert.Equal(t, o.GetArtifactId().Version, "0.0.1") + params, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetArtifactId().Inputs, params) + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(response, nil) + resp, err := discovery.Get(ctx, task, inputPath) + assert.NoError(t, err) + assert.NotNil(t, resp) + val, exists := resp.Literals["out1"] + assert.True(t, exists) + assert.Equal(t, int64(100), val.GetScalar().GetPrimitive().GetInteger()) + }) +} + +func TestLegacyDiscovery_Put(t *testing.T) { + ctx := context.Background() + + inputPath := storage.DataReference("test-data/inputs.pb") + outputPath := storage.DataReference("test-data/ouputs.pb") + + paramMap := &core.LiteralMap{Literals: map[string]*core.Literal{ + "out1": newIntegerLiteral(100), + }} + task := &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "0.0.1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + } + + execID := &core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "runID", + }, + }, + } + + t.Run("failed", func(t *testing.T) { + mockClient := &mocks.ArtifactsClient{} + store := createInmemoryDataStore(t, promutils.NewScope("put_test_failed")) + discovery := LegacyDiscovery{client: mockClient, store: store} + mockClient.On("Create", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateRequest) bool { + assert.Equal(t, o.GetRef().Name, "project:domain:name") + assert.Equal(t, o.GetReferenceId(), "project:domain:runID") + inputs, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + outputs, err := TransformToOutputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetRef().Inputs, inputs) + assert.Equal(t, o.GetOutputs(), outputs) + + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(nil, status.Error(codes.AlreadyExists, "")) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + err = store.WriteProtobuf(ctx, outputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + err = discovery.Put(ctx, task, execID, inputPath, outputPath) + assert.Error(t, err) + }) + + t.Run("success", func(t *testing.T) { + store := createInmemoryDataStore(t, promutils.NewScope("put_test_success")) + mockClient := &mocks.ArtifactsClient{} + discovery := LegacyDiscovery{client: mockClient, store: store} + mockClient.On("Create", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateRequest) bool { + assert.Equal(t, o.GetRef().Name, "project:domain:name") + assert.Equal(t, o.GetRef().Version, "0.0.1") + inputs, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + outputs, err := TransformToOutputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetRef().Inputs, inputs) + assert.Equal(t, o.GetOutputs(), outputs) + + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(nil, nil) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + err = store.WriteProtobuf(ctx, outputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + err = discovery.Put(ctx, task, execID, inputPath, outputPath) + assert.NoError(t, err) + }) +} diff --git a/pkg/controller/catalog/mock_catalog.go b/pkg/controller/catalog/mock_catalog.go new file mode 100644 index 000000000..03885ff61 --- /dev/null +++ b/pkg/controller/catalog/mock_catalog.go @@ -0,0 +1,21 @@ +package catalog + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/storage" +) + +type MockCatalogClient struct { + GetFunc func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) + PutFunc func(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error +} + +func (m *MockCatalogClient) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return m.GetFunc(ctx, task, inputPath) +} + +func (m *MockCatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return m.PutFunc(ctx, task, execID, inputPath, outputPath) +} diff --git a/pkg/controller/catalog/no_op_discovery.go b/pkg/controller/catalog/no_op_discovery.go new file mode 100644 index 000000000..9d3fc9332 --- /dev/null +++ b/pkg/controller/catalog/no_op_discovery.go @@ -0,0 +1,29 @@ +package catalog + +import ( + "context" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// NoOpDiscovery +type NoOpDiscovery struct{} + +func (d *NoOpDiscovery) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + logger.Infof(ctx, "No-op Discovery Get invoked. Returning NotFound") + return nil, status.Error(codes.NotFound, "No-op Discovery default behavior.") +} + +func (d *NoOpDiscovery) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + logger.Infof(ctx, "No-op Discovery Put invoked. Doing nothing") + return nil +} + +func NewNoOpDiscovery() *NoOpDiscovery { + return &NoOpDiscovery{} +} diff --git a/pkg/controller/catalog/no_op_discovery_test.go b/pkg/controller/catalog/no_op_discovery_test.go new file mode 100644 index 000000000..ded727c87 --- /dev/null +++ b/pkg/controller/catalog/no_op_discovery_test.go @@ -0,0 +1,26 @@ +package catalog + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var noopDiscovery Client = &NoOpDiscovery{} + +func TestNoopDiscovery_Get(t *testing.T) { + ctx := context.Background() + resp, err := noopDiscovery.Get(ctx, nil, "") + assert.Nil(t, resp) + assert.Error(t, err) + assert.True(t, status.Code(err) == codes.NotFound) +} + +func TestNoopDiscovery_Put(t *testing.T) { + ctx := context.Background() + err := noopDiscovery.Put(ctx, nil, nil, "", "") + assert.Nil(t, err) +} diff --git a/pkg/controller/completed_workflows.go b/pkg/controller/completed_workflows.go new file mode 100644 index 000000000..43197c715 --- /dev/null +++ b/pkg/controller/completed_workflows.go @@ -0,0 +1,87 @@ +package controller + +import ( + "strconv" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const controllerAgentName = "flyteworkflow-controller" +const workflowTerminationStatusKey = "termination-status" +const workflowTerminatedValue = "terminated" +const hourOfDayCompletedKey = "hour-of-day" + +// This function creates a label selector, that will ignore all objects (in this case workflow) that DOES NOT have a +// label key=workflowTerminationStatusKey with a value=workflowTerminatedValue +func IgnoreCompletedWorkflowsLabelSelector() *v1.LabelSelector { + return &v1.LabelSelector{ + MatchExpressions: []v1.LabelSelectorRequirement{ + { + Key: workflowTerminationStatusKey, + Operator: v1.LabelSelectorOpNotIn, + Values: []string{workflowTerminatedValue}, + }, + }, + } +} + +// Creates a new LabelSelector that selects all workflows that have the completed Label +func CompletedWorkflowsLabelSelector() *v1.LabelSelector { + return &v1.LabelSelector{ + MatchLabels: map[string]string{ + workflowTerminationStatusKey: workflowTerminatedValue, + }, + } +} + +func SetCompletedLabel(w *v1alpha1.FlyteWorkflow, currentTime time.Time) { + if w.Labels == nil { + w.Labels = make(map[string]string) + } + w.Labels[workflowTerminationStatusKey] = workflowTerminatedValue + w.Labels[hourOfDayCompletedKey] = strconv.Itoa(currentTime.Hour()) +} + +func HasCompletedLabel(w *v1alpha1.FlyteWorkflow) bool { + if w.Labels != nil { + v, ok := w.Labels[workflowTerminationStatusKey] + if ok { + return v == workflowTerminatedValue + } + } + return false +} + +// Calculates a list of all the hours that should be deleted given the current hour of the day and the retentionperiod in hours +// Usually this is a list of all hours out of the 24 hours in the day - retention period - the current hour of the day +func CalculateHoursToDelete(retentionPeriodHours, currentHourOfDay int) []string { + numberOfHoursToDelete := 24 - retentionPeriodHours + hoursToDelete := make([]string, 0, numberOfHoursToDelete) + + for i := 0; i < currentHourOfDay-retentionPeriodHours; i++ { + hoursToDelete = append(hoursToDelete, strconv.Itoa(i)) + } + maxHourOfDay := 24 + if currentHourOfDay-retentionPeriodHours < 0 { + maxHourOfDay = 24 + (currentHourOfDay - retentionPeriodHours) + } + for i := currentHourOfDay + 1; i < maxHourOfDay; i++ { + hoursToDelete = append(hoursToDelete, strconv.Itoa(i)) + } + return hoursToDelete +} + +// Creates a new selector that selects all completed workflows and workflows with completed hour label outside of the +// retention window +func CompletedWorkflowsSelectorOutsideRetentionPeriod(retentionPeriodHours int, currentTime time.Time) *v1.LabelSelector { + hoursToDelete := CalculateHoursToDelete(retentionPeriodHours, currentTime.Hour()) + s := CompletedWorkflowsLabelSelector() + s.MatchExpressions = append(s.MatchExpressions, v1.LabelSelectorRequirement{ + Key: hourOfDayCompletedKey, + Operator: v1.LabelSelectorOpIn, + Values: hoursToDelete, + }) + return s +} diff --git a/pkg/controller/completed_workflows_test.go b/pkg/controller/completed_workflows_test.go new file mode 100644 index 000000000..7b57d938d --- /dev/null +++ b/pkg/controller/completed_workflows_test.go @@ -0,0 +1,160 @@ +package controller + +import ( + "testing" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestIgnoreCompletedWorkflowsLabelSelector(t *testing.T) { + s := IgnoreCompletedWorkflowsLabelSelector() + assert.NotNil(t, s) + assert.Empty(t, s.MatchLabels) + assert.NotEmpty(t, s.MatchExpressions) + r := s.MatchExpressions[0] + assert.Equal(t, workflowTerminationStatusKey, r.Key) + assert.Equal(t, v1.LabelSelectorOpNotIn, r.Operator) + assert.Equal(t, []string{workflowTerminatedValue}, r.Values) +} + +func TestCompletedWorkflowsLabelSelector(t *testing.T) { + s := CompletedWorkflowsLabelSelector() + assert.NotEmpty(t, s.MatchLabels) + v, ok := s.MatchLabels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) +} + +func TestHasCompletedLabel(t *testing.T) { + + n := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + t.Run("no-labels", func(t *testing.T) { + + w := &v1alpha1.FlyteWorkflow{} + assert.Empty(t, w.Labels) + assert.False(t, HasCompletedLabel(w)) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + assert.True(t, HasCompletedLabel(w)) + }) + + t.Run("existing-lables", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Labels: map[string]string{ + "x": "v", + }, + }, + } + assert.NotEmpty(t, w.Labels) + assert.False(t, HasCompletedLabel(w)) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + v, ok = w.Labels["x"] + assert.True(t, ok) + assert.Equal(t, "v", v) + assert.True(t, HasCompletedLabel(w)) + }) +} + +func TestSetCompletedLabel(t *testing.T) { + n := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + t.Run("no-labels", func(t *testing.T) { + + w := &v1alpha1.FlyteWorkflow{} + assert.Empty(t, w.Labels) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + }) + + t.Run("existing-lables", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Labels: map[string]string{ + "x": "v", + }, + }, + } + assert.NotEmpty(t, w.Labels) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + v, ok = w.Labels["x"] + assert.True(t, ok) + assert.Equal(t, "v", v) + }) + +} + +func TestCalculateHoursToDelete(t *testing.T) { + assert.Equal(t, []string{ + "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", + }, CalculateHoursToDelete(6, 5)) + + assert.Equal(t, []string{ + "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(6, 6)) + + assert.Equal(t, []string{ + "0", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(6, 7)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "23", + }, CalculateHoursToDelete(6, 22)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", + }, CalculateHoursToDelete(6, 23)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", + }, CalculateHoursToDelete(0, 23)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "21", "22", "23", + }, CalculateHoursToDelete(0, 20)) + + assert.Equal(t, []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(0, 0)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(0, 12)) + + assert.Equal(t, []string{"13"}, CalculateHoursToDelete(22, 12)) + assert.Equal(t, []string{"1"}, CalculateHoursToDelete(22, 0)) + assert.Equal(t, []string{"0"}, CalculateHoursToDelete(22, 23)) + assert.Equal(t, []string{"23"}, CalculateHoursToDelete(22, 22)) +} + +func TestCompletedWorkflowsSelectorOutsideRetentionPeriod(t *testing.T) { + n := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + s := CompletedWorkflowsSelectorOutsideRetentionPeriod(2, n) + v, ok := s.MatchLabels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + assert.NotEmpty(t, s.MatchExpressions) + r := s.MatchExpressions[0] + assert.Equal(t, hourOfDayCompletedKey, r.Key) + assert.Equal(t, v1.LabelSelectorOpIn, r.Operator) + assert.Equal(t, 21, len(r.Values)) + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, r.Values) +} diff --git a/pkg/controller/composite_workqueue.go b/pkg/controller/composite_workqueue.go new file mode 100644 index 000000000..a03e79e11 --- /dev/null +++ b/pkg/controller/composite_workqueue.go @@ -0,0 +1,172 @@ +package controller + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/workqueue" +) + +// A CompositeWorkQueue can be used in cases where the work is enqueued by two sources. It can be enqueued by either +// 1. Informer for the Primary Object itself. In case of FlytePropeller, this is the workflow object +// 2. Informer or any other process that enqueues the top-level object for re-evaluation in response to one of the +// sub-objects being ready. In the case of FlytePropeller this is the "Node/Task" updates, will re-enqueue the workflow +// to be re-evaluated +type CompositeWorkQueue interface { + workqueue.RateLimitingInterface + // Specialized interface that should be called to start the migration of work from SubQueue to primaryQueue + Start(ctx context.Context) + // Shutsdown all the queues that are in the context + ShutdownAll() + // Adds the item explicitly to the subqueue + AddToSubQueue(item interface{}) + // Adds the item explicitly to the subqueue, using a rate limiter + AddToSubQueueRateLimited(item interface{}) + // Adds the item explicitly to the subqueue after some duration + AddToSubQueueAfter(item interface{}, duration time.Duration) +} + +// SimpleWorkQueue provides a simple RateLimitingInterface, but ensures that the compositeQueue interface works +// with a default queue. +type SimpleWorkQueue struct { + // workqueue is a rate limited work queue. This is used to queue work to be + // processed instead of performing it as soon as a change happens. This + // means we can ensure we only process a fixed amount of resources at a + // time, and makes it easy to ensure we are never processing the same item + // simultaneously in two different workers. + workqueue.RateLimitingInterface +} + +func (s *SimpleWorkQueue) Start(ctx context.Context) { +} + +func (s *SimpleWorkQueue) ShutdownAll() { + s.ShutDown() +} + +func (s *SimpleWorkQueue) AddToSubQueue(item interface{}) { + s.Add(item) +} + +func (s *SimpleWorkQueue) AddToSubQueueAfter(item interface{}, duration time.Duration) { + s.AddAfter(item, duration) +} + +func (s *SimpleWorkQueue) AddToSubQueueRateLimited(item interface{}) { + s.AddRateLimited(item) +} + +// A BatchingWorkQueue consists of 2 queues and migrates items from sub-queue to parent queue as a batch at a specified +// interval +type BatchingWorkQueue struct { + // workqueue is a rate limited work queue. This is used to queue work to be + // processed instead of performing it as soon as a change happens. This + // means we can ensure we only process a fixed amount of resources at a + // time, and makes it easy to ensure we are never processing the same item + // simultaneously in two different workers. + workqueue.RateLimitingInterface + + subQueue workqueue.RateLimitingInterface + batchingInterval time.Duration + batchSize int +} + +func (b *BatchingWorkQueue) Start(ctx context.Context) { + logger.Infof(ctx, "Batching queue started") + go wait.Until(func() { + b.runSubQueueHandler(ctx) + }, b.batchingInterval, ctx.Done()) +} + +func (b *BatchingWorkQueue) runSubQueueHandler(ctx context.Context) { + logger.Debugf(ctx, "Subqueue handler batch round") + defer logger.Debugf(ctx, "Exiting SubQueue handler batch round") + if b.subQueue.ShuttingDown() { + return + } + numToRetrieve := b.batchSize + if b.batchSize == -1 || b.batchSize > b.subQueue.Len() { + numToRetrieve = b.subQueue.Len() + } + + logger.Debugf(ctx, "Dynamically configured batch size [%d]", b.batchSize) + // Run batches forever + objectsRetrieved := make([]interface{}, numToRetrieve) + for i := 0; i < numToRetrieve; i++ { + obj, shutdown := b.subQueue.Get() + if obj != nil { + // We expect strings to come off the workqueue. These are of the + // form namespace/name. We do this as the delayed nature of the + // workqueue means the items in the informer cache may actually be + // more up to date that when the item was initially put onto the + // workqueue. + if key, ok := obj.(string); ok { + objectsRetrieved[i] = key + } + } + if shutdown { + logger.Warningf(ctx, "NodeQ shutdown invoked. Shutting down poller.") + // We cannot add after shutdown, so just quit! + return + } + + } + + for _, obj := range objectsRetrieved { + b.Add(obj) + // Finally, if no error occurs we Forget this item so it does not + // get queued again until another change happens. + b.subQueue.Forget(obj) + b.subQueue.Done(obj) + } + +} + +func (b *BatchingWorkQueue) ShutdownAll() { + b.subQueue.ShutDown() + b.ShutDown() +} + +func (b *BatchingWorkQueue) AddToSubQueue(item interface{}) { + b.subQueue.Add(item) +} + +func (b *BatchingWorkQueue) AddToSubQueueAfter(item interface{}, duration time.Duration) { + b.subQueue.AddAfter(item, duration) +} + +func (b *BatchingWorkQueue) AddToSubQueueRateLimited(item interface{}) { + b.subQueue.AddRateLimited(item) +} + +func NewCompositeWorkQueue(ctx context.Context, cfg config.CompositeQueueConfig, scope promutils.Scope) (CompositeWorkQueue, error) { + workQ, err := NewWorkQueue(ctx, cfg.Queue, scope.CurrentScope()) + if err != nil { + return nil, errors.Wrapf(err, "failed to create WorkQueue in CompositeQueue type Batch") + } + switch cfg.Type { + case config.CompositeQueueBatch: + subQ, err := NewWorkQueue(ctx, cfg.Sub, scope.NewSubScope("sub").CurrentScope()) + if err != nil { + return nil, errors.Wrapf(err, "failed to create SubQueue in CompositeQueue type Batch") + } + return &BatchingWorkQueue{ + RateLimitingInterface: workQ, + batchSize: cfg.BatchSize, + batchingInterval: cfg.BatchingInterval.Duration, + subQueue: subQ, + }, nil + case config.CompositeQueueSimple: + fallthrough + default: + } + return &SimpleWorkQueue{ + RateLimitingInterface: workQ, + }, nil +} diff --git a/pkg/controller/composite_workqueue_test.go b/pkg/controller/composite_workqueue_test.go new file mode 100644 index 000000000..4239102ad --- /dev/null +++ b/pkg/controller/composite_workqueue_test.go @@ -0,0 +1,146 @@ +package controller + +import ( + "context" + "testing" + "time" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestNewCompositeWorkQueue(t *testing.T) { + ctx := context.TODO() + + t.Run("simple", func(t *testing.T) { + testScope := promutils.NewScope("test1") + cfg := config2.CompositeQueueConfig{} + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + switch q.(type) { + case *SimpleWorkQueue: + return + default: + assert.FailNow(t, "SimpleWorkQueue expected") + } + }) + + t.Run("batch", func(t *testing.T) { + testScope := promutils.NewScope("test2") + cfg := config2.CompositeQueueConfig{ + Type: config2.CompositeQueueBatch, + BatchSize: -1, + BatchingInterval: config.Duration{Duration: time.Second * 1}, + } + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + switch q.(type) { + case *BatchingWorkQueue: + assert.Equal(t, -1, q.(*BatchingWorkQueue).batchSize) + assert.Equal(t, time.Second*1, q.(*BatchingWorkQueue).batchingInterval) + return + default: + assert.FailNow(t, "BatchWorkQueue expected") + } + }) +} + +func TestSimpleWorkQueue(t *testing.T) { + ctx := context.TODO() + testScope := promutils.NewScope("test") + cfg := config2.CompositeQueueConfig{} + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + + t.Run("AddSubQueue", func(t *testing.T) { + q.AddToSubQueue("x") + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "x", i.(string)) + q.Done(i) + }) + + t.Run("AddAfterSubQueue", func(t *testing.T) { + q.AddToSubQueueAfter("y", time.Nanosecond*0) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "y", i.(string)) + q.Done(i) + }) + + t.Run("AddRateLimitedSubQueue", func(t *testing.T) { + q.AddToSubQueueRateLimited("z") + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "z", i.(string)) + q.Done(i) + }) + + t.Run("shutdown", func(t *testing.T) { + q.ShutdownAll() + _, s := q.Get() + assert.True(t, s) + }) +} + +func TestBatchingQueue(t *testing.T) { + ctx := context.TODO() + testScope := promutils.NewScope("test_batch") + cfg := config2.CompositeQueueConfig{ + Type: config2.CompositeQueueBatch, + BatchSize: -1, + BatchingInterval: config.Duration{Duration: time.Nanosecond * 1}, + } + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + + batchQueue := q.(*BatchingWorkQueue) + + t.Run("AddSubQueue", func(t *testing.T) { + q.AddToSubQueue("x") + assert.Equal(t, 0, q.Len()) + batchQueue.runSubQueueHandler(ctx) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "x", i.(string)) + q.Done(i) + }) + + t.Run("AddAfterSubQueue", func(t *testing.T) { + q.AddToSubQueueAfter("y", time.Nanosecond*0) + assert.Equal(t, 0, q.Len()) + batchQueue.runSubQueueHandler(ctx) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "y", i.(string)) + q.Done(i) + }) + + t.Run("AddRateLimitedSubQueue", func(t *testing.T) { + q.AddToSubQueueRateLimited("z") + assert.Equal(t, 0, q.Len()) + batchQueue.Start(ctx) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "z", i.(string)) + q.Done(i) + }) + + t.Run("shutdown", func(t *testing.T) { + q.AddToSubQueue("g") + q.ShutdownAll() + assert.Equal(t, 0, q.Len()) + batchQueue.runSubQueueHandler(ctx) + i, s := q.Get() + assert.True(t, s) + assert.Nil(t, i) + q.Done(i) + }) +} diff --git a/pkg/controller/config/config.go b/pkg/controller/config/config.go new file mode 100644 index 000000000..306290634 --- /dev/null +++ b/pkg/controller/config/config.go @@ -0,0 +1,93 @@ +package config + +import ( + "github.com/lyft/flytestdlib/config" + "k8s.io/apimachinery/pkg/types" +) + +//go:generate pflags Config + +const configSectionKey = "propeller" + +var ConfigSection = config.MustRegisterSection(configSectionKey, &Config{}) + +// NOTE: when adding new fields, do not mark them as "omitempty" if it's desirable to read the value from env variables. +// Config that uses the flytestdlib Config module to generate commandline and load config files. This configuration is +// the base configuration to start propeller +type Config struct { + KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` + MasterURL string `json:"master"` + Workers int `json:"workers" pflag:"2,Number of threads to process workflows"` + WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:"\"30s\",Frequency of re-evaluating workflows"` + DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:"\"60s\",Frequency of re-evaluating downstream tasks"` + LimitNamespace string `json:"limit-namespace" pflag:"\"all\",Namespaces to watch for this propeller"` + ProfilerPort config.Port `json:"prof-port" pflag:"\"10254\",Profiler port"` + MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` + Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + MetricsPrefix string `json:"metrics-prefix" pflag:"\"flyte:\",An optional prefix for all published metrics."` + EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"false, Enable remote Workflow launcher to Admin"` + MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"50,Maximum number of retries per workflow"` + MaxTTLInHours int `json:"max-ttl-hours" pflag:"23,Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` + GCInterval config.Duration `json:"gc-interval" pflag:"\"30m\",Run periodic GC every 30 minutes"` + LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` + PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` +} + +type CompositeQueueType = string + +const ( + CompositeQueueSimple CompositeQueueType = "simple" + CompositeQueueBatch CompositeQueueType = "batch" +) + +type CompositeQueueConfig struct { + Type CompositeQueueType `json:"type" pflag:"\"simple\",Type of composite queue to use for the WorkQueue"` + Queue WorkqueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + Sub WorkqueueConfig `json:"sub-queue,omitempty" pflag:",SubQueue configuration, affects the way the nodes cause the top-level Work to be re-evaluated."` + BatchingInterval config.Duration `json:"batching-interval" pflag:"\"1s\",Duration for which downstream updates are buffered"` + BatchSize int `json:"batch-size" pflag:"-1,Number of downstream triggered top-level objects to re-enqueue every duration. -1 indicates all available."` +} + +type WorkqueueType = string + +const ( + WorkqueueTypeDefault WorkqueueType = "default" + WorkqueueTypeBucketRateLimiter WorkqueueType = "bucket" + WorkqueueTypeExponentialFailureRateLimiter WorkqueueType = "expfailure" + WorkqueueTypeMaxOfRateLimiter WorkqueueType = "maxof" +) + +// prototypical configuration to configure a workqueue. We may want to generalize this in a package like k8sutils +type WorkqueueConfig struct { + // Refer to https://github.com/kubernetes/client-go/tree/master/util/workqueue + Type WorkqueueType `json:"type" pflag:"\"default\",Type of RateLimiter to use for the WorkQueue"` + BaseDelay config.Duration `json:"base-delay" pflag:"\"10s\",base backoff delay for failure"` + MaxDelay config.Duration `json:"max-delay" pflag:"\"10s\",Max backoff delay for failure"` + Rate int64 `json:"rate" pflag:"int64(10),Bucket Refill rate per second"` + Capacity int `json:"capacity" pflag:"100,Bucket capacity as number of items"` +} + +// Contains leader election configuration. +type LeaderElectionConfig struct { + // Enable or disable leader election. + Enabled bool `json:"enabled" pflag:",Enables/Disables leader election."` + + // Determines the name of the configmap that leader election will use for holding the leader lock. + LockConfigMap types.NamespacedName `json:"lock-config-map" pflag:",ConfigMap namespace/name to use for resource lock."` + + // Duration that non-leader candidates will wait to force acquire leadership. This is measured against time of last + // observed ack + LeaseDuration config.Duration `json:"lease-duration" pflag:"\"15s\",Duration that non-leader candidates will wait to force acquire leadership. This is measured against time of last observed ack."` + + // RenewDeadline is the duration that the acting master will retry refreshing leadership before giving up. + RenewDeadline config.Duration `json:"renew-deadline" pflag:"\"10s\",Duration that the acting master will retry refreshing leadership before giving up."` + + // RetryPeriod is the duration the LeaderElector clients should wait between tries of actions. + RetryPeriod config.Duration `json:"retry-period" pflag:"\"2s\",Duration the LeaderElector clients should wait between tries of actions."` +} + +// Extracts the Configuration from the global config module in flytestdlib and returns the corresponding type-casted object. +// TODO What if the type is incorrect? +func GetConfig() *Config { + return ConfigSection.GetConfig().(*Config) +} diff --git a/pkg/controller/config/config_flags.go b/pkg/controller/config/config_flags.go new file mode 100755 index 000000000..f49657203 --- /dev/null +++ b/pkg/controller/config/config_flags.go @@ -0,0 +1,78 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "kube-config"), *new(string), "Path to kubernetes client config file.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "master"), *new(string), "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "workers"), 2, "Number of threads to process workflows") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "workflow-reeval-duration"), "30s", "Frequency of re-evaluating workflows") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "downstream-eval-duration"), "60s", "Frequency of re-evaluating downstream tasks") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "limit-namespace"), "all", "Namespaces to watch for this propeller") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "prof-port"), "10254", "Profiler port") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metadata-prefix"), *new(string), "MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.type"), "simple", "Type of composite queue to use for the WorkQueue") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.queue.type"), "default", "Type of RateLimiter to use for the WorkQueue") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.queue.base-delay"), "10s", "base backoff delay for failure") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.queue.max-delay"), "10s", "Max backoff delay for failure") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "queue.queue.rate"), int64(10), "Bucket Refill rate per second") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "queue.queue.capacity"), 100, "Bucket capacity as number of items") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.type"), "default", "Type of RateLimiter to use for the WorkQueue") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.base-delay"), "10s", "base backoff delay for failure") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.max-delay"), "10s", "Max backoff delay for failure") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.rate"), int64(10), "Bucket Refill rate per second") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.capacity"), 100, "Bucket capacity as number of items") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.batching-interval"), "1s", "Duration for which downstream updates are buffered") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "queue.batch-size"), -1, "Number of downstream triggered top-level objects to re-enqueue every duration. -1 indicates all available.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-prefix"), "flyte:", "An optional prefix for all published metrics.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "enable-admin-launcher"), false, " Enable remote Workflow launcher to Admin") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-workflow-retries"), 50, "Maximum number of retries per workflow") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-ttl-hours"), 23, "Maximum number of hours a completed workflow should be retained. Number between 1-23 hours") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "gc-interval"), "30m", "Run periodic GC every 30 minutes") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "leader-election.enabled"), *new(bool), "Enables/Disables leader election.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.lock-config-map.Namespace"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.lock-config-map.Name"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.lease-duration"), "15s", "Duration that non-leader candidates will wait to force acquire leadership. This is measured against time of last observed ack.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.renew-deadline"), "10s", "Duration that the acting master will retry refreshing leadership before giving up.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.retry-period"), "2s", "Duration the LeaderElector clients should wait between tries of actions.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "publish-k8s-events"), *new(bool), "Enable events publishing to K8s events API.") + return cmdFlags +} diff --git a/pkg/controller/config/config_flags_test.go b/pkg/controller/config/config_flags_test.go new file mode 100755 index 000000000..b7f3ff4a8 --- /dev/null +++ b/pkg/controller/config/config_flags_test.go @@ -0,0 +1,828 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_kube-config", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("kube-config"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("kube-config", testValue) + if vString, err := cmdFlags.GetString("kube-config"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.KubeConfigPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_master", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("master"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("master", testValue) + if vString, err := cmdFlags.GetString("master"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MasterURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("workers"); err == nil { + assert.Equal(t, int(2), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("workers", testValue) + if vInt, err := cmdFlags.GetInt("workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Workers) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_workflow-reeval-duration", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("workflow-reeval-duration"); err == nil { + assert.Equal(t, string("30s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "30s" + + cmdFlags.Set("workflow-reeval-duration", testValue) + if vString, err := cmdFlags.GetString("workflow-reeval-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.WorkflowReEval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_downstream-eval-duration", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("downstream-eval-duration"); err == nil { + assert.Equal(t, string("60s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "60s" + + cmdFlags.Set("downstream-eval-duration", testValue) + if vString, err := cmdFlags.GetString("downstream-eval-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DownstreamEval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_limit-namespace", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("limit-namespace"); err == nil { + assert.Equal(t, string("all"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("limit-namespace", testValue) + if vString, err := cmdFlags.GetString("limit-namespace"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LimitNamespace) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_prof-port", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("prof-port"); err == nil { + assert.Equal(t, string("10254"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10254" + + cmdFlags.Set("prof-port", testValue) + if vString, err := cmdFlags.GetString("prof-port"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ProfilerPort) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_metadata-prefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("metadata-prefix"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("metadata-prefix", testValue) + if vString, err := cmdFlags.GetString("metadata-prefix"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MetadataPrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.type"); err == nil { + assert.Equal(t, string("simple"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.type", testValue) + if vString, err := cmdFlags.GetString("queue.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.queue.type"); err == nil { + assert.Equal(t, string("default"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.queue.type", testValue) + if vString, err := cmdFlags.GetString("queue.queue.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Queue.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.base-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.queue.base-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.queue.base-delay", testValue) + if vString, err := cmdFlags.GetString("queue.queue.base-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Queue.BaseDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.max-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.queue.max-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.queue.max-delay", testValue) + if vString, err := cmdFlags.GetString("queue.queue.max-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Queue.MaxDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.rate", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("queue.queue.rate"); err == nil { + assert.Equal(t, int64(int64(10)), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.queue.rate", testValue) + if vInt64, err := cmdFlags.GetInt64("queue.queue.rate"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.Queue.Queue.Rate) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.capacity", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("queue.queue.capacity"); err == nil { + assert.Equal(t, int(100), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.queue.capacity", testValue) + if vInt, err := cmdFlags.GetInt("queue.queue.capacity"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Queue.Queue.Capacity) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.sub-queue.type"); err == nil { + assert.Equal(t, string("default"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.sub-queue.type", testValue) + if vString, err := cmdFlags.GetString("queue.sub-queue.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Sub.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.base-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.sub-queue.base-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.sub-queue.base-delay", testValue) + if vString, err := cmdFlags.GetString("queue.sub-queue.base-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Sub.BaseDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.max-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.sub-queue.max-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.sub-queue.max-delay", testValue) + if vString, err := cmdFlags.GetString("queue.sub-queue.max-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Sub.MaxDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.rate", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("queue.sub-queue.rate"); err == nil { + assert.Equal(t, int64(int64(10)), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.sub-queue.rate", testValue) + if vInt64, err := cmdFlags.GetInt64("queue.sub-queue.rate"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.Queue.Sub.Rate) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.capacity", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("queue.sub-queue.capacity"); err == nil { + assert.Equal(t, int(100), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.sub-queue.capacity", testValue) + if vInt, err := cmdFlags.GetInt("queue.sub-queue.capacity"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Queue.Sub.Capacity) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.batching-interval", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.batching-interval"); err == nil { + assert.Equal(t, string("1s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1s" + + cmdFlags.Set("queue.batching-interval", testValue) + if vString, err := cmdFlags.GetString("queue.batching-interval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.BatchingInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.batch-size", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("queue.batch-size"); err == nil { + assert.Equal(t, int(-1), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.batch-size", testValue) + if vInt, err := cmdFlags.GetInt("queue.batch-size"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Queue.BatchSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_metrics-prefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { + assert.Equal(t, string("flyte:"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("metrics-prefix", testValue) + if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MetricsPrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_enable-admin-launcher", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("enable-admin-launcher"); err == nil { + assert.Equal(t, bool(false), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("enable-admin-launcher", testValue) + if vBool, err := cmdFlags.GetBool("enable-admin-launcher"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.EnableAdminLauncher) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_max-workflow-retries", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("max-workflow-retries"); err == nil { + assert.Equal(t, int(50), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("max-workflow-retries", testValue) + if vInt, err := cmdFlags.GetInt("max-workflow-retries"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.MaxWorkflowRetries) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_max-ttl-hours", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("max-ttl-hours"); err == nil { + assert.Equal(t, int(23), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("max-ttl-hours", testValue) + if vInt, err := cmdFlags.GetInt("max-ttl-hours"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.MaxTTLInHours) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_gc-interval", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("gc-interval"); err == nil { + assert.Equal(t, string("30m"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "30m" + + cmdFlags.Set("gc-interval", testValue) + if vString, err := cmdFlags.GetString("gc-interval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.GCInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("leader-election.enabled"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("leader-election.enabled", testValue) + if vBool, err := cmdFlags.GetBool("leader-election.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LeaderElection.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.lock-config-map.Namespace", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Namespace"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("leader-election.lock-config-map.Namespace", testValue) + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Namespace"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.LockConfigMap.Namespace) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.lock-config-map.Name", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Name"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("leader-election.lock-config-map.Name", testValue) + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.LockConfigMap.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.lease-duration", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.lease-duration"); err == nil { + assert.Equal(t, string("15s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "15s" + + cmdFlags.Set("leader-election.lease-duration", testValue) + if vString, err := cmdFlags.GetString("leader-election.lease-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.LeaseDuration) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.renew-deadline", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.renew-deadline"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("leader-election.renew-deadline", testValue) + if vString, err := cmdFlags.GetString("leader-election.renew-deadline"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.RenewDeadline) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.retry-period", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.retry-period"); err == nil { + assert.Equal(t, string("2s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "2s" + + cmdFlags.Set("leader-election.retry-period", testValue) + if vString, err := cmdFlags.GetString("leader-election.retry-period"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.RetryPeriod) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_publish-k8s-events", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("publish-k8s-events"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("publish-k8s-events", testValue) + if vBool, err := cmdFlags.GetBool("publish-k8s-events"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.PublishK8sEvents) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go new file mode 100644 index 000000000..25d8fb431 --- /dev/null +++ b/pkg/controller/controller.go @@ -0,0 +1,305 @@ +package controller + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "github.com/lyft/flytepropeller/pkg/controller/config" + "github.com/lyft/flytepropeller/pkg/controller/workflowstore" + + "github.com/lyft/flyteidl/clients/go/admin" + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/clock" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/leaderelection" + "k8s.io/client-go/tools/record" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + flyteScheme "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/scheme" + informers "github.com/lyft/flytepropeller/pkg/client/informers/externalversions" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/workflow" +) + +type metrics struct { + Scope promutils.Scope + EnqueueCountWf prometheus.Counter + EnqueueCountTask prometheus.Counter +} + +// Controller is the controller implementation for FlyteWorkflow resources +type Controller struct { + workerPool *WorkerPool + flyteworkflowSynced cache.InformerSynced + workQueue CompositeWorkQueue + gc *GarbageCollector + numWorkers int + workflowStore workflowstore.FlyteWorkflow + // recorder is an event recorder for recording Event resources to the + // Kubernetes API. + recorder record.EventRecorder + metrics *metrics + leaderElector *leaderelection.LeaderElector +} + +// Runs either as a leader -if configured- or as a standalone process. +func (c *Controller) Run(ctx context.Context) error { + if c.leaderElector == nil { + logger.Infof(ctx, "Running without leader election.") + return c.run(ctx) + } + + logger.Infof(ctx, "Attempting to acquire leader lease and act as leader.") + go c.leaderElector.Run(ctx) + <-ctx.Done() + return nil +} + +// Start the actual work of controller (e.g. GC, consume and process queue items... etc.) +func (c *Controller) run(ctx context.Context) error { + // Initializing WorkerPool + logger.Info(ctx, "Initializing controller") + if err := c.workerPool.Initialize(ctx); err != nil { + return err + } + + // Start the GC + if err := c.gc.StartGC(ctx); err != nil { + logger.Errorf(ctx, "failed to start background GC") + return err + } + + // Start the informer factories to begin populating the informer caches + logger.Info(ctx, "Starting FlyteWorkflow controller") + return c.workerPool.Run(ctx, c.numWorkers, c.flyteworkflowSynced) +} + +// Called from leader elector -if configured- to start running as the leader. +func (c *Controller) onStartedLeading(ctx context.Context) { + ctx, cancelNow := context.WithCancel(context.Background()) + logger.Infof(ctx, "Acquired leader lease.") + go func() { + if err := c.run(ctx); err != nil { + logger.Panic(ctx, err) + } + }() + + <-ctx.Done() + logger.Infof(ctx, "Lost leader lease.") + cancelNow() +} + +// enqueueFlyteWorkflow takes a FlyteWorkflow resource and converts it into a namespace/name +// string which is then put onto the work queue. This method should *not* be +// passed resources of any type other than FlyteWorkflow. +func (c *Controller) enqueueFlyteWorkflow(obj interface{}) { + ctx := context.TODO() + wf, ok := obj.(*v1alpha1.FlyteWorkflow) + if !ok { + logger.Errorf(ctx, "Received a non Workflow object") + return + } + key := wf.GetK8sWorkflowID() + logger.Infof(ctx, "==> Enqueueing workflow [%v]", key) + c.workQueue.Add(key.String()) +} + +func (c *Controller) enqueueWorkflowForNodeUpdates(wID v1alpha1.WorkflowID) { + if wID == "" { + return + } + namespace, name, err := cache.SplitMetaNamespaceKey(wID) + if err != nil { + if _, err2 := c.workflowStore.Get(context.TODO(), namespace, name); err2 != nil { + if workflowstore.IsNotFound(err) { + // Workflow is not found in storage, was probably deleted, but one of the sub-objects sent an event + return + } + } + c.metrics.EnqueueCountTask.Inc() + c.workQueue.AddToSubQueue(wID) + } +} + +func (c *Controller) getWorkflowUpdatesHandler() cache.ResourceEventHandler { + return cache.ResourceEventHandlerFuncs{ + AddFunc: c.enqueueFlyteWorkflow, + UpdateFunc: func(old, new interface{}) { + // TODO we might need to handle updates to the workflow itself. + // Initially maybe we should not support it at all + c.enqueueFlyteWorkflow(new) + }, + DeleteFunc: func(obj interface{}) { + // There is a corner case where the obj is not in fact a valid resource (it sends a DeletedFinalStateUnknown + // object instead) -it has to do with missing some event that leads to not knowing the final state of the + // resource. In which case, we can't use the regular metaAccessor to read obj name/namespace but should + // instead use cache.DeletionHandling* helper functions that know how to deal with that. + + key, err := cache.DeletionHandlingMetaNamespaceKeyFunc(obj) + if err != nil { + logger.Errorf(context.TODO(), "Unable to get key for deleted obj. Error[%v]", err) + return + } + + _, name, err := cache.SplitMetaNamespaceKey(key) + if err != nil { + logger.Errorf(context.TODO(), "Unable to split enqueued key into namespace/execId. Error[%v]", err) + return + } + + logger.Infof(context.TODO(), "Deletion triggered for %v", name) + }, + } +} + +func newControllerMetrics(scope promutils.Scope) *metrics { + c := scope.MustNewCounterVec("wf_enqueue", "workflow enqueue count.", "type") + return &metrics{ + Scope: scope, + EnqueueCountWf: c.WithLabelValues("wf"), + EnqueueCountTask: c.WithLabelValues("task"), + } +} + +func newK8sEventRecorder(ctx context.Context, kubeclientset kubernetes.Interface, publishK8sEvents bool) record.EventRecorder { + // Create event broadcaster + // Add FlyteWorkflow controller types to the default Kubernetes Scheme so Events can be + // logged for FlyteWorkflow Controller types. + err := flyteScheme.AddToScheme(scheme.Scheme) + if err != nil { + logger.Panicf(ctx, "failed to add flyte workflows scheme, %s", err.Error()) + } + logger.Info(ctx, "Creating event broadcaster") + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartLogging(logger.InfofNoCtx) + if publishK8sEvents { + eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeclientset.CoreV1().Events("")}) + } + return eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: controllerAgentName}) +} + +// NewController returns a new FlyteWorkflow controller +func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Interface, flytepropellerClientset clientset.Interface, + flyteworkflowInformerFactory informers.SharedInformerFactory, kubeClient executors.Client, scope promutils.Scope) (*Controller, error) { + + var wfLauncher launchplan.Executor + if cfg.EnableAdminLauncher { + adminClient, err := admin.InitializeAdminClientFromConfig(ctx) + if err != nil { + logger.Errorf(ctx, "failed to initialize Admin client, err :%s", err.Error()) + return nil, err + } + wfLauncher, err = launchplan.NewAdminLaunchPlanExecutor(ctx, adminClient, cfg.DownstreamEval.Duration, + launchplan.GetAdminConfig(), scope.NewSubScope("admin_launcher")) + if err != nil { + logger.Errorf(ctx, "failed to create Admin workflow Launcher, err: %v", err.Error()) + return nil, err + } + + if err := wfLauncher.Initialize(ctx); err != nil { + logger.Errorf(ctx, "failed to initialize Admin workflow Launcher, err: %v", err.Error()) + return nil, err + } + } else { + wfLauncher = launchplan.NewFailFastLaunchPlanExecutor() + } + + logger.Info(ctx, "Setting up event sink and recorder") + eventSink, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create EventSink [%v], error %v", events.GetConfig(ctx).Type, err) + } + gc, err := NewGarbageCollector(cfg, scope, clock.RealClock{}, kubeclientset.CoreV1().Namespaces(), flytepropellerClientset.FlyteworkflowV1alpha1(), cfg.LimitNamespace) + if err != nil { + logger.Errorf(ctx, "failed to initialize GC for workflows") + return nil, errors.Wrapf(err, "failed to initialize WF GC") + } + + eventRecorder := newK8sEventRecorder(ctx, kubeclientset, cfg.PublishK8sEvents) + controller := &Controller{ + metrics: newControllerMetrics(scope), + recorder: eventRecorder, + gc: gc, + numWorkers: cfg.Workers, + } + + lock, err := newResourceLock(kubeclientset.CoreV1(), eventRecorder, cfg.LeaderElection) + if err != nil { + logger.Errorf(ctx, "failed to initialize resource lock.") + return nil, errors.Wrapf(err, "failed to initialize resource lock.") + } + + if lock != nil { + logger.Infof(ctx, "Creating leader elector for the controller.") + controller.leaderElector, err = newLeaderElector(lock, cfg.LeaderElection, controller.onStartedLeading, func() { + logger.Fatal(ctx, "Lost leader state. Shutting down.") + }) + + if err != nil { + logger.Errorf(ctx, "failed to initialize leader elector.") + return nil, errors.Wrapf(err, "failed to initialize leader elector.") + } + } + + // WE are disabling this as the metrics have high cardinality. Metrics seem to be emitted per pod and this has problems + // when we create new pods + // Set Client Metrics Provider + // setClientMetricsProvider(scope.NewSubScope("k8s_client")) + + // obtain references to shared index informers for FlyteWorkflow. + flyteworkflowInformer := flyteworkflowInformerFactory.Flyteworkflow().V1alpha1().FlyteWorkflows() + controller.flyteworkflowSynced = flyteworkflowInformer.Informer().HasSynced + + sCfg := storage.GetConfig() + if sCfg == nil { + logger.Errorf(ctx, "Storage configuration missing.") + } + + store, err := storage.NewDataStore(sCfg, scope.NewSubScope("metastore")) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create Metadata storage") + } + + logger.Info(ctx, "Setting up Catalog client.") + catalogClient := catalog.NewCatalogClient(store) + + workQ, err := NewCompositeWorkQueue(ctx, cfg.Queue, scope) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create WorkQueue [%v]", scope.CurrentScope()) + } + controller.workQueue = workQ + + controller.workflowStore = workflowstore.NewPassthroughWorkflowStore(ctx, scope, flytepropellerClientset.FlyteworkflowV1alpha1(), flyteworkflowInformer.Lister()) + + nodeExecutor, err := nodes.NewExecutor(ctx, store, controller.enqueueWorkflowForNodeUpdates, + cfg.DownstreamEval.Duration, eventSink, wfLauncher, catalogClient, kubeClient, scope) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create Controller.") + } + + workflowExecutor, err := workflow.NewExecutor(ctx, store, controller.enqueueWorkflowForNodeUpdates, eventSink, controller.recorder, cfg.MetadataPrefix, nodeExecutor, scope) + if err != nil { + return nil, err + } + + handler := NewPropellerHandler(ctx, cfg, controller.workflowStore, workflowExecutor, scope) + controller.workerPool = NewWorkerPool(ctx, scope, workQ, handler) + + logger.Info(ctx, "Setting up event handlers") + // Set up an event handler for when FlyteWorkflow resources change + flyteworkflowInformer.Informer().AddEventHandler(controller.getWorkflowUpdatesHandler()) + return controller, nil +} diff --git a/pkg/controller/executors/contextual.go b/pkg/controller/executors/contextual.go new file mode 100644 index 000000000..7d02a6f9a --- /dev/null +++ b/pkg/controller/executors/contextual.go @@ -0,0 +1,30 @@ +package executors + +import ( + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type ContextualWorkflow struct { + v1alpha1.WorkflowMetaExtended + v1alpha1.ExecutableSubWorkflow + v1alpha1.NodeStatusGetter +} + +func NewBaseContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow) v1alpha1.ExecutableWorkflow { + return &ContextualWorkflow{ + ExecutableSubWorkflow: baseWorkflow, + WorkflowMetaExtended: baseWorkflow, + NodeStatusGetter: baseWorkflow.GetExecutionStatus(), + } +} + +// Creates a contextual workflow using the provided interface implementations. +func NewSubContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, subWF v1alpha1.ExecutableSubWorkflow, + nodeStatus v1alpha1.ExecutableNodeStatus) v1alpha1.ExecutableWorkflow { + + return &ContextualWorkflow{ + ExecutableSubWorkflow: subWF, + WorkflowMetaExtended: baseWorkflow, + NodeStatusGetter: nodeStatus, + } +} diff --git a/pkg/controller/executors/kube.go b/pkg/controller/executors/kube.go new file mode 100644 index 000000000..4c790c397 --- /dev/null +++ b/pkg/controller/executors/kube.go @@ -0,0 +1,56 @@ +package executors + +import ( + "context" + + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +//go:generate mockery -name Client + +// A friendly controller-runtime client that gets passed to executors +type Client interface { + // GetClient returns a client configured with the Config + GetClient() client.Client + + // GetCache returns a cache.Cache + GetCache() cache.Cache +} + +type fallbackClientReader struct { + orderedClients []client.Client +} + +func (c fallbackClientReader) Get(ctx context.Context, key client.ObjectKey, out runtime.Object) (err error) { + for _, k8sClient := range c.orderedClients { + if err = k8sClient.Get(ctx, key, out); err == nil { + return nil + } + } + + return +} + +func (c fallbackClientReader) List(ctx context.Context, opts *client.ListOptions, list runtime.Object) (err error) { + for _, k8sClient := range c.orderedClients { + if err = k8sClient.List(ctx, opts, list); err == nil { + return nil + } + } + + return +} + +// Creates a new k8s client that uses the cached client for reads and falls back to making API +// calls if it failed. Write calls will always go to raw client directly. +func NewFallbackClient(cachedClient, rawClient client.Client) client.Client { + return client.DelegatingClient{ + Reader: fallbackClientReader{ + orderedClients: []client.Client{cachedClient, rawClient}, + }, + StatusClient: rawClient, + Writer: rawClient, + } +} diff --git a/pkg/controller/executors/mocks/Client.go b/pkg/controller/executors/mocks/Client.go new file mode 100644 index 000000000..bc7af4670 --- /dev/null +++ b/pkg/controller/executors/mocks/Client.go @@ -0,0 +1,45 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import cache "sigs.k8s.io/controller-runtime/pkg/cache" +import client "sigs.k8s.io/controller-runtime/pkg/client" + +import mock "github.com/stretchr/testify/mock" + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// GetCache provides a mock function with given fields: +func (_m *Client) GetCache() cache.Cache { + ret := _m.Called() + + var r0 cache.Cache + if rf, ok := ret.Get(0).(func() cache.Cache); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(cache.Cache) + } + } + + return r0 +} + +// GetClient provides a mock function with given fields: +func (_m *Client) GetClient() client.Client { + ret := _m.Called() + + var r0 client.Client + if rf, ok := ret.Get(0).(func() client.Client); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.Client) + } + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/fake.go b/pkg/controller/executors/mocks/fake.go new file mode 100644 index 000000000..27fb94060 --- /dev/null +++ b/pkg/controller/executors/mocks/fake.go @@ -0,0 +1,13 @@ +package mocks + +import ( + "sigs.k8s.io/controller-runtime/pkg/cache/informertest" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func NewFakeKubeClient() *Client { + c := Client{} + c.On("GetClient").Return(fake.NewFakeClient()) + c.On("GetCache").Return(&informertest.FakeInformers{}) + return &c +} diff --git a/pkg/controller/executors/node.go b/pkg/controller/executors/node.go new file mode 100644 index 000000000..3dd1cb270 --- /dev/null +++ b/pkg/controller/executors/node.go @@ -0,0 +1,100 @@ +package executors + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// Phase of the node +type NodePhase int + +const ( + // Indicates that the node is not yet ready to be executed and is pending any previous nodes completion + NodePhasePending NodePhase = iota + // Indicates that the node was queued and will start running soon + NodePhaseQueued + // Indicates that the payload associated with this node is being executed and is not yet done + NodePhaseRunning + // Indicates that the nodes payload has been successfully completed, but any downstream nodes from this node may not yet have completed + // We could make Success = running, but this enables more granular control + NodePhaseSuccess + // Complete indicates successful completion of a node. For singular nodes (nodes that have only one execution) success = complete, but, the executor + // will always signal completion + NodePhaseComplete + // Node failed in execution, either this node or anything in the downstream chain + NodePhaseFailed + // Internal error observed. This state should always be accompanied with an `error`. if not the behavior is undefined + NodePhaseUndefined +) + +func (p NodePhase) String() string { + switch p { + case NodePhaseRunning: + return "Running" + case NodePhaseQueued: + return "Queued" + case NodePhasePending: + return "Pending" + case NodePhaseFailed: + return "Failed" + case NodePhaseSuccess: + return "Success" + case NodePhaseComplete: + return "Complete" + case NodePhaseUndefined: + return "Undefined" + } + return fmt.Sprintf("Unknown - %d", p) +} + +// Core Node Executor that is used to execute a node. This is a recursive node executor and understands node dependencies +type Node interface { + // This method is used specifically to set inputs for start node. This is because start node does not retrieve inputs + // from predecessors, but the inputs are inputs to the workflow or inputs to the parent container (workflow) node. + SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *core.LiteralMap) (NodeStatus, error) + + // This is the main entrypoint to execute a node. It recursively depth-first goes through all ready nodes and starts their execution + // This returns either + // - 1. It finds a blocking node (not ready, or running) + // - 2. A node fails and hence the workflow will fail + // - 3. The final/end node has completed and the workflow should be stopped + RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + + // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them + AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error + + // This method should be used to initialize Node executor + Initialize(ctx context.Context) error +} + +// Helper struct to allow passing of status between functions +type NodeStatus struct { + NodePhase NodePhase + Err error +} + +func (n *NodeStatus) IsComplete() bool { + return n.NodePhase == NodePhaseComplete +} + +func (n *NodeStatus) HasFailed() bool { + return n.NodePhase == NodePhaseFailed +} + +func (n *NodeStatus) PartiallyComplete() bool { + return n.NodePhase == NodePhaseSuccess +} + +var NodeStatusPending = NodeStatus{NodePhase: NodePhasePending} +var NodeStatusQueued = NodeStatus{NodePhase: NodePhaseQueued} +var NodeStatusRunning = NodeStatus{NodePhase: NodePhaseRunning} +var NodeStatusSuccess = NodeStatus{NodePhase: NodePhaseSuccess} +var NodeStatusComplete = NodeStatus{NodePhase: NodePhaseComplete} +var NodeStatusUndefined = NodeStatus{NodePhase: NodePhaseUndefined} + +func NodeStatusFailed(err error) NodeStatus { + return NodeStatus{NodePhase: NodePhaseFailed, Err: err} +} diff --git a/pkg/controller/executors/workflow.go b/pkg/controller/executors/workflow.go new file mode 100644 index 000000000..31db1cab0 --- /dev/null +++ b/pkg/controller/executors/workflow.go @@ -0,0 +1,13 @@ +package executors + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type Workflow interface { + Initialize(ctx context.Context) error + HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error + HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error +} diff --git a/pkg/controller/finalizer.go b/pkg/controller/finalizer.go new file mode 100644 index 000000000..f1a8ba8eb --- /dev/null +++ b/pkg/controller/finalizer.go @@ -0,0 +1,36 @@ +package controller + +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + +const FinalizerKey = "flyte-finalizer" + +// NOTE: Some of these APIs are exclusive and do not compare the actual values of the finalizers. +// the intention of this module is to set only one opaque finalizer at a time. If you want to set multiple (not common) +// finalizers, use this module carefully and at your own risk! + +// Sets a new finalizer in case the finalizer is empty +func SetFinalizerIfEmpty(meta v1.Object, finalizer string) { + if !HasFinalizer(meta) { + meta.SetFinalizers([]string{finalizer}) + } +} + +// Check if the deletion timestamp is set, this is set automatically when an object is deleted +func IsDeleted(meta v1.Object) bool { + return meta.GetDeletionTimestamp() != nil +} + +// Reset all the finalizers on the object +func ResetFinalizers(meta v1.Object) { + meta.SetFinalizers([]string{}) +} + +// Currently we only compare the lengths of finalizers. If you add finalizers directly these API;'s will not work +func FinalizersIdentical(o1 v1.Object, o2 v1.Object) bool { + return len(o1.GetFinalizers()) == len(o2.GetFinalizers()) +} + +// Check if any finalizer is set +func HasFinalizer(meta v1.Object) bool { + return len(meta.GetFinalizers()) != 0 +} diff --git a/pkg/controller/finalizer_test.go b/pkg/controller/finalizer_test.go new file mode 100644 index 000000000..05401806d --- /dev/null +++ b/pkg/controller/finalizer_test.go @@ -0,0 +1,70 @@ +package controller + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/batch/v1" + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestFinalizersIdentical(t *testing.T) { + noFinalizer := &v1.Job{} + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + + assert.True(t, FinalizersIdentical(noFinalizer, noFinalizer)) + assert.True(t, FinalizersIdentical(withFinalizer, withFinalizer)) + assert.False(t, FinalizersIdentical(noFinalizer, withFinalizer)) + withMultipleFinalizers := &v1.Job{} + withMultipleFinalizers.SetFinalizers([]string{"f1", "f2"}) + assert.False(t, FinalizersIdentical(withMultipleFinalizers, withFinalizer)) + + withDiffFinalizer := &v1.Job{} + withDiffFinalizer.SetFinalizers([]string{"f1"}) + assert.True(t, FinalizersIdentical(withFinalizer, withDiffFinalizer)) +} + +func TestIsDeleted(t *testing.T) { + noTermTS := &v1.Job{} + termedTS := &v1.Job{} + n := v12.Now() + termedTS.SetDeletionTimestamp(&n) + + assert.True(t, IsDeleted(termedTS)) + assert.False(t, IsDeleted(noTermTS)) +} + +func TestHasFinalizer(t *testing.T) { + noFinalizer := &v1.Job{} + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + + assert.False(t, HasFinalizer(noFinalizer)) + assert.True(t, HasFinalizer(withFinalizer)) +} + +func TestSetFinalizerIfEmpty(t *testing.T) { + noFinalizer := &v1.Job{} + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + + assert.False(t, HasFinalizer(noFinalizer)) + SetFinalizerIfEmpty(noFinalizer, "f1") + assert.True(t, HasFinalizer(noFinalizer)) + assert.Equal(t, []string{"f1"}, noFinalizer.GetFinalizers()) + + SetFinalizerIfEmpty(withFinalizer, "f1") + assert.Equal(t, []string{"t1"}, withFinalizer.GetFinalizers()) +} + +func TestResetFinalizer(t *testing.T) { + noFinalizer := &v1.Job{} + ResetFinalizers(noFinalizer) + assert.Equal(t, []string{}, noFinalizer.GetFinalizers()) + + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + ResetFinalizers(withFinalizer) + assert.Equal(t, []string{}, withFinalizer.GetFinalizers()) +} diff --git a/pkg/controller/garbage_collector.go b/pkg/controller/garbage_collector.go new file mode 100644 index 000000000..2361ff701 --- /dev/null +++ b/pkg/controller/garbage_collector.go @@ -0,0 +1,145 @@ +package controller + +import ( + "context" + "runtime/pprof" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "strings" + + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/clock" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +type gcMetrics struct { + gcRoundSuccess labeled.Counter + gcRoundFailure labeled.Counter + gcTime labeled.StopWatch +} + +// Garbage collector is an active background cleanup service, that deletes all workflows that are completed and older +// than the configured TTL +type GarbageCollector struct { + wfClient v1alpha1.FlyteworkflowV1alpha1Interface + namespaceClient corev1.NamespaceInterface + ttlHours int + interval time.Duration + clk clock.Clock + metrics *gcMetrics + namespace string +} + +// Issues a background deletion command with label selector for all completed workflows outside of the retention period +func (g *GarbageCollector) deleteWorkflows(ctx context.Context) error { + + s := CompletedWorkflowsSelectorOutsideRetentionPeriod(g.ttlHours-1, g.clk.Now()) + + // Delete doesn't support 'all' namespaces. Let's fetch namespaces and loop over each. + if g.namespace == "" || strings.ToLower(g.namespace) == "all" || strings.ToLower(g.namespace) == "all-namespaces" { + namespaceList, err := g.namespaceClient.List(v1.ListOptions{}) + if err != nil { + return err + } + for _, n := range namespaceList.Items { + namespaceCtx := contextutils.WithNamespace(ctx, n.GetName()) + logger.Infof(namespaceCtx, "Triggering Workflow delete for namespace: [%s]", n.GetName()) + + if err := g.deleteWorkflowsForNamespace(n.GetName(), s); err != nil { + g.metrics.gcRoundFailure.Inc(namespaceCtx) + logger.Errorf(namespaceCtx, "Garbage collection failed for for namespace: [%s]. Error : [%v]", n.GetName(), err) + } else { + g.metrics.gcRoundSuccess.Inc(namespaceCtx) + } + } + } else { + namespaceCtx := contextutils.WithNamespace(ctx, g.namespace) + logger.Infof(namespaceCtx, "Triggering Workflow delete for namespace: [%s]", g.namespace) + if err := g.deleteWorkflowsForNamespace(g.namespace, s); err != nil { + g.metrics.gcRoundFailure.Inc(namespaceCtx) + logger.Errorf(namespaceCtx, "Garbage collection failed for for namespace: [%s]. Error : [%v]", g.namespace, err) + } else { + g.metrics.gcRoundSuccess.Inc(namespaceCtx) + } + } + return nil +} + +func (g *GarbageCollector) deleteWorkflowsForNamespace(namespace string, labelSelector *v1.LabelSelector) error { + gracePeriodZero := int64(0) + propagation := v1.DeletePropagationBackground + + return g.wfClient.FlyteWorkflows(namespace).DeleteCollection( + &v1.DeleteOptions{ + GracePeriodSeconds: &gracePeriodZero, + PropagationPolicy: &propagation, + }, + v1.ListOptions{ + LabelSelector: v1.FormatLabelSelector(labelSelector), + }, + ) +} + +// A periodic GC running +func (g *GarbageCollector) runGC(ctx context.Context, ticker clock.Ticker) { + logger.Infof(ctx, "Background workflow garbage collection started, with duration [%s], TTL [%d] hours", g.interval.String(), g.ttlHours) + + ctx = contextutils.WithGoroutineLabel(ctx, "gc-worker") + pprof.SetGoroutineLabels(ctx) + defer ticker.Stop() + for { + select { + case <-ticker.C(): + logger.Infof(ctx, "Garbage collector running...") + t := g.metrics.gcTime.Start(ctx) + if err := g.deleteWorkflows(ctx); err != nil { + logger.Errorf(ctx, "Garbage collection failed in this round.Error : [%v]", err) + } + t.Stop() + case <-ctx.Done(): + logger.Infof(ctx, "Garbage collector stopping") + return + + } + } +} + +// Use this method to start a background garbage collection routine. Use the context to signal an exit signal +func (g *GarbageCollector) StartGC(ctx context.Context) error { + if g.ttlHours <= 0 { + logger.Warningf(ctx, "Garbage collector is disabled, as ttl [%d] is <=0", g.ttlHours) + return nil + } + ticker := g.clk.NewTicker(g.interval) + go g.runGC(ctx, ticker) + return nil +} + +func NewGarbageCollector(cfg *config.Config, scope promutils.Scope, clk clock.Clock, namespaceClient corev1.NamespaceInterface, wfClient v1alpha1.FlyteworkflowV1alpha1Interface, namespace string) (*GarbageCollector, error) { + ttl := 23 + if cfg.MaxTTLInHours < 23 { + ttl = cfg.MaxTTLInHours + } else { + logger.Warningf(context.TODO(), "defaulting max ttl for workflows to 23 hours, since configured duration is larger than 23 [%d]", cfg.MaxTTLInHours) + } + return &GarbageCollector{ + wfClient: wfClient, + ttlHours: ttl, + interval: cfg.GCInterval.Duration, + namespaceClient: namespaceClient, + metrics: &gcMetrics{ + gcTime: labeled.NewStopWatch("gc_latency", "time taken to issue a delete for TTL'ed workflows", time.Millisecond, scope), + gcRoundSuccess: labeled.NewCounter("gc_success", "successful executions of delete request", scope), + gcRoundFailure: labeled.NewCounter("gc_failure", "failure to delete workflows", scope), + }, + clk: clk, + namespace: namespace, + }, nil +} diff --git a/pkg/controller/garbage_collector_test.go b/pkg/controller/garbage_collector_test.go new file mode 100644 index 000000000..8874af55c --- /dev/null +++ b/pkg/controller/garbage_collector_test.go @@ -0,0 +1,166 @@ +package controller + +import ( + "context" + "sync" + "testing" + "time" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + corev1Types "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/clock" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +func TestNewGarbageCollector(t *testing.T) { + t.Run("enabled", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 2, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), clock.NewFakeClock(time.Now()), nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, 2, gc.ttlHours) + }) + + t.Run("enabledBeyond23Hours", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 24, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), clock.NewFakeClock(time.Now()), nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, 23, gc.ttlHours) + }) + + t.Run("ttl0", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 0, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), nil, nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, 0, gc.ttlHours) + assert.NoError(t, gc.StartGC(context.TODO())) + + }) + + t.Run("ttl-1", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: -1, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), nil, nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, -1, gc.ttlHours) + assert.NoError(t, gc.StartGC(context.TODO())) + }) +} + +type mockWfClient struct { + v1alpha1.FlyteWorkflowInterface + DeleteCollectionCb func(options *v1.DeleteOptions, listOptions v1.ListOptions) error +} + +func (m *mockWfClient) DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + return m.DeleteCollectionCb(options, listOptions) +} + +type mockClient struct { + v1alpha1.FlyteworkflowV1alpha1Client + FlyteWorkflowsCb func(namespace string) v1alpha1.FlyteWorkflowInterface +} + +func (m *mockClient) FlyteWorkflows(namespace string) v1alpha1.FlyteWorkflowInterface { + return m.FlyteWorkflowsCb(namespace) +} + +type mockNamespaceClient struct { + corev1.NamespaceInterface + ListCb func(opts v1.ListOptions) (*corev1Types.NamespaceList, error) +} + +func (m *mockNamespaceClient) List(opts v1.ListOptions) (*corev1Types.NamespaceList, error) { + return m.ListCb(opts) +} + +func TestGarbageCollector_StartGC(t *testing.T) { + wg := sync.WaitGroup{} + b := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + mockWfClient := &mockWfClient{ + DeleteCollectionCb: func(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + assert.NotNil(t, options) + assert.NotNil(t, listOptions) + assert.Equal(t, "hour-of-day in (0,1,10,11,12,13,14,15,16,17,18,19,2,20,21,3,4,5,6,7,8,9),termination-status=terminated", listOptions.LabelSelector) + wg.Done() + return nil + }, + } + + mockClient := &mockClient{ + FlyteWorkflowsCb: func(namespace string) v1alpha1.FlyteWorkflowInterface { + return mockWfClient + }, + } + + mockNamespaceInvoked := false + mockNamespaceClient := &mockNamespaceClient{ + ListCb: func(opts v1.ListOptions) (*corev1Types.NamespaceList, error) { + mockNamespaceInvoked = true + return &corev1Types.NamespaceList{ + Items: []corev1Types.Namespace{ + { + ObjectMeta: v1.ObjectMeta{ + Name: "ns1", + }, + }, + { + ObjectMeta: v1.ObjectMeta{ + Name: "ns2", + }, + }, + }, + }, nil + }, + } + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 2, + } + + t.Run("one-namespace", func(t *testing.T) { + fakeClock := clock.NewFakeClock(b) + mockNamespaceInvoked = false + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), fakeClock, mockNamespaceClient, mockClient, "flyte") + assert.NoError(t, err) + wg.Add(1) + ctx := context.TODO() + ctx, cancel := context.WithCancel(ctx) + assert.NoError(t, gc.StartGC(ctx)) + fakeClock.Step(time.Minute * 30) + wg.Wait() + cancel() + assert.False(t, mockNamespaceInvoked) + }) + + t.Run("all-namespace", func(t *testing.T) { + fakeClock := clock.NewFakeClock(b) + mockNamespaceInvoked = false + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), fakeClock, mockNamespaceClient, mockClient, "all") + assert.NoError(t, err) + wg.Add(2) + ctx := context.TODO() + ctx, cancel := context.WithCancel(ctx) + assert.NoError(t, gc.StartGC(ctx)) + fakeClock.Step(time.Minute * 30) + wg.Wait() + cancel() + assert.True(t, mockNamespaceInvoked) + }) +} diff --git a/pkg/controller/handler.go b/pkg/controller/handler.go new file mode 100644 index 000000000..798d7b6c0 --- /dev/null +++ b/pkg/controller/handler.go @@ -0,0 +1,188 @@ +package controller + +import ( + "context" + "fmt" + "runtime/debug" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/config" + "github.com/lyft/flytepropeller/pkg/controller/workflowstore" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flytepropeller/pkg/controller/executors" +) + +// TODO Lets move everything to use controller runtime + +type propellerMetrics struct { + Scope promutils.Scope + DeepCopyTime promutils.StopWatch + RawWorkflowTraversalTime promutils.StopWatch + SystemError prometheus.Counter + AbortError prometheus.Counter + PanicObserved prometheus.Counter + RoundSkipped prometheus.Counter + WorkflowNotFound prometheus.Counter +} + +func newPropellerMetrics(scope promutils.Scope) *propellerMetrics { + roundScope := scope.NewSubScope("round") + return &propellerMetrics{ + Scope: scope, + DeepCopyTime: roundScope.MustNewStopWatch("deepcopy", "Total time to deep copy wf object", time.Millisecond), + RawWorkflowTraversalTime: roundScope.MustNewStopWatch("raw", "Total time to traverse the workflow", time.Millisecond), + SystemError: roundScope.MustNewCounter("system_error", "Failure to reconcile a workflow, system error"), + AbortError: roundScope.MustNewCounter("abort_error", "Failure to abort a workflow, system error"), + PanicObserved: roundScope.MustNewCounter("panic", "Panic during handling or aborting workflow"), + RoundSkipped: roundScope.MustNewCounter("skipped", "Round Skipped because of stale workflow"), + WorkflowNotFound: roundScope.MustNewCounter("not_found", "workflow not found in the cache"), + } +} + +type Propeller struct { + wfStore workflowstore.FlyteWorkflow + workflowExecutor executors.Workflow + metrics *propellerMetrics + cfg *config.Config +} + +func (p *Propeller) Initialize(ctx context.Context) error { + return p.workflowExecutor.Initialize(ctx) +} + +// reconciler compares the actual state with the desired, and attempts to +// converge the two. It then updates the GetExecutionStatus block of the FlyteWorkflow resource +// with the current status of the resource. +// Every FlyteWorkflow transitions through the following +// +// The Workflow to be worked on is identified for the given namespace and executionID (which is the name of the workflow) +// The return value should be an error, in the case, we wish to retry this workflow +//
+//
+//     +--------+        +--------+        +--------+     +--------+
+//     |        |        |        |        |        |     |        |
+//     | Ready  +--------> Running+--------> Succeeding---> Success|
+//     |        |        |        |        |        |     |        |
+//     +--------+        +--------+        +---------     +--------+
+//         |                  |
+//         |                  |
+//         |             +----v---+        +--------+
+//         |             |        |        |        |
+//         +-------------> Failing+--------> Failed |
+//                       |        |        |        |
+//                       +--------+        +--------+
+// 
+func (p *Propeller) Handle(ctx context.Context, namespace, name string) error { + logger.Infof(ctx, "Processing Workflow.") + defer logger.Infof(ctx, "Completed processing workflow.") + + // Get the FlyteWorkflow resource with this namespace/name + w, err := p.wfStore.Get(ctx, namespace, name) + if err != nil { + if workflowstore.IsNotFound(err) { + p.metrics.WorkflowNotFound.Inc() + logger.Warningf(ctx, "Workflow namespace[%v]/name[%v] not found, may be deleted.", namespace, name) + return nil + } + if workflowstore.IsWorkflowStale(err) { + p.metrics.RoundSkipped.Inc() + logger.Warningf(ctx, "Workflow namespace[%v]/name[%v] Stale.", namespace, name) + return nil + } + logger.Warningf(ctx, "Failed to GetWorkflow, retrying with back-off", err) + return err + } + + t := p.metrics.DeepCopyTime.Start() + wfDeepCopy := w.DeepCopy() + t.Stop() + ctx = contextutils.WithWorkflowID(ctx, wfDeepCopy.GetID()) + + maxRetries := uint32(p.cfg.MaxWorkflowRetries) + if IsDeleted(wfDeepCopy) || (wfDeepCopy.Status.FailedAttempts > maxRetries) { + var err error + func() { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when aborting workflow, Stack: [%s]", string(stack)) + p.metrics.PanicObserved.Inc() + } + }() + err = p.workflowExecutor.HandleAbortedWorkflow(ctx, wfDeepCopy, maxRetries) + }() + if err != nil { + p.metrics.AbortError.Inc() + return err + } + } else { + if wfDeepCopy.GetExecutionStatus().IsTerminated() { + if HasCompletedLabel(wfDeepCopy) && !HasFinalizer(wfDeepCopy) { + logger.Debugf(ctx, "Workflow is terminated.") + return nil + } + // NOTE: This should never really happen, but in case we externally mark the workflow as terminated + // We should allow cleanup + logger.Warn(ctx, "Workflow is marked as terminated but doesn't have the completed label, marking it as completed.") + } else { + SetFinalizerIfEmpty(wfDeepCopy, FinalizerKey) + + func() { + t := p.metrics.RawWorkflowTraversalTime.Start() + defer func() { + t.Stop() + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when aborting workflow, Stack: [%s]", string(stack)) + p.metrics.PanicObserved.Inc() + } + }() + err = p.workflowExecutor.HandleFlyteWorkflow(ctx, wfDeepCopy) + }() + + if err != nil { + logger.Errorf(ctx, "Error when trying to reconcile workflow. Error [%v]", err) + // Let's mark these as system errors. + // We only want to increase failed attempts and discard any other partial changes to the CRD. + wfDeepCopy = w.DeepCopy() + wfDeepCopy.GetExecutionStatus().IncFailedAttempts() + wfDeepCopy.GetExecutionStatus().SetMessage(err.Error()) + p.metrics.SystemError.Inc() + } else { + // No updates in the status we detected, we will skip writing to KubeAPI + if wfDeepCopy.Status.Equals(&w.Status) { + logger.Info(ctx, "WF hasn't been updated in this round.") + return nil + } + } + } + } + // If the end result is a terminated workflow, we remove the labels + if wfDeepCopy.GetExecutionStatus().IsTerminated() { + // We add a completed label so that we can avoid polling for this workflow + SetCompletedLabel(wfDeepCopy, time.Now()) + ResetFinalizers(wfDeepCopy) + } + // TODO we will need to call updatestatus when it is supported. But to preserve metadata like (label/finalizer) we will need to use update + + // update the GetExecutionStatus block of the FlyteWorkflow resource. UpdateStatus will not + // allow changes to the Spec of the resource, which is ideal for ensuring + // nothing other than resource status has been updated. + return p.wfStore.Update(ctx, wfDeepCopy, workflowstore.PriorityClassCritical) +} + +func NewPropellerHandler(_ context.Context, cfg *config.Config, wfStore workflowstore.FlyteWorkflow, executor executors.Workflow, scope promutils.Scope) *Propeller { + + metrics := newPropellerMetrics(scope) + return &Propeller{ + metrics: metrics, + wfStore: wfStore, + workflowExecutor: executor, + cfg: cfg, + } +} diff --git a/pkg/controller/handler_test.go b/pkg/controller/handler_test.go new file mode 100644 index 000000000..3c8dd7dcf --- /dev/null +++ b/pkg/controller/handler_test.go @@ -0,0 +1,408 @@ +package controller + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytepropeller/pkg/controller/config" + "github.com/lyft/flytepropeller/pkg/controller/workflowstore" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +type mockExecutor struct { + HandleCb func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error + HandleAbortedCb func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error +} + +func (m *mockExecutor) Initialize(ctx context.Context) error { + return nil +} + +func (m *mockExecutor) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + return m.HandleAbortedCb(ctx, w, maxRetries) +} + +func (m *mockExecutor) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + return m.HandleCb(ctx, w) +} + +func TestPropeller_Handle(t *testing.T) { + scope := promutils.NewTestScope() + ctx := context.TODO() + s := workflowstore.NewInMemoryWorkflowStore() + exec := &mockExecutor{} + cfg := &config.Config{ + MaxWorkflowRetries: 0, + } + + p := NewPropellerHandler(ctx, cfg, s, exec, scope) + + const namespace = "test" + const name = "123" + t.Run("notPresent", func(t *testing.T) { + assert.NoError(t, p.Handle(ctx, namespace, name)) + }) + + t.Run("terminated", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseFailed, + }, + })) + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) + + t.Run("happy", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 1, len(r.Finalizers)) + assert.False(t, HasCompletedLabel(r)) + }) + + t.Run("error", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + return fmt.Errorf("failed") + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseReady, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + assert.False(t, HasCompletedLabel(r)) + }) + + t.Run("abort", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + FailedAttempts: 1, + }, + })) + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseFailed, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + }) + + t.Run("abort_panics", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"x"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + FailedAttempts: 1, + Phase: v1alpha1.WorkflowPhaseRunning, + }, + })) + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + panic("error") + } + assert.Error(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseRunning, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 1, len(r.Finalizers)) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + }) + + t.Run("noUpdate", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, 1, len(r.Finalizers)) + }) + + t.Run("handlingPanics", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + panic("error") + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, 1, len(r.Finalizers)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + }) + + t.Run("noUpdate", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, 1, len(r.Finalizers)) + }) + + t.Run("retriesExhaustedFinalize", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseRunning, + FailedAttempts: 1, + }, + })) + abortCalled := false + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.Status.UpdatePhase(v1alpha1.WorkflowPhaseFailed, "Aborted") + abortCalled = true + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + assert.True(t, abortCalled) + }) + + t.Run("deletedShouldBeFinalized", func(t *testing.T) { + n := v1.Now() + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + DeletionTimestamp: &n, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.Status.UpdatePhase(v1alpha1.WorkflowPhaseAborted, "Aborted") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseAborted, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) + + t.Run("deletedButAbortFailed", func(t *testing.T) { + n := v1.Now() + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + DeletionTimestamp: &n, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + return fmt.Errorf("failed") + } + + assert.Error(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, []string{"f1"}, r.Finalizers) + assert.False(t, HasCompletedLabel(r)) + }) + + t.Run("removefinalizerOnTerminateSuccess", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSuccess, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSuccess, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) + + t.Run("removefinalizerOnTerminateFailure", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseFailed, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) +} + +func TestPropellerHandler_Initialize(t *testing.T) { + scope := promutils.NewTestScope() + ctx := context.TODO() + s := workflowstore.NewInMemoryWorkflowStore() + exec := &mockExecutor{} + cfg := &config.Config{ + MaxWorkflowRetries: 0, + } + + p := NewPropellerHandler(ctx, cfg, s, exec, scope) + + assert.NoError(t, p.Initialize(ctx)) +} diff --git a/pkg/controller/leaderelection.go b/pkg/controller/leaderelection.go new file mode 100644 index 000000000..8c6cd1132 --- /dev/null +++ b/pkg/controller/leaderelection.go @@ -0,0 +1,78 @@ +package controller + +import ( + "context" + "fmt" + "os" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "k8s.io/apimachinery/pkg/util/rand" + + v1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/tools/leaderelection" + "k8s.io/client-go/tools/leaderelection/resourcelock" + "k8s.io/client-go/tools/record" +) + +const ( + // Env var to lookup pod name in. In pod spec, you will have to specify it like this: + // env: + // - name: POD_NAME + // valueFrom: + // fieldRef: + // fieldPath: metadata.name + podNameEnvVar = "POD_NAME" +) + +// NewResourceLock creates a new config map resource lock for use in a leader election loop +func newResourceLock(corev1 v1.CoreV1Interface, eventRecorder record.EventRecorder, options config.LeaderElectionConfig) ( + resourcelock.Interface, error) { + + if !options.Enabled { + return nil, nil + } + + // Default the LeaderElectionID + if len(options.LockConfigMap.String()) == 0 { + return nil, fmt.Errorf("to enable leader election, a config map must be provided") + } + + // Leader id, needs to be unique + return resourcelock.New(resourcelock.ConfigMapsResourceLock, + options.LockConfigMap.Namespace, + options.LockConfigMap.Name, + corev1, + resourcelock.ResourceLockConfig{ + Identity: getUniqueLeaderID(), + EventRecorder: eventRecorder, + }) +} + +func getUniqueLeaderID() string { + val, found := os.LookupEnv(podNameEnvVar) + if found { + return val + } + + id, err := os.Hostname() + if err != nil { + id = "" + } + + return fmt.Sprintf("%v_%v", id, rand.String(10)) +} + +func newLeaderElector(lock resourcelock.Interface, cfg config.LeaderElectionConfig, + leaderFn func(ctx context.Context), leaderStoppedFn func()) (*leaderelection.LeaderElector, error) { + return leaderelection.NewLeaderElector(leaderelection.LeaderElectionConfig{ + Lock: lock, + LeaseDuration: cfg.LeaseDuration.Duration, + RenewDeadline: cfg.RenewDeadline.Duration, + RetryPeriod: cfg.RetryPeriod.Duration, + Callbacks: leaderelection.LeaderCallbacks{ + OnStartedLeading: leaderFn, + OnStoppedLeading: leaderStoppedFn, + }, + }) +} diff --git a/pkg/controller/nodes/branch/comparator.go b/pkg/controller/nodes/branch/comparator.go new file mode 100644 index 000000000..bf7e26ce9 --- /dev/null +++ b/pkg/controller/nodes/branch/comparator.go @@ -0,0 +1,139 @@ +package branch + +import ( + "reflect" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/pkg/errors" +) + +type comparator func(lValue *core.Primitive, rValue *core.Primitive) bool +type comparators struct { + gt comparator + eq comparator +} + +var primitiveBooleanType = reflect.TypeOf(&core.Primitive_Boolean{}).String() + +var perTypeComparators = map[string]comparators{ + reflect.TypeOf(&core.Primitive_FloatValue{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetFloatValue() > rValue.GetFloatValue() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetFloatValue() == rValue.GetFloatValue() + }, + }, + reflect.TypeOf(&core.Primitive_Integer{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetInteger() > rValue.GetInteger() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetInteger() == rValue.GetInteger() + }, + }, + reflect.TypeOf(&core.Primitive_Boolean{}).String(): { + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetBoolean() == rValue.GetBoolean() + }, + }, + reflect.TypeOf(&core.Primitive_StringValue{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() > rValue.GetStringValue() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() == rValue.GetStringValue() + }, + }, + reflect.TypeOf(&core.Primitive_StringValue{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() > rValue.GetStringValue() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() == rValue.GetStringValue() + }, + }, + reflect.TypeOf(&core.Primitive_Datetime{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDatetime().GetSeconds() > rValue.GetDatetime().GetSeconds() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDatetime().GetSeconds() == rValue.GetDatetime().GetSeconds() + }, + }, + reflect.TypeOf(&core.Primitive_Duration{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDuration().GetSeconds() > rValue.GetDuration().GetSeconds() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDuration().GetSeconds() == rValue.GetDuration().GetSeconds() + }, + }, +} + +func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { + lValueType := reflect.TypeOf(lValue.Value) + rValueType := reflect.TypeOf(rValue.Value) + if lValueType != rValueType { + return false, errors.Errorf("Comparison between different primitives types. lVal[%v]:rVal[%v]", lValueType, rValueType) + } + comps, ok := perTypeComparators[lValueType.String()] + if !ok { + return false, errors.Errorf("Comparator not defined for type: [%v]", lValueType.String()) + } + isBoolean := false + if lValueType.String() == primitiveBooleanType { + isBoolean = true + } + switch op { + case core.ComparisonExpression_GT: + if isBoolean { + return false, errors.Errorf("[GT] not defined for boolean operands.") + } + return comps.gt(lValue, rValue), nil + case core.ComparisonExpression_GTE: + if isBoolean { + return false, errors.Errorf("[GTE] not defined for boolean operands.") + } + return comps.eq(lValue, rValue) || comps.gt(lValue, rValue), nil + case core.ComparisonExpression_LT: + if isBoolean { + return false, errors.Errorf("[LT] not defined for boolean operands.") + } + return !(comps.gt(lValue, rValue) || comps.eq(lValue, rValue)), nil + case core.ComparisonExpression_LTE: + if isBoolean { + return false, errors.Errorf("[LTE] not defined for boolean operands.") + } + return !comps.gt(lValue, rValue), nil + case core.ComparisonExpression_EQ: + return comps.eq(lValue, rValue), nil + case core.ComparisonExpression_NEQ: + return !comps.eq(lValue, rValue), nil + } + return false, errors.Errorf("Unsupported operator type in Propeller. System error.") +} + +func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { + if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. RHS Variable is non primitive.") + } + return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op) +} + +func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. LHS Variable is non primitive.") + } + return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op) +} + +func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. LHS Variable is non primitive.") + } + if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. RHS Variable is non primitive") + } + return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op) +} diff --git a/pkg/controller/nodes/branch/comparator_test.go b/pkg/controller/nodes/branch/comparator_test.go new file mode 100644 index 000000000..c34f28b0d --- /dev/null +++ b/pkg/controller/nodes/branch/comparator_test.go @@ -0,0 +1,403 @@ +package branch + +import ( + "fmt" + "testing" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +func TestEvaluate_int(t *testing.T) { + p1 := utils.MustMakePrimitive(1) + p2 := utils.MustMakePrimitive(2) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_float(t *testing.T) { + p1 := utils.MustMakePrimitive(1.0) + p2 := utils.MustMakePrimitive(2.0) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_string(t *testing.T) { + p1 := utils.MustMakePrimitive("a") + p2 := utils.MustMakePrimitive("b") + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_datetime(t *testing.T) { + p1 := utils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC)) + p2 := utils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC)) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_duration(t *testing.T) { + p1 := utils.MustMakePrimitive(10 * time.Second) + p2 := utils.MustMakePrimitive(11 * time.Second) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_boolean(t *testing.T) { + p1 := utils.MustMakePrimitive(true) + p2 := utils.MustMakePrimitive(false) + f := func(op core.ComparisonExpression_Operator) { + // GT/LT = false + msg := fmt.Sprintf("Evaluating: [%s]", op.String()) + b, err := Evaluate(p1, p2, op) + assert.Error(t, err, msg) + assert.False(t, b, msg) + b, err = Evaluate(p2, p1, op) + assert.Error(t, err, msg) + assert.False(t, b, msg) + b, err = Evaluate(p1, p1, op) + assert.Error(t, err, msg) + assert.False(t, b, msg) + } + f(core.ComparisonExpression_GT) + f(core.ComparisonExpression_LT) + f(core.ComparisonExpression_GTE) + f(core.ComparisonExpression_LTE) + + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } +} diff --git a/pkg/controller/nodes/branch/evaluator.go b/pkg/controller/nodes/branch/evaluator.go new file mode 100644 index 000000000..24377a11c --- /dev/null +++ b/pkg/controller/nodes/branch/evaluator.go @@ -0,0 +1,139 @@ +package branch + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + regErrors "github.com/pkg/errors" +) + +func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *handler.Data) (bool, error) { + var lValue *core.Literal + var rValue *core.Literal + var lPrim *core.Primitive + var rPrim *core.Primitive + + if expr.GetLeftValue().GetPrimitive() == nil { + if nodeInputs == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()] + if lValue == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + } else { + lPrim = expr.GetLeftValue().GetPrimitive() + } + + if expr.GetRightValue().GetPrimitive() == nil { + if nodeInputs == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()] + if rValue == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar()) + } + } else { + rPrim = expr.GetRightValue().GetPrimitive() + } + + if lValue != nil && rValue != nil { + return EvaluateLiterals(lValue, rValue, expr.GetOperator()) + } + if lValue != nil && rPrim != nil { + return Evaluate2(lValue, rPrim, expr.GetOperator()) + } + if lPrim != nil && rValue != nil { + return Evaluate1(lPrim, rValue, expr.GetOperator()) + } + return Evaluate(lPrim, rPrim, expr.GetOperator()) +} + +func EvaluateBooleanExpression(expr *core.BooleanExpression, nodeInputs *handler.Data) (bool, error) { + if expr.GetComparison() != nil { + return EvaluateComparison(expr.GetComparison(), nodeInputs) + } + if expr.GetConjunction() == nil { + return false, regErrors.Errorf("No Comparison or Conjunction found in Branch node expression.") + } + lvalue, err := EvaluateBooleanExpression(expr.GetConjunction().GetLeftExpression(), nodeInputs) + if err != nil { + return false, err + } + rvalue, err := EvaluateBooleanExpression(expr.GetConjunction().GetRightExpression(), nodeInputs) + if err != nil { + return false, err + } + if expr.GetConjunction().GetOperator() == core.ConjunctionExpression_OR { + return lvalue || rvalue, nil + } + return lvalue && rvalue, nil +} + +func EvaluateIfBlock(block v1alpha1.ExecutableIfBlock, nodeInputs *handler.Data, skippedNodeIds []*v1alpha1.NodeID) (*v1alpha1.NodeID, []*v1alpha1.NodeID, error) { + if ok, err := EvaluateBooleanExpression(block.GetCondition(), nodeInputs); err != nil { + return nil, skippedNodeIds, err + } else if ok { + // Set status to running + return block.GetThenNode(), skippedNodeIds, err + } + // This branch is not taken + return nil, append(skippedNodeIds, block.GetThenNode()), nil +} + +// Decides the branch to be taken, returns the nodeId of the selected node or an error +// The branchnode is marked as success. This is used by downstream node to determine if it can be executed +// All downstream nodes are marked as skipped +func DecideBranch(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, nodeID v1alpha1.NodeID, node v1alpha1.ExecutableBranchNode, nodeInputs *handler.Data) (*v1alpha1.NodeID, error) { + var selectedNodeID *v1alpha1.NodeID + var skippedNodeIds []*v1alpha1.NodeID + var err error + + selectedNodeID, skippedNodeIds, err = EvaluateIfBlock(node.GetIf(), nodeInputs, skippedNodeIds) + if err != nil { + return nil, err + } + + for _, block := range node.GetElseIf() { + if selectedNodeID != nil { + skippedNodeIds = append(skippedNodeIds, block.GetThenNode()) + } else { + selectedNodeID, skippedNodeIds, err = EvaluateIfBlock(block, nodeInputs, skippedNodeIds) + if err != nil { + return nil, err + } + } + } + if node.GetElse() != nil { + if selectedNodeID == nil { + selectedNodeID = node.GetElse() + } else { + skippedNodeIds = append(skippedNodeIds, node.GetElse()) + } + } + for _, nodeIDPtr := range skippedNodeIds { + skippedNodeID := *nodeIDPtr + n, ok := w.GetNode(skippedNodeID) + if !ok { + return nil, errors.Errorf(errors.DownstreamNodeNotFoundError, nodeID, "Downstream node [%v] not found", skippedNodeID) + } + nStatus := w.GetNodeExecutionStatus(n.GetID()) + logger.Infof(ctx, "Branch Setting Node[%v] status to Skipped!", skippedNodeID) + nStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.Now(), "Branch evaluated to false") + } + + if selectedNodeID == nil { + if node.GetElseFail() != nil { + return nil, errors.Errorf(errors.UserProvidedError, nodeID, node.GetElseFail().Message) + } + return nil, errors.Errorf(errors.NoBranchTakenError, nodeID, "No branch satisfied") + } + logger.Infof(ctx, "Branch Node[%v] selected!", *selectedNodeID) + return selectedNodeID, nil +} diff --git a/pkg/controller/nodes/branch/evaluator_test.go b/pkg/controller/nodes/branch/evaluator_test.go new file mode 100644 index 000000000..cab3672cb --- /dev/null +++ b/pkg/controller/nodes/branch/evaluator_test.go @@ -0,0 +1,667 @@ +package branch + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +// Creates a ComparisonExpression, comparing 2 literals +func getComparisonExpression(lV interface{}, op core.ComparisonExpression_Operator, rV interface{}) (*core.ComparisonExpression, *handler.Data) { + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "x", + }, + }, + Operator: op, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(lV), + "y": utils.MustMakePrimitiveLiteral(rV), + }, + } + return exp, inputs +} + +func createUnaryConjunction(l *core.ComparisonExpression, op core.ConjunctionExpression_LogicalOperator, r *core.ComparisonExpression) *core.ConjunctionExpression { + return &core.ConjunctionExpression{ + LeftExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: l, + }, + }, + Operator: op, + RightExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: r, + }, + }, + } +} + +func TestEvaluateComparison(t *testing.T) { + t.Run("ComparePrimitives", func(t *testing.T) { + // Compare primitives + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(2), + }, + }, + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("ComparePrimitiveAndLiteral", func(t *testing.T) { + // Compare lVal -> primitive and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "y": utils.MustMakePrimitiveLiteral(2), + }, + } + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("CompareLiteralAndPrimitive", func(t *testing.T) { + + // Compare lVal -> literal and rVal -> primitive + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "x", + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(2), + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + "y": utils.MustMakePrimitiveLiteral(3), + }, + } + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.False(t, v) + }) + + t.Run("CompareLiterals", func(t *testing.T) { + // Compare lVal -> literal and rVal -> literal + exp, inputs := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.True(t, v) + }) + + t.Run("CompareLiterals2", func(t *testing.T) { + // Compare lVal -> literal and rVal -> literal + exp, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("ComparePrimitiveAndLiteralNotFound", func(t *testing.T) { + // Compare lVal -> primitive and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{}, + } + _, err := EvaluateComparison(exp, inputs) + assert.Error(t, err) + + _, err = EvaluateComparison(exp, nil) + assert.Error(t, err) + }) + + t.Run("CompareLiteralNotFoundAndPrimitive", func(t *testing.T) { + // Compare lVal -> primitive and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{}, + } + _, err := EvaluateComparison(exp, inputs) + assert.Error(t, err) + + _, err = EvaluateComparison(exp, nil) + assert.Error(t, err) + }) + +} + +func TestEvaluateBooleanExpression(t *testing.T) { + { + // Simple comparison only + ce, inputs := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + exp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: ce, + }, + } + v, err := EvaluateBooleanExpression(exp, inputs) + assert.NoError(t, err) + assert.True(t, v) + } + { + // AND of 2 comparisons. Inputs are the same for both. + l, lInputs := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + exp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_AND, r), + }, + } + v, err := EvaluateBooleanExpression(exp, lInputs) + assert.NoError(t, err) + assert.False(t, v) + } + { + // OR of 2 comparisons + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + exp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_OR, r), + }, + } + v, err := EvaluateBooleanExpression(exp, inputs) + assert.NoError(t, err) + assert.True(t, v) + } + { + // Conjunction of comparison and a conjunction, AND + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + innerExp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_OR, r), + }, + } + + outerComparison := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "a", + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "b", + }, + }, + } + outerInputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "a": utils.MustMakePrimitiveLiteral(5), + "b": utils.MustMakePrimitiveLiteral(4), + }, + } + + outerExp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: &core.ConjunctionExpression{ + LeftExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: outerComparison, + }, + }, + Operator: core.ConjunctionExpression_AND, + RightExpression: innerExp, + }, + }, + } + + for k, v := range inputs.Literals { + outerInputs.Literals[k] = v + } + + v, err := EvaluateBooleanExpression(outerExp, outerInputs) + assert.NoError(t, err) + assert.True(t, v) + } +} + +func TestEvaluateIfBlock(t *testing.T) { + { + // AND of 2 comparisons + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + thenNode := "test" + block := &v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_AND, r), + }, + }, + }, + ThenNode: &thenNode, + } + + skippedNodeIds := make([]*v1alpha1.NodeID, 0) + accp, skippedNodeIds, err := EvaluateIfBlock(block, inputs, skippedNodeIds) + assert.NoError(t, err) + assert.Nil(t, accp) + assert.Equal(t, 1, len(skippedNodeIds)) + assert.Equal(t, "test", *skippedNodeIds[0]) + } + { + // OR of 2 comparisons + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + thenNode := "test" + block := &v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_OR, r), + }, + }, + }, + ThenNode: &thenNode, + } + + skippedNodeIds := make([]*v1alpha1.NodeID, 0) + accp, skippedNodeIds, err := EvaluateIfBlock(block, inputs, skippedNodeIds) + assert.NoError(t, err) + assert.NotNil(t, accp) + assert.Equal(t, "test", *accp) + assert.Equal(t, 0, len(skippedNodeIds)) + } +} + +func TestDecideBranch(t *testing.T) { + ctx := context.Background() + + t.Run("EmptyIfBlock", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{}, + }, + } + branchNode := &v1alpha1.BranchNodeSpec{} + b, err := DecideBranch(ctx, w, "n1", branchNode, nil) + assert.Error(t, err) + assert.Nil(t, b) + }) + + t.Run("MissingThenNode", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{}, + }, + } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: nil, + }, + } + b, err := DecideBranch(ctx, w, "n1", branchNode, inputs) + assert.Error(t, err) + assert.Nil(t, b) + assert.Equal(t, errors.NoBranchTakenError, err.(*errors.NodeError).Code) + }) + + t.Run("WithThenNode", func(t *testing.T) { + n1 := "n1" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + }, + }, + } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n1, + }, + } + b, err := DecideBranch(ctx, w, "n1", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n1, *b) + }) + + t.Run("RepeatedCondition", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n2, + }, + }, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n1, *b) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n2].GetPhase()) + assert.Nil(t, w.Status.NodeStatus[n1]) + }) + + t.Run("SecondCondition", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n2, *b) + assert.Nil(t, w.Status.NodeStatus[n2]) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n1].GetPhase()) + }) + + t.Run("ElseCase", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + n3 := "n3" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + Else: &n3, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n3, *b) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n1].GetPhase()) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n2].GetPhase()) + }) + + t.Run("MissingNode", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + n3 := "n3" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + Else: &n3, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.Error(t, err) + assert.Nil(t, b) + assert.Equal(t, errors.DownstreamNodeNotFoundError, err.(*errors.NodeError).Code) + }) + + t.Run("ElseFailCase", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + userError := "User error" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + ElseFail: &v1alpha1.Error{ + Error: &core.Error{ + Message: userError, + }, + }, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.Error(t, err) + assert.Nil(t, b) + assert.Equal(t, errors.UserProvidedError, err.(*errors.NodeError).Code) + assert.Equal(t, userError, err.(*errors.NodeError).Message) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n1].GetPhase()) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n2].GetPhase()) + }) +} diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go new file mode 100644 index 000000000..f3531771e --- /dev/null +++ b/pkg/controller/nodes/branch/handler.go @@ -0,0 +1,135 @@ +package branch + +import ( + "context" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +type branchHandler struct { + nodeExecutor executors.Node + recorder events.NodeEventRecorder +} + +func (b *branchHandler) recurseDownstream(ctx context.Context, w v1alpha1.ExecutableWorkflow, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Status, error) { + downstreamStatus, err := b.nodeExecutor.RecursiveNodeHandler(ctx, w, branchTakenNode) + if err != nil { + return handler.StatusUndefined, err + } + + if downstreamStatus.IsComplete() { + // For branch node we set the output node to be the same as the child nodes output + childNodeStatus := w.GetNodeExecutionStatus(branchTakenNode.GetID()) + nodeStatus.SetDataDir(childNodeStatus.GetDataDir()) + return handler.StatusSuccess, nil + } + + if downstreamStatus.HasFailed() { + return handler.StatusFailed(downstreamStatus.Err), nil + } + + return handler.StatusRunning, nil +} + +func (b *branchHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + logger.Debugf(ctx, "Starting Branch Node") + branch := node.GetBranchNode() + if branch == nil { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "Invoked branch handler, for a non branch node.")), nil + } + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + branchStatus := nodeStatus.GetOrCreateBranchStatus() + finalNode, err := DecideBranch(ctx, w, node.GetID(), branch, nodeInputs) + if err != nil { + branchStatus.SetBranchNodeError() + logger.Debugf(ctx, "Branch evaluation failed. Error [%s]", err) + return handler.StatusFailed(err), nil + } + branchStatus.SetBranchNodeSuccess(*finalNode) + var ok bool + childNode, ok := w.GetNode(*finalNode) + if !ok { + logger.Debugf(ctx, "Branch downstream finalized node not found. FinalizedNode [%s]", *finalNode) + return handler.StatusFailed(errors.Errorf(errors.DownstreamNodeNotFoundError, w.GetID(), node.GetID(), "Downstream node [%v] not found", *finalNode)), nil + } + i := node.GetID() + childNodeStatus := w.GetNodeExecutionStatus(childNode.GetID()) + childNodeStatus.SetParentNodeID(&i) + + logger.Debugf(ctx, "Recursing down branch node") + return b.recurseDownstream(ctx, w, nodeStatus, childNode) +} + +func (b *branchHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + branch := node.GetBranchNode() + if branch == nil { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "Invoked branch handler, for a non branch node.")), nil + } + // If the branch was already evaluated i.e, Node is in Running status + branchStatus := nodeStatus.GetOrCreateBranchStatus() + userError := branch.GetElseFail() + finalNodeID := branchStatus.GetFinalizedNode() + if finalNodeID == nil { + if userError != nil { + // We should never reach here, but for safety and completeness + return handler.StatusFailed(errors.Errorf(errors.UserProvidedError, w.GetID(), node.GetID(), userError.Message)), nil + } + return handler.StatusRunning, errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "No node finalized through previous branch evaluation.") + } + var ok bool + branchTakenNode, ok := w.GetNode(*finalNodeID) + if !ok { + return handler.StatusFailed(errors.Errorf(errors.DownstreamNodeNotFoundError, w.GetID(), node.GetID(), "Downstream node [%v] not found", *finalNodeID)), nil + } + // Recurse downstream + return b.recurseDownstream(ctx, w, nodeStatus, branchTakenNode) +} + +func (b *branchHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "A Branch node cannot enter a failing state")), nil +} + +func (b *branchHandler) Initialize(ctx context.Context) error { + return nil +} + +func (b *branchHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + branch := node.GetBranchNode() + if branch == nil { + return errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "Invoked branch handler, for a non branch node.") + } + // If the branch was already evaluated i.e, Node is in Running status + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + branchStatus := nodeStatus.GetOrCreateBranchStatus() + userError := branch.GetElseFail() + finalNodeID := branchStatus.GetFinalizedNode() + if finalNodeID == nil { + if userError != nil { + // We should never reach here, but for safety and completeness + return errors.Errorf(errors.UserProvidedError, w.GetID(), node.GetID(), userError.Message) + } + return errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "No node finalized through previous branch evaluation.") + } + var ok bool + branchTakenNode, ok := w.GetNode(*finalNodeID) + + if !ok { + return errors.Errorf(errors.DownstreamNodeNotFoundError, w.GetID(), node.GetID(), "Downstream node [%v] not found", *finalNodeID) + } + // Recurse downstream + return b.nodeExecutor.AbortHandler(ctx, w, branchTakenNode) +} + +func New(executor executors.Node, eventSink events.EventSink, scope promutils.Scope) handler.IFace { + branchScope := scope.NewSubScope("branch") + return &branchHandler{ + nodeExecutor: executor, + recorder: events.NewNodeEventRecorder(eventSink, branchScope), + } +} diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go new file mode 100644 index 000000000..c64cf96a3 --- /dev/null +++ b/pkg/controller/nodes/branch/handler_test.go @@ -0,0 +1,236 @@ +package branch + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/stretchr/testify/assert" +) + +type recursiveNodeHandlerFn func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) +type abortNodeHandlerCbFn func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error + +type mockNodeExecutor struct { + executors.Node + RecursiveNodeHandlerCB recursiveNodeHandlerFn + AbortNodeHandlerCB abortNodeHandlerCbFn +} + +func (m *mockNodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + return m.RecursiveNodeHandlerCB(ctx, w, currentNode) +} + +func (m *mockNodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { + return m.AbortNodeHandlerCB(ctx, w, currentNode) +} + +func TestBranchHandler_RecurseDownstream(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()).(*branchHandler) + childNodeID := "child" + childDatadir := v1alpha1.DataReference("test") + w := &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + childNodeID: { + DataDir: childDatadir, + }, + }, + }, + } + expectedError := fmt.Errorf("error") + + recursiveNodeHandlerFnArchetype := func(status executors.NodeStatus, err error) recursiveNodeHandlerFn { + return func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + return status, err + } + } + + tests := []struct { + name string + recursiveNodeHandlerFn recursiveNodeHandlerFn + nodeStatus v1alpha1.ExecutableNodeStatus + branchTakenNode v1alpha1.ExecutableNode + isErr bool + expectedStatus handler.Status + }{ + {"childNodeError", recursiveNodeHandlerFnArchetype(executors.NodeStatusUndefined, expectedError), + nil, &v1alpha1.NodeSpec{}, true, handler.StatusUndefined}, + {"childPending", recursiveNodeHandlerFnArchetype(executors.NodeStatusPending, nil), + nil, &v1alpha1.NodeSpec{}, false, handler.StatusRunning}, + {"childStillRunning", recursiveNodeHandlerFnArchetype(executors.NodeStatusRunning, nil), + nil, &v1alpha1.NodeSpec{}, false, handler.StatusRunning}, + {"childFailure", recursiveNodeHandlerFnArchetype(executors.NodeStatusFailed(expectedError), nil), + nil, &v1alpha1.NodeSpec{}, false, handler.StatusFailed(expectedError)}, + {"childComplete", recursiveNodeHandlerFnArchetype(executors.NodeStatusComplete, nil), + &v1alpha1.NodeStatus{}, &v1alpha1.NodeSpec{ID: childNodeID}, false, handler.StatusSuccess}, + {"childCompleteNoStatus", recursiveNodeHandlerFnArchetype(executors.NodeStatusComplete, nil), + &v1alpha1.NodeStatus{}, &v1alpha1.NodeSpec{ID: "deadbeef"}, false, handler.StatusSuccess}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + m.RecursiveNodeHandlerCB = test.recursiveNodeHandlerFn + h, err := branch.recurseDownstream(ctx, w, test.nodeStatus, test.branchTakenNode) + if test.isErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedStatus, h) + if test.nodeStatus != nil { + assert.Equal(t, w.GetNodeExecutionStatus(test.branchTakenNode.GetID()).GetDataDir(), test.nodeStatus.GetDataDir()) + } + }) + } +} + +func TestBranchHandler_AbortNode(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()) + b1 := "b1" + n1 := "n1" + n2 := "n2" + + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "test", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + b1: { + Phase: v1alpha1.NodePhaseRunning, + BranchStatus: &v1alpha1.BranchNodeStatus{ + FinalizedNodeID: &n1, + }, + }, + }, + }, + } + exp, _ := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + + branchNode := &v1alpha1.BranchNodeSpec{ + + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n2, + }, + }, + } + + t.Run("NoBranchNode", func(t *testing.T) { + + err := branch.AbortNode(ctx, w, &v1alpha1.NodeSpec{}) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.IllegalStateError)) + }) + + t.Run("BranchNodeNoEval", func(t *testing.T) { + + err := branch.AbortNode(ctx, w, &v1alpha1.NodeSpec{ + BranchNode: branchNode}) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.IllegalStateError)) + }) + + t.Run("BranchNodeSuccess", func(t *testing.T) { + m.AbortNodeHandlerCB = func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { + assert.Equal(t, n1, currentNode.GetID()) + return nil + } + err := branch.AbortNode(ctx, w, &v1alpha1.NodeSpec{ + ID: b1, + BranchNode: branchNode}) + assert.NoError(t, err) + }) +} + +func TestBranchHandler_Initialize(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()) + assert.NoError(t, branch.Initialize(ctx)) +} + +// TODO incomplete test suite, add more +func TestBranchHandler_StartNode(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()) + childNodeID := "child" + childDatadir := v1alpha1.DataReference("test") + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "test", + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + childNodeID: { + DataDir: childDatadir, + }, + }, + }, + } + _, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + tests := []struct { + name string + node v1alpha1.ExecutableNode + isErr bool + expectedStatus handler.Status + }{ + {"NoBranchNode", &v1alpha1.NodeSpec{}, false, handler.StatusFailed(nil)}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, err := branch.StartNode(ctx, w, test.node, inputs) + if test.isErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedStatus.Phase, s.Phase) + + }) + } +} + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} diff --git a/pkg/controller/nodes/common/output_resolver.go b/pkg/controller/nodes/common/output_resolver.go new file mode 100644 index 000000000..b356ab776 --- /dev/null +++ b/pkg/controller/nodes/common/output_resolver.go @@ -0,0 +1,62 @@ +package common + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +) + +func CreateAliasMap(aliases []v1alpha1.Alias) map[string]string { + aliasToVarMap := make(map[string]string, len(aliases)) + for _, alias := range aliases { + aliasToVarMap[alias.GetAlias()] = alias.GetVar() + } + return aliasToVarMap +} + +// A simple output resolver that expects an outputs.pb at the data directory of the node. +type SimpleOutputsResolver struct { + store storage.ProtobufStore +} + +func (r SimpleOutputsResolver) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar handler.VarName) (values *core.Literal, err error) { + d := &handler.Data{} + nodeStatus := w.GetNodeExecutionStatus(n.GetID()) + outputsFileRef := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + if err := r.store.ReadProtobuf(ctx, outputsFileRef, d); err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, "Failed to GetPrevious data from dataDir [%v]", nodeStatus.GetDataDir()) + } + + if d.Literals == nil { + return nil, errors.Errorf(errors.OutputsNotFoundError, n.GetID(), + "Outputs not found at [%v]", outputsFileRef) + } + + aliasMap := CreateAliasMap(n.GetOutputAlias()) + if variable, ok := aliasMap[bindToVar]; ok { + logger.Debugf(ctx, "Mapping [%v].[%v] -> [%v].[%v]", n.GetID(), variable, n.GetID(), bindToVar) + bindToVar = variable + } + + l, ok := d.Literals[bindToVar] + if !ok { + return nil, errors.Errorf(errors.OutputsNotFoundError, n.GetID(), + "Failed to find [%v].[%v]", n.GetID(), bindToVar) + } + + return l, nil +} + +// Creates a simple output resolver that expects an outputs.pb at the data directory of the node. +func NewSimpleOutputsResolver(store storage.ProtobufStore) SimpleOutputsResolver { + return SimpleOutputsResolver{ + store: store, + } +} diff --git a/pkg/controller/nodes/common/output_resolver_test.go b/pkg/controller/nodes/common/output_resolver_test.go new file mode 100644 index 000000000..42eaa76d6 --- /dev/null +++ b/pkg/controller/nodes/common/output_resolver_test.go @@ -0,0 +1,30 @@ +package common + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestCreateAliasMap(t *testing.T) { + { + aliases := []v1alpha1.Alias{ + {Alias: core.Alias{Var: "x", Alias: "y"}}, + } + m := CreateAliasMap(aliases) + assert.Equal(t, map[string]string{ + "y": "x", + }, m) + } + { + var aliases []v1alpha1.Alias + m := CreateAliasMap(aliases) + assert.Equal(t, map[string]string{}, m) + } + { + m := CreateAliasMap(nil) + assert.Equal(t, map[string]string{}, m) + } +} diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go new file mode 100644 index 000000000..fe8af1cb3 --- /dev/null +++ b/pkg/controller/nodes/dynamic/handler.go @@ -0,0 +1,391 @@ +package dynamic + +import ( + "context" + "time" + + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytepropeller/pkg/compiler" + common2 "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/common" + + "github.com/lyft/flytestdlib/promutils" + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +) + +type dynamicNodeHandler struct { + handler.IFace + metrics metrics + simpleResolver common.SimpleOutputsResolver + store *storage.DataStore + nodeExecutor executors.Node + enQWorkflow v1alpha1.EnqueueWorkflow +} + +type metrics struct { + buildDynamicWorkflow labeled.StopWatch + retrieveDynamicJobSpec labeled.StopWatch +} + +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + buildDynamicWorkflow: labeled.NewStopWatch("build_dynamic_workflow", "Overhead for building a dynamic workflow in memory.", time.Microsecond, scope), + retrieveDynamicJobSpec: labeled.NewStopWatch("retrieve_dynamic_spec", "Overhead of downloading and unmarshaling dynamic job spec", time.Microsecond, scope), + } +} + +func (e dynamicNodeHandler) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar handler.VarName) (values *core.Literal, err error) { + outputResolver, casted := e.IFace.(handler.OutputResolver) + if !casted { + return e.simpleResolver.ExtractOutput(ctx, w, n, bindToVar) + } + + return outputResolver.ExtractOutput(ctx, w, n, bindToVar) +} + +func (e dynamicNodeHandler) getDynamicJobSpec(ctx context.Context, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (*core.DynamicJobSpec, error) { + t := e.metrics.retrieveDynamicJobSpec.Start(ctx) + defer t.Stop() + + futuresFilePath, err := e.store.ConstructReference(ctx, nodeStatus.GetDataDir(), v1alpha1.GetFutureFile()) + if err != nil { + logger.Warnf(ctx, "Failed to construct data path for futures file. Error: %v", err) + return nil, err + } + + // If no futures file produced, then declare success and return. + if metadata, err := e.store.Head(ctx, futuresFilePath); err != nil { + logger.Warnf(ctx, "Failed to read futures file. Error: %v", err) + return nil, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to do HEAD on futures file.") + } else if !metadata.Exists() { + return nil, nil + } + + djSpec := &core.DynamicJobSpec{} + if err := e.store.ReadProtobuf(ctx, futuresFilePath, djSpec); err != nil { + logger.Warnf(ctx, "Failed to read futures file. Error: %v", err) + return nil, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to read futures protobuf file.") + } + + return djSpec, nil +} + +func (e dynamicNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, + w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) ( + *core.WorkflowTemplate, error) { + + iface, err := underlyingInterface(w, node) + if err != nil { + return nil, err + } + + // Modify node IDs to include lineage, the entire system assumes node IDs are unique per parent WF. + // We keep track of the original node ids because that's where inputs are written to. + parentNodeID := node.GetID() + for _, n := range djSpec.Nodes { + newID, err := hierarchicalNodeID(parentNodeID, n.Id) + if err != nil { + return nil, err + } + + // Instantiate a nodeStatus using the modified name but set its data directory using the original name. + subNodeStatus := nodeStatus.GetNodeExecutionStatus(newID) + originalNodePath, err := e.store.ConstructReference(ctx, nodeStatus.GetDataDir(), n.Id) + if err != nil { + return nil, err + } + + subNodeStatus.SetDataDir(originalNodePath) + subNodeStatus.ResetDirty() + + n.Id = newID + } + + if node.GetTaskID() != nil { + // If the parent is a task, pass down data children nodes should inherit. + parentTask, err := w.GetTask(*node.GetTaskID()) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to find task [%v].", node.GetTaskID()) + } + + for _, t := range djSpec.Tasks { + if t.GetContainer() != nil && parentTask.CoreTask().GetContainer() != nil { + t.GetContainer().Config = append(t.GetContainer().Config, parentTask.CoreTask().GetContainer().Config...) + } + } + } + + for _, o := range djSpec.Outputs { + err = updateBindingNodeIDsWithLineage(parentNodeID, o.Binding) + if err != nil { + return nil, err + } + } + + return &core.WorkflowTemplate{ + Id: &core.Identifier{ + Project: w.GetExecutionID().Project, + Domain: w.GetExecutionID().Domain, + Version: rand.String(10), + Name: rand.String(10), + ResourceType: core.ResourceType_WORKFLOW, + }, + Nodes: djSpec.Nodes, + Outputs: djSpec.Outputs, + Interface: iface, + }, nil +} + +// For any node that is not in a NEW/READY state in the recording, CheckNodeStatus will be invoked. The implementation should handle +// idempotency and return the current observed state of the node +func (e dynamicNodeHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, + previousNodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + + var status handler.Status + var err error + switch previousNodeStatus.GetOrCreateDynamicNodeStatus().GetDynamicNodePhase() { + case v1alpha1.DynamicNodePhaseExecuting: + // If the node succeeded, check if it generated a futures.pb file to execute. + dynamicWF, nStatus, _, err := e.buildContextualDynamicWorkflow(ctx, w, node, previousNodeStatus) + if err != nil { + return handler.StatusFailed(err), nil + } + + s, err := e.progressDynamicWorkflow(ctx, nStatus, dynamicWF) + if err == nil && s == handler.StatusSuccess { + // After the dynamic node completes we need to copy the outputs from its end nodes to the parent nodes status + endNode := dynamicWF.GetNodeExecutionStatus(v1alpha1.EndNodeID) + outputPath := v1alpha1.GetOutputsFile(endNode.GetDataDir()) + destinationPath := v1alpha1.GetOutputsFile(previousNodeStatus.GetDataDir()) + logger.Infof(ctx, "Dynamic workflow completed, copying outputs from the end-node [%s] to the parent node data dir [%s]", outputPath, destinationPath) + if err := e.store.CopyRaw(ctx, outputPath, destinationPath, storage.Options{}); err != nil { + logger.Errorf(ctx, "Failed to copy outputs from dynamic sub-wf [%s] to [%s]. Error: %s", outputPath, destinationPath, err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.StorageError, node.GetID(), err, "Failed to copy outputs from dynamic sub-wf [%s] to [%s]. Error: %s", outputPath, destinationPath, err.Error()) + } + if successHandler, ok := e.IFace.(handler.PostNodeSuccessHandler); ok { + return successHandler.HandleNodeSuccess(ctx, w, node) + } + logger.Warnf(ctx, "Bad configuration for dynamic node, no post node success handler found!") + } + return s, err + default: + // Invoke the underlying check node status. + status, err = e.IFace.CheckNodeStatus(ctx, w, node, previousNodeStatus) + + if err != nil { + return status, err + } + + if status.Phase != handler.PhaseSuccess { + return status, err + } + + // If the node succeeded, check if it generated a futures.pb file to execute. + _, _, isDynamic, err := e.buildContextualDynamicWorkflow(ctx, w, node, previousNodeStatus) + if err != nil { + return handler.StatusFailed(err), nil + } + + if !isDynamic { + if successHandler, ok := e.IFace.(handler.PostNodeSuccessHandler); ok { + return successHandler.HandleNodeSuccess(ctx, w, node) + } + logger.Warnf(ctx, "Bad configuration for dynamic node, no post node success handler found!") + return status, err + } + + // Mark the node as a dynamic node executing its child nodes. Next time check node status is called, it'll go + // directly to progress the dynamically generated workflow. + previousNodeStatus.GetOrCreateDynamicNodeStatus().SetDynamicNodePhase(v1alpha1.DynamicNodePhaseExecuting) + + return handler.StatusRunning, nil + } +} + +func (e dynamicNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, + previousNodeStatus v1alpha1.ExecutableNodeStatus) (dynamicWf v1alpha1.ExecutableWorkflow, status v1alpha1.ExecutableNodeStatus, isDynamic bool, err error) { + + t := e.metrics.buildDynamicWorkflow.Start(ctx) + defer t.Stop() + + var nStatus v1alpha1.ExecutableNodeStatus + // We will only get here if the Phase is success. The downside is that this is an overhead for all nodes that are + // not dynamic. But given that we will only check once, it should be ok. + // TODO: Check for node.is_dynamic once the IDL changes are in and SDK migration has happened. + djSpec, err := e.getDynamicJobSpec(ctx, node, previousNodeStatus) + if err != nil { + return nil, nil, false, err + } + + if djSpec == nil { + return nil, status, false, nil + } + + rootNodeStatus := w.GetNodeExecutionStatus(node.GetID()) + if node.GetTaskID() != nil { + // TODO: This is a hack to set parent task execution id, we should move to node-node relationship. + execID, err := e.getTaskExecutionIdentifier(ctx, w, node) + if err != nil { + return nil, nil, false, err + } + + dynamicNode := &v1alpha1.NodeSpec{ + ID: "dynamic-node", + } + + nStatus = rootNodeStatus.GetNodeExecutionStatus(dynamicNode.GetID()) + nStatus.SetDataDir(rootNodeStatus.GetDataDir()) + nStatus.SetParentTaskID(execID) + } else { + nStatus = w.GetNodeExecutionStatus(node.GetID()) + } + + var closure *core.CompiledWorkflowClosure + wf, err := e.buildDynamicWorkflowTemplate(ctx, djSpec, w, node, nStatus) + if err != nil { + return nil, nil, true, err + } + + compiledTasks, err := compileTasks(ctx, djSpec.Tasks) + if err != nil { + return nil, nil, true, err + } + + // TODO: This will currently fail if the WF references any launch plans + closure, err = compiler.CompileWorkflow(wf, djSpec.Subworkflows, compiledTasks, []common2.InterfaceProvider{}) + if err != nil { + return nil, nil, true, err + } + + subwf, err := k8s.BuildFlyteWorkflow(closure, nil, nil, "") + if err != nil { + return nil, nil, true, err + } + + return newContextualWorkflow(w, subwf, nStatus, subwf.Tasks, subwf.SubWorkflows), nStatus, true, nil +} + +func (e dynamicNodeHandler) progressDynamicWorkflow(ctx context.Context, parentNodeStatus v1alpha1.ExecutableNodeStatus, + w v1alpha1.ExecutableWorkflow) (handler.Status, error) { + + state, err := e.nodeExecutor.RecursiveNodeHandler(ctx, w, w.StartNode()) + if err != nil { + return handler.StatusUndefined, err + } + + if state.HasFailed() { + if w.GetOnFailureNode() != nil { + return handler.StatusFailing(state.Err), nil + } + return handler.StatusFailed(state.Err), nil + } + + if state.IsComplete() { + nodeID := "" + if parentNodeStatus.GetParentNodeID() != nil { + nodeID = *parentNodeStatus.GetParentNodeID() + } + + // If the WF interface has outputs, validate that the outputs file was written. + if outputBindings := w.GetOutputBindings(); len(outputBindings) > 0 { + endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + if endNodeStatus == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "No end node found in subworkflow.")), nil + } + + sourcePath := v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir()) + if metadata, err := e.store.Head(ctx, sourcePath); err == nil { + if !metadata.Exists() { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "Subworkflow is expected to produce outputs but no outputs file was written to %v.", + sourcePath)), nil + } + } else { + return handler.StatusUndefined, err + } + + destinationPath := v1alpha1.GetOutputsFile(parentNodeStatus.GetDataDir()) + if err := e.store.CopyRaw(ctx, sourcePath, destinationPath, storage.Options{}); err != nil { + return handler.StatusFailed(errors.Wrapf(errors.OutputsNotFoundError, nodeID, + err, "Failed to copy subworkflow outputs from [%v] to [%v]", + sourcePath, destinationPath)), nil + } + } + + return handler.StatusSuccess, nil + } + + if state.PartiallyComplete() { + // Re-enqueue the workflow + e.enQWorkflow(w.GetK8sWorkflowID().String()) + } + + return handler.StatusRunning, nil +} + +func (e dynamicNodeHandler) getTaskExecutionIdentifier(_ context.Context, w v1alpha1.ExecutableWorkflow, + node v1alpha1.ExecutableNode) (*core.TaskExecutionIdentifier, error) { + + taskID := node.GetTaskID() + task, err := w.GetTask(*taskID) + if err != nil { + return nil, errors.Wrapf(errors.BadSpecificationError, node.GetID(), err, "Unable to find task for taskId: [%v]", *taskID) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + return &core.TaskExecutionIdentifier{ + TaskId: task.CoreTask().Id, + RetryAttempt: nodeStatus.GetAttempts(), + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, + }, + }, nil +} + +func (e dynamicNodeHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + + previousNodeStatus := w.GetNodeExecutionStatus(node.GetID()) + switch previousNodeStatus.GetOrCreateDynamicNodeStatus().GetDynamicNodePhase() { + case v1alpha1.DynamicNodePhaseExecuting: + dynamicWF, _, isDynamic, err := e.buildContextualDynamicWorkflow(ctx, w, node, previousNodeStatus) + if err != nil { + return err + } + + if !isDynamic { + return nil + } + + return e.nodeExecutor.AbortHandler(ctx, dynamicWF, dynamicWF.StartNode()) + default: + // Invoke the underlying abort node. + return e.IFace.AbortNode(ctx, w, node) + } +} + +func New(underlying handler.IFace, nodeExecutor executors.Node, enQWorkflow v1alpha1.EnqueueWorkflow, store *storage.DataStore, + scope promutils.Scope) handler.IFace { + + return dynamicNodeHandler{ + IFace: underlying, + metrics: newMetrics(scope), + nodeExecutor: nodeExecutor, + enQWorkflow: enQWorkflow, + store: store, + } +} diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go new file mode 100644 index 000000000..f96e157b9 --- /dev/null +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -0,0 +1,261 @@ +package dynamic + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + pluginsV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + typesV1 "k8s.io/api/core/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" +) + +const DataDir = storage.DataReference("test-data") +const NodeID = "n1" + +var ( + enqueueWfFunc = func(id string) {} + fakeKubeClient = mocks2.NewFakeKubeClient() +) + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +func createTask(id string, ttype string, discoverable bool) *v1alpha1.TaskSpec { + return &v1alpha1.TaskSpec{ + TaskTemplate: &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: ttype, + Metadata: &core.TaskMetadata{Discoverable: discoverable}, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + }, + } +} + +func mockCatalogClient() catalog.Client { + return &catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } +} + +func createWf(id string, execID string, project string, domain string, name string) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + ExecutionID: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: execID, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + NodeID: { + DataDir: DataDir, + }, + }, + }, + ObjectMeta: v1.ObjectMeta{ + Name: name, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: id, + }, + } +} + +func createStartNode() *v1alpha1.NodeSpec { + return &v1alpha1.NodeSpec{ + ID: NodeID, + Kind: v1alpha1.NodeKindStart, + Resources: &typesV1.ResourceRequirements{ + Requests: typesV1.ResourceList{ + typesV1.ResourceCPU: resource.MustParse("1"), + }, + }, + } +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func TestTaskHandler_CheckNodeStatusDiscovery(t *testing.T) { + ctx := context.Background() + + taskID := "t1" + tk := createTask(taskID, "container", true) + tk.Id.Project = "flytekit" + w := createWf("w1", "w2-exec", "projTest", "domainTest", "checkNodeTestName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: tk, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("TaskExecDoneDiscoveryWriteFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == tk.Type { + return taskExec, nil + } + return nil, fmt.Errorf("no match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return status.Errorf(codes.DeadlineExceeded, "") + }, + } + store := createInmemoryDataStore(t, promutils.NewTestScope()) + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(1) + err1 := store.WriteProtobuf(ctx, "test-data/inputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + err2 := store.WriteProtobuf(ctx, "test-data/outputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + assert.NoError(t, err1) + assert.NoError(t, err2) + + th := New( + task.NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()), + nil, + enqueueWfFunc, + store, + promutils.NewTestScope(), + ) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) + + t.Run("TaskExecDoneDiscoveryMissingOutputs", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == tk.Type { + return taskExec, nil + } + return nil, fmt.Errorf("no match") + }, + } + store := createInmemoryDataStore(t, promutils.NewTestScope()) + th := New( + task.NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()), + nil, + enqueueWfFunc, + store, + promutils.NewTestScope(), + ) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseRetryableFailure, s.Phase, "received: %s", s.Phase.String()) + }) + + t.Run("TaskExecDoneDiscoveryWriteSuccess", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == tk.Type { + return taskExec, nil + } + return nil, fmt.Errorf("no match") + }, + } + store := createInmemoryDataStore(t, promutils.NewTestScope()) + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + err1 := store.WriteProtobuf(ctx, "test-data/inputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + err2 := store.WriteProtobuf(ctx, "test-data/outputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + assert.NoError(t, err1) + assert.NoError(t, err2) + th := New( + task.NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()), + nil, + enqueueWfFunc, + store, + promutils.NewTestScope(), + ) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} diff --git a/pkg/controller/nodes/dynamic/subworkflow.go b/pkg/controller/nodes/dynamic/subworkflow.go new file mode 100644 index 000000000..3869a5c21 --- /dev/null +++ b/pkg/controller/nodes/dynamic/subworkflow.go @@ -0,0 +1,89 @@ +package dynamic + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/storage" +) + +// Defines a sub-contextual workflow that is built in-memory to represent a dynamic job execution plan. +type contextualWorkflow struct { + v1alpha1.ExecutableWorkflow + + extraTasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec + extraWorkflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec + status *ContextualWorkflowStatus +} + +func newContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, + subwf v1alpha1.ExecutableSubWorkflow, + status v1alpha1.ExecutableNodeStatus, + tasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec, + workflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec) v1alpha1.ExecutableWorkflow { + + return &contextualWorkflow{ + ExecutableWorkflow: executors.NewSubContextualWorkflow(baseWorkflow, subwf, status), + extraTasks: tasks, + extraWorkflows: workflows, + status: newContextualWorkflowStatus(baseWorkflow.GetExecutionStatus(), status), + } +} + +func (w contextualWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + return w.status +} + +func (w contextualWorkflow) GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) { + if task, found := w.extraTasks[id]; found { + return task, nil + } + + return w.ExecutableWorkflow.GetTask(id) +} + +func (w contextualWorkflow) FindSubWorkflow(id v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow { + if wf, found := w.extraWorkflows[id]; found { + return wf + } + + return w.ExecutableWorkflow.FindSubWorkflow(id) +} + +// A contextual workflow status to override some of the implementations. +type ContextualWorkflowStatus struct { + v1alpha1.ExecutableWorkflowStatus + baseStatus v1alpha1.ExecutableNodeStatus +} + +func (w ContextualWorkflowStatus) GetDataDir() v1alpha1.DataReference { + return w.baseStatus.GetDataDir() +} + +// Overrides default node data dir to work around the contractual assumption between Propeller and Futures to write all +// sub-node inputs into current node data directory. +// E.g. +// if current node data dir is /wf_exec/node-1/data/ +// and the task ran and yielded 2 nodes, the structure will look like this: +// /wf_exec/node-1/data/ +// |_ inputs.pb +// |_ futures.pb +// |_ sub-node1/inputs.pb +// |_ sub-node2/inputs.pb +// TODO: This is just a stop-gap until we transition the DynamicJobSpec to be a full-fledged workflow spec. +// TODO: this will allow us to have proper data bindings between nodes then we can stop making assumptions about data refs. +func (w ContextualWorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, + name v1alpha1.NodeID) (storage.DataReference, error) { + return constructor.ConstructReference(ctx, w.GetDataDir(), name) +} + +func newContextualWorkflowStatus(baseWfStatus v1alpha1.ExecutableWorkflowStatus, + baseStatus v1alpha1.ExecutableNodeStatus) *ContextualWorkflowStatus { + + return &ContextualWorkflowStatus{ + ExecutableWorkflowStatus: baseWfStatus, + baseStatus: baseStatus, + } +} diff --git a/pkg/controller/nodes/dynamic/subworkflow_test.go b/pkg/controller/nodes/dynamic/subworkflow_test.go new file mode 100644 index 000000000..9433b6ee5 --- /dev/null +++ b/pkg/controller/nodes/dynamic/subworkflow_test.go @@ -0,0 +1,52 @@ +package dynamic + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" +) + +func TestNewContextualWorkflow(t *testing.T) { + wf := &mocks.ExecutableWorkflow{} + calledBase := false + wf.On("GetAnnotations").Return(map[string]string{}).Run(func(_ mock.Arguments) { + calledBase = true + }) + + wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) + + subwf := &mocks.ExecutableSubWorkflow{} + cWF := newContextualWorkflow(wf, subwf, nil, nil, nil) + cWF.GetAnnotations() + + assert.True(t, calledBase) +} + +func TestConstructNodeDataDir(t *testing.T) { + wf := &mocks.ExecutableWorkflow{} + wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) + + wfStatus := &mocks.ExecutableWorkflowStatus{} + wfStatus.On("GetDataDir").Return(storage.DataReference("fk://wrong/")).Run(func(_ mock.Arguments) { + assert.FailNow(t, "Should call the override") + }) + + nodeStatus := &mocks.ExecutableNodeStatus{} + nodeStatus.On("GetDataDir").Return(storage.DataReference("fk://right/")) + + ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + cWF := newContextualWorkflowStatus(wfStatus, nodeStatus) + + dataDir, err := cWF.ConstructNodeDataDir(context.TODO(), ds, "my_node") + assert.NoError(t, err) + assert.NotNil(t, dataDir) + assert.Equal(t, "fk://right/my_node", dataDir.String()) +} diff --git a/pkg/controller/nodes/dynamic/utils.go b/pkg/controller/nodes/dynamic/utils.go new file mode 100644 index 000000000..42c9e081c --- /dev/null +++ b/pkg/controller/nodes/dynamic/utils.go @@ -0,0 +1,105 @@ +package dynamic + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/utils" +) + +// Constructs the expected interface of a given node. +func underlyingInterface(w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (*core.TypedInterface, error) { + iface := &core.TypedInterface{} + if node.GetTaskID() != nil { + t, err := w.GetTask(*node.GetTaskID()) + if err != nil { + // Should never happen + return nil, err + } + + iface.Outputs = t.CoreTask().GetInterface().Outputs + } else if wfNode := node.GetWorkflowNode(); wfNode != nil { + if wfRef := wfNode.GetSubWorkflowRef(); wfRef != nil { + t := w.FindSubWorkflow(*wfRef) + if t == nil { + // Should never happen + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Couldn't find subworkflow [%v].", wfRef) + } + + iface.Outputs = t.GetOutputs().VariableMap + } else { + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Unknown interface") + } + } else if node.GetBranchNode() != nil { + if ifBlock := node.GetBranchNode().GetIf(); ifBlock != nil && ifBlock.GetThenNode() != nil { + bn, found := w.GetNode(*ifBlock.GetThenNode()) + if !found { + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Couldn't find branch node [%v]", + *ifBlock.GetThenNode()) + } + + return underlyingInterface(w, bn) + } + + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Empty branch detected.") + } else { + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Unknown interface.") + } + + return iface, nil +} + +func hierarchicalNodeID(parentNodeID, nodeID string) (string, error) { + return utils.FixedLengthUniqueIDForParts(20, parentNodeID, nodeID) +} + +func updateBindingNodeIDsWithLineage(parentNodeID string, binding *core.BindingData) (err error) { + switch b := binding.Value.(type) { + case *core.BindingData_Promise: + b.Promise.NodeId, err = hierarchicalNodeID(parentNodeID, b.Promise.NodeId) + if err != nil { + return err + } + case *core.BindingData_Collection: + for _, item := range b.Collection.Bindings { + err = updateBindingNodeIDsWithLineage(parentNodeID, item) + if err != nil { + return err + } + } + case *core.BindingData_Map: + for _, item := range b.Map.Bindings { + err = updateBindingNodeIDsWithLineage(parentNodeID, item) + if err != nil { + return err + } + } + } + + return nil +} + +func compileTasks(_ context.Context, tasks []*core.TaskTemplate) ([]*core.CompiledTask, error) { + compiledTasks := make([]*core.CompiledTask, 0, len(tasks)) + visitedTasks := sets.NewString() + for _, t := range tasks { + if visitedTasks.Has(t.Id.String()) { + continue + } + + ct, err := compiler.CompileTask(t) + if err != nil { + return nil, err + } + + compiledTasks = append(compiledTasks, ct) + visitedTasks.Insert(t.Id.String()) + } + + return compiledTasks, nil +} diff --git a/pkg/controller/nodes/dynamic/utils_test.go b/pkg/controller/nodes/dynamic/utils_test.go new file mode 100644 index 000000000..e5df771a6 --- /dev/null +++ b/pkg/controller/nodes/dynamic/utils_test.go @@ -0,0 +1,77 @@ +package dynamic + +import ( + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/stretchr/testify/mock" + + "github.com/stretchr/testify/assert" +) + +func TestHierarchicalNodeID(t *testing.T) { + t.Run("empty parent", func(t *testing.T) { + actual, err := hierarchicalNodeID("", "abc") + assert.NoError(t, err) + assert.Equal(t, "-abc", actual) + }) + + t.Run("long result", func(t *testing.T) { + actual, err := hierarchicalNodeID("abcdefghijklmnopqrstuvwxyz", "abc") + assert.NoError(t, err) + assert.Equal(t, "fpa3kc3y", actual) + }) +} + +func TestUnderlyingInterface(t *testing.T) { + expectedIface := &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "in": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + } + wf := &mocks.ExecutableWorkflow{} + + subWF := &mocks.ExecutableSubWorkflow{} + wf.On("FindSubWorkflow", mock.Anything).Return(subWF) + subWF.On("GetOutputs").Return(&v1alpha1.OutputVarMap{VariableMap: expectedIface.Outputs}) + + task := &mocks.ExecutableTask{} + wf.On("GetTask", mock.Anything).Return(task, nil) + task.On("CoreTask").Return(&core.TaskTemplate{ + Interface: expectedIface, + }) + + n := &mocks.ExecutableNode{} + wf.On("GetNode", mock.Anything).Return(n) + emptyStr := "" + n.On("GetTaskID").Return(&emptyStr) + + iface, err := underlyingInterface(wf, n) + assert.NoError(t, err) + assert.NotNil(t, iface) + assert.Equal(t, expectedIface, iface) + + n = &mocks.ExecutableNode{} + n.On("GetTaskID").Return(nil) + + wfNode := &mocks.ExecutableWorkflowNode{} + n.On("GetWorkflowNode").Return(wfNode) + wfNode.On("GetSubWorkflowRef").Return(&emptyStr) + + iface, err = underlyingInterface(wf, n) + assert.NoError(t, err) + assert.NotNil(t, iface) + assert.Equal(t, expectedIface, iface) +} diff --git a/pkg/controller/nodes/end/handler.go b/pkg/controller/nodes/end/handler.go new file mode 100644 index 000000000..5c53d9de1 --- /dev/null +++ b/pkg/controller/nodes/end/handler.go @@ -0,0 +1,52 @@ +package end + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +type endHandler struct { + store storage.ProtobufStore +} + +func (e *endHandler) Initialize(ctx context.Context) error { + return nil +} + +func (e *endHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + if nodeInputs != nil { + logger.Debugf(ctx, "Workflow has outputs. Storing them.") + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + o := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + so := storage.Options{} + if err := e.store.WriteProtobuf(ctx, o, so, nodeInputs); err != nil { + logger.Errorf(ctx, "Failed to store workflow outputs. Error [%s]", err) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to store workflow outputs, as end-node") + } + } + logger.Debugf(ctx, "End node success") + return handler.StatusSuccess, nil +} + +func (e *endHandler) CheckNodeStatus(ctx context.Context, g v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + return handler.StatusSuccess, nil +} + +func (e *endHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, node.GetID(), "End node cannot enter a failing state")), nil +} + +func (e *endHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + return nil +} + +func New(store storage.ProtobufStore) handler.IFace { + return &endHandler{ + store: store, + } +} diff --git a/pkg/controller/nodes/end/handler_test.go b/pkg/controller/nodes/end/handler_test.go new file mode 100644 index 000000000..a5ab02c4b --- /dev/null +++ b/pkg/controller/nodes/end/handler_test.go @@ -0,0 +1,135 @@ +package end + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + flyteassert "github.com/lyft/flytepropeller/pkg/utils/assert" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + regErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +var testScope = promutils.NewScope("end_test") + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func init() { + labeled.SetMetricKeys(contextutils.NodeIDKey) +} + +type TestProtoDataStore struct { + ReadProtobufCb func(ctx context.Context, reference storage.DataReference, msg proto.Message) error + WriteProtobufCb func(ctx context.Context, reference storage.DataReference, opts storage.Options, msg proto.Message) error +} + +func (t TestProtoDataStore) ReadProtobuf(ctx context.Context, reference storage.DataReference, msg proto.Message) error { + return t.ReadProtobufCb(ctx, reference, msg) +} + +func (t TestProtoDataStore) WriteProtobuf(ctx context.Context, reference storage.DataReference, opts storage.Options, msg proto.Message) error { + return t.WriteProtobufCb(ctx, reference, opts, msg) +} + +func TestEndHandler_CheckNodeStatus(t *testing.T) { + e := endHandler{} + s, err := e.CheckNodeStatus(context.TODO(), nil, nil, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) +} + +func TestEndHandler_HandleFailingNode(t *testing.T) { + e := endHandler{} + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + } + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: v1alpha1.WorkflowID("w1"), + }, + } + s, err := e.HandleFailingNode(context.TODO(), w, node) + assert.NoError(t, err) + assert.Equal(t, errors.IllegalStateError, s.Err.(*errors.NodeError).Code) +} + +func TestEndHandler_Initialize(t *testing.T) { + e := endHandler{} + assert.NoError(t, e.Initialize(context.TODO())) +} + +func TestEndHandler_StartNode(t *testing.T) { + inMem := createInmemoryDataStore(t, testScope.NewSubScope("x")) + e := New(inMem) + ctx := context.Background() + + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral("hello"), + "y": utils.MustMakePrimitiveLiteral("blah"), + }, + } + + outputRef := v1alpha1.DataReference("testRef") + + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + } + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: v1alpha1.WorkflowID("w1"), + }, + } + w.Status.NodeStatus = map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.EndNodeID: { + DataDir: outputRef, + }, + } + + t.Run("NoInputs", func(t *testing.T) { + s, err := e.StartNode(ctx, w, node, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) + + outputLoc := v1alpha1.GetOutputsFile(outputRef) + t.Run("WithInputs", func(t *testing.T) { + s, err := e.StartNode(ctx, w, node, inputs) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + actual := &core.LiteralMap{} + if assert.NoError(t, inMem.ReadProtobuf(ctx, outputLoc, actual)) { + flyteassert.EqualLiteralMap(t, inputs, actual) + } + }) + + t.Run("StoreFailure", func(t *testing.T) { + store := &TestProtoDataStore{ + WriteProtobufCb: func(ctx context.Context, reference v1alpha1.DataReference, opts storage.Options, msg proto.Message) error { + return regErrors.Errorf("Fail") + }, + } + e := New(store) + s, err := e.StartNode(ctx, w, node, inputs) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.CausedByError)) + assert.Equal(t, handler.StatusUndefined, s) + }) +} diff --git a/pkg/controller/nodes/errors/codes.go b/pkg/controller/nodes/errors/codes.go new file mode 100644 index 000000000..0d0b3da09 --- /dev/null +++ b/pkg/controller/nodes/errors/codes.go @@ -0,0 +1,26 @@ +package errors + +type ErrorCode string + +const ( + NotYetImplementedError ErrorCode = "NotYetImplementedError" + DownstreamNodeNotFoundError ErrorCode = "DownstreamNodeNotFound" + UserProvidedError ErrorCode = "UserProvidedError" + IllegalStateError ErrorCode = "IllegalStateError" + BadSpecificationError ErrorCode = "BadSpecificationError" + UnsupportedTaskTypeError ErrorCode = "UnsupportedTaskType" + BindingResolutionError ErrorCode = "BindingResolutionError" + CausedByError ErrorCode = "CausedByError" + RuntimeExecutionError ErrorCode = "RuntimeExecutionError" + SubWorkflowExecutionFailed ErrorCode = "SubWorkflowExecutionFailed" + RemoteChildWorkflowExecutionFailed ErrorCode = "RemoteChildWorkflowExecutionFailed" + NoBranchTakenError ErrorCode = "NoBranchTakenError" + OutputsNotFoundError ErrorCode = "OutputsNotFoundError" + StorageError ErrorCode = "StorageError" + EventRecordingFailed ErrorCode = "EventRecordingFailed" + CatalogCallFailed ErrorCode = "CatalogCallFailed" +) + +func (e ErrorCode) String() string { + return string(e) +} diff --git a/pkg/controller/nodes/errors/errors.go b/pkg/controller/nodes/errors/errors.go new file mode 100644 index 000000000..05c096ddf --- /dev/null +++ b/pkg/controller/nodes/errors/errors.go @@ -0,0 +1,80 @@ +package errors + +import ( + "fmt" + + "github.com/pkg/errors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type ErrorMessage = string + +type NodeError struct { + errors.StackTrace + Code ErrorCode + Message ErrorMessage + Node v1alpha1.NodeID +} + +func (n *NodeError) Error() string { + return fmt.Sprintf("failed at Node[%s]. %v: %v", n.Node, n.Code, n.Message) +} + +type NodeErrorWithCause struct { + *NodeError + cause error +} + +func (n *NodeErrorWithCause) Error() string { + return fmt.Sprintf("%v, caused by: %v", n.NodeError.Error(), errors.Cause(n)) +} + +func (n *NodeErrorWithCause) Cause() error { + return n.cause +} + +func errorf(c ErrorCode, n v1alpha1.NodeID, msgFmt string, args ...interface{}) *NodeError { + return &NodeError{ + Code: c, + Node: n, + Message: fmt.Sprintf(msgFmt, args...), + } +} + +func Errorf(c ErrorCode, n v1alpha1.NodeID, msgFmt string, args ...interface{}) error { + return errorf(c, n, msgFmt, args...) +} + +func Wrapf(c ErrorCode, n v1alpha1.NodeID, cause error, msgFmt string, args ...interface{}) error { + return &NodeErrorWithCause{ + NodeError: errorf(c, n, msgFmt, args...), + cause: cause, + } +} + +func Matches(err error, code ErrorCode) bool { + errCode, isNodeError := GetErrorCode(err) + if isNodeError { + return code == errCode + } + return false +} + +func GetErrorCode(err error) (code ErrorCode, isNodeError bool) { + isNodeError = false + e, ok := err.(*NodeError) + if ok { + code = e.Code + isNodeError = true + return + } + + e2, ok := err.(*NodeErrorWithCause) + if ok { + code = e2.Code + isNodeError = true + return + } + return +} diff --git a/pkg/controller/nodes/errors/errors_test.go b/pkg/controller/nodes/errors/errors_test.go new file mode 100644 index 000000000..2386e9205 --- /dev/null +++ b/pkg/controller/nodes/errors/errors_test.go @@ -0,0 +1,48 @@ +package errors + +import ( + "fmt" + "testing" + + extErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestErrorf(t *testing.T) { + msg := "msg" + err := Errorf(IllegalStateError, "n1", "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*NodeError) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "n1", e.Node) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, err, extErrors.Cause(e)) + assert.Equal(t, "failed at Node[n1]. IllegalStateError: Message [msg]", err.Error()) +} + +func TestErrorfWithCause(t *testing.T) { + cause := extErrors.Errorf("Some Error") + msg := "msg" + err := Wrapf(IllegalStateError, "n1", cause, "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*NodeErrorWithCause) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "n1", e.Node) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, cause, extErrors.Cause(e)) + assert.Equal(t, "failed at Node[n1]. IllegalStateError: Message [msg], caused by: Some Error", err.Error()) +} + +func TestMatches(t *testing.T) { + err := Errorf(IllegalStateError, "n1", "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + cause := extErrors.Errorf("Some Error") + err = Wrapf(IllegalStateError, "n1", cause, "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + assert.False(t, Matches(cause, IllegalStateError)) + assert.False(t, Matches(cause, BadSpecificationError)) +} diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go new file mode 100644 index 000000000..7d0ea5ee3 --- /dev/null +++ b/pkg/controller/nodes/executor.go @@ -0,0 +1,540 @@ +package nodes + +import ( + "context" + "fmt" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/clients/go/events" + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/utils" +) + +type nodeMetrics struct { + FailureDuration labeled.StopWatch + SuccessDuration labeled.StopWatch + ResolutionFailure labeled.Counter + InputsWriteFailure labeled.Counter + + // Measures the latency between the last parent node stoppedAt time and current node's queued time. + TransitionLatency labeled.StopWatch + // Measures the latency between the time a node's been queued to the time the handler reported the executable moved + // to running state + QueuingLatency labeled.StopWatch +} + +type nodeExecutor struct { + nodeHandlerFactory HandlerFactory + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + nodeRecorder events.NodeEventRecorder + metrics *nodeMetrics +} + +// In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued +// Before we start the node execution, we need to transition this Node status to Queued. +// This is because a node execution has to exist before task/wf executions can start. +func (c *nodeExecutor) queueNodeIfReady(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + logger.Debugf(ctx, "Node not yet started") + // Query the nodes information to figure out if it can be executed. + predicatePhase, err := CanExecute(ctx, w, node) + if err != nil { + logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) + return handler.StatusUndefined, err + } + if predicatePhase == PredicatePhaseSkip { + logger.Debugf(ctx, "Node upstream node was skipped. Skipping!") + return handler.StatusSkipped, nil + } else if predicatePhase == PredicatePhaseNotReady { + logger.Debugf(ctx, "Node not ready for executing.") + return handler.StatusNotStarted, nil + } + + if len(nodeStatus.GetDataDir()) == 0 { + // Predicate ready, lets Resolve the data + dataDir, err := w.GetExecutionStatus().ConstructNodeDataDir(ctx, c.store, node.GetID()) + if err != nil { + return handler.StatusUndefined, err + } + + nodeStatus.SetDataDir(dataDir) + } + + return handler.StatusQueued, nil +} + +func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { + if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { + // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) + t, err := GetParentNodeMaxEndTime(ctx, w, node) + if err != nil { + logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) + return + } + if !t.IsZero() { + c.metrics.TransitionLatency.Observe(ctx, t.Time, time.Now()) + } + } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure && nodeStatus.GetLastUpdatedAt() != nil { + c.metrics.TransitionLatency.Observe(ctx, nodeStatus.GetLastUpdatedAt().Time, time.Now()) + } +} + +// Start the node execution. This implies that the node will start processing +func (c *nodeExecutor) startNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, h handler.IFace) (handler.Status, error) { + + // TODO: Performance problem, we may be in a retry loop and do not need to resolve the inputs again. + // For now we will do this. + dataDir := nodeStatus.GetDataDir() + var nodeInputs *handler.Data + if !node.IsStartNode() { + // Can execute + var err error + nodeInputs, err = Resolve(ctx, c.nodeHandlerFactory, w, node.GetID(), node.GetInputBindings(), c.store) + // TODO we need to handle retryable, network errors here!! + if err != nil { + c.metrics.ResolutionFailure.Inc(ctx) + logger.Warningf(ctx, "Failed to resolve inputs for Node. Error [%v]", err) + return handler.StatusFailed(err), nil + } + + if nodeInputs != nil { + inputsFile := v1alpha1.GetInputsFile(dataDir) + if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to store inputs for Node. Error [%v]. InputsFile [%s]", err, inputsFile) + return handler.StatusUndefined, errors.Wrapf( + errors.StorageError, node.GetID(), err, "Failed to store inputs for Node. InputsFile [%s]", inputsFile) + } + } + + logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) + } + + // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed + // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due + // to various external reasons - like queuing, overuse of quota, plugin overhead etc. + c.RecordTransitionLatency(ctx, w, node, nodeStatus) + + // Start node execution + return h.StartNode(ctx, w, node, nodeInputs) +} + +func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + logger.Debugf(ctx, "Handling Node [%s]", node.GetID()) + defer logger.Debugf(ctx, "Completed node [%s]", node.GetID()) + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(node.GetKind()) + if err != nil { + return handler.StatusUndefined, err + } + + // Important to note that we have special optimization for start node only (not end node) + // We specifically ignore queueing of start node and directly move the start node to "starting" + // This prevents an extra event to Admin and an extra write to etcD. This is also because of the fact that start node does not have tasks and do not need to send out task events. + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + var status handler.Status + if !node.IsStartNode() && !node.IsEndNode() && nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted { + // We only send the queued event to Admin in case the node was never started and when it is not either StartNode. + // We do not do this for endNode, because endNode may still be not executable. This is because StartNode + // completes as soon as started. + return c.queueNodeIfReady(ctx, w, node, nodeStatus) + } else if node.IsEndNode() { + status, err = c.queueNodeIfReady(ctx, w, node, nodeStatus) + if err == nil && status.Phase == handler.PhaseQueued { + status, err = c.startNode(ctx, w, node, nodeStatus, h) + } + } else if node.IsStartNode() || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued || + nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure { + // If the node is either StartNode or was previously queued or failed in a previous attempt, we will call + // the start method on the node handler + status, err = c.startNode(ctx, w, node, nodeStatus, h) + } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseFailing { + status, err = h.HandleFailingNode(ctx, w, node) + } else { + status, err = h.CheckNodeStatus(ctx, w, node, nodeStatus) + } + + return status, err +} + +func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { + err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent) + // TODO: add unit tests for this specific path + if err != nil && eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) + return nil + } + return err +} + +func (c *nodeExecutor) TransitionToPhase(ctx context.Context, execID *core.WorkflowExecutionIdentifier, + node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, toStatus handler.Status) (executors.NodeStatus, error) { + + previousNodePhase := nodeStatus.GetPhase() + // TODO GC analysis. We will create a ton of node-events but never publish them. We could first check for the PhaseChange and if so then do this processing + + nodeEvent := &event.NodeExecutionEvent{ + Id: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: execID, + }, + InputUri: v1alpha1.GetInputsFile(nodeStatus.GetDataDir()).String(), + } + + var returnStatus executors.NodeStatus + errMsg := "" + errCode := "NodeExecutionUnknownError" + if toStatus.Err != nil { + errMsg = toStatus.Err.Error() + code, ok := errors.GetErrorCode(toStatus.Err) + if ok { + errCode = code.String() + } + } + + // If there is a child workflow, include the execution of the child workflow in the event + if nodeStatus.GetWorkflowNodeStatus() != nil { + nodeEvent.TargetMetadata = &event.NodeExecutionEvent_WorkflowNodeMetadata{ + WorkflowNodeMetadata: &event.WorkflowNodeMetadata{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: execID.Project, + Domain: execID.Domain, + Name: nodeStatus.GetWorkflowNodeStatus().GetWorkflowExecutionName(), + }, + }, + } + } + + switch toStatus.Phase { + case handler.PhaseNotStarted: + return executors.NodeStatusPending, nil + // TODO we should not need handler.PhaseQueued since we have added Task StateMachine. Remove it. + case handler.PhaseQueued: + nodeEvent.Phase = core.NodeExecution_QUEUED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseQueued, v1.NewTime(toStatus.OccurredAt), "") + + returnStatus = executors.NodeStatusQueued + if !toStatus.OccurredAt.IsZero() { + nodeEvent.OccurredAt = utils.GetProtoTime(&v1.Time{Time: toStatus.OccurredAt}) + } else { + nodeEvent.OccurredAt = ptypes.TimestampNow() // TODO: add queueAt in nodeStatus + } + + case handler.PhaseRunning: + nodeEvent.Phase = core.NodeExecution_RUNNING + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, v1.NewTime(toStatus.OccurredAt), "") + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStartedAt()) + returnStatus = executors.NodeStatusRunning + + if nodeStatus.GetQueuedAt() != nil && nodeStatus.GetStartedAt() != nil { + c.metrics.QueuingLatency.Observe(ctx, nodeStatus.GetQueuedAt().Time, nodeStatus.GetStartedAt().Time) + } + case handler.PhaseRetryableFailure: + maxAttempts := uint32(0) + if node.GetRetryStrategy() != nil && node.GetRetryStrategy().MinAttempts != nil { + maxAttempts = uint32(*node.GetRetryStrategy().MinAttempts) + } + + nodeEvent.OutputResult = &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errCode, + Message: fmt.Sprintf("Retries [%d/%d], %s", nodeStatus.GetAttempts(), maxAttempts, errMsg), + ErrorUri: v1alpha1.GetOutputErrorFile(nodeStatus.GetDataDir()).String(), + }, + } + + if nodeStatus.IncrementAttempts() >= maxAttempts { + logger.Debugf(ctx, "All retries have exhausted, failing node. [%d/%d]", nodeStatus.GetAttempts(), maxAttempts) + // Failure + nodeEvent.Phase = core.NodeExecution_FAILED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, v1.NewTime(toStatus.OccurredAt), errMsg) + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + returnStatus = executors.NodeStatusFailed(toStatus.Err) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + } else { + // retry + // TODO add a nodeEvent of retryableFailure (it is not a terminal event). + // For now, we don't send an event for node retryable failures. + nodeEvent = nil + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRetryableFailure, v1.NewTime(toStatus.OccurredAt), errMsg) + returnStatus = executors.NodeStatusRunning + + // Reset all executors' state to start a fresh attempt. + nodeStatus.ClearTaskStatus() + nodeStatus.ClearWorkflowStatus() + nodeStatus.ClearDynamicNodeStatus() + + // Required for transition (backwards compatibility) + if nodeStatus.GetLastUpdatedAt() != nil { + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetLastUpdatedAt().Time) + } + } + + case handler.PhaseSkipped: + nodeEvent.Phase = core.NodeExecution_SKIPPED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.NewTime(toStatus.OccurredAt), "") + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + returnStatus = executors.NodeStatusSuccess + + case handler.PhaseSucceeding: + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeding, v1.NewTime(toStatus.OccurredAt), "") + // Currently we do not record events for this + return executors.NodeStatusRunning, nil + + case handler.PhaseSuccess: + nodeEvent.Phase = core.NodeExecution_SUCCEEDED + reason := "" + if nodeStatus.IsCached() { + reason = "Task Skipped due to Discovery Cache Hit." + } + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, v1.NewTime(toStatus.OccurredAt), reason) + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + if metadata, err := c.store.Head(ctx, v1alpha1.GetOutputsFile(nodeStatus.GetDataDir())); err == nil && metadata.Exists() { + nodeEvent.OutputResult = &event.NodeExecutionEvent_OutputUri{ + OutputUri: v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()).String(), + } + } + + returnStatus = executors.NodeStatusSuccess + c.metrics.SuccessDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + + case handler.PhaseFailing: + nodeEvent.Phase = core.NodeExecution_FAILING + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailing, v1.NewTime(toStatus.OccurredAt), "") + nodeEvent.OccurredAt = utils.GetProtoTime(nil) + returnStatus = executors.NodeStatusRunning + + case handler.PhaseFailed: + nodeEvent.Phase = core.NodeExecution_FAILED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, v1.NewTime(toStatus.OccurredAt), errMsg) + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + nodeEvent.OutputResult = &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errCode, + Message: errMsg, + ErrorUri: v1alpha1.GetOutputErrorFile(nodeStatus.GetDataDir()).String(), + }, + } + returnStatus = executors.NodeStatusFailed(toStatus.Err) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + + case handler.PhaseUndefined: + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, node.GetID(), "unexpected undefined state received, without an error") + } + + // We observe that the phase has changed, and so we will record this event. + if nodeEvent != nil && previousNodePhase != nodeStatus.GetPhase() { + if nodeStatus.GetParentTaskID() != nil { + nodeEvent.ParentTaskMetadata = &event.ParentTaskExecutionMetadata{ + Id: nodeStatus.GetParentTaskID(), + } + } + + logger.Debugf(ctx, "Recording NodeEvent for Phase transition [%s] -> [%s]", previousNodePhase.String(), nodeStatus.GetPhase().String()) + err := c.IdempotentRecordEvent(ctx, nodeEvent) + + if err != nil && eventsErr.IsEventAlreadyInTerminalStateError(err) { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return executors.NodeStatusFailed(errors.Wrapf(errors.IllegalStateError, node.GetID(), err, + "phase mismatch between propeller and control plane; Propeller State: %s", returnStatus.NodePhase)), nil + } else if err != nil { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, node.GetID(), err, "failed to record node event") + } + } + return returnStatus, nil +} + +func (c *nodeExecutor) executeNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + handlerStatus, err := c.handleNode(ctx, w, node) + if err != nil { + logger.Warningf(ctx, "Node handling failed with an error [%v]", err.Error()) + return executors.NodeStatusUndefined, err + } + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + return c.TransitionToPhase(ctx, w.GetExecutionID().WorkflowExecutionIdentifier, node, nodeStatus, handlerStatus) +} + +// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from +// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. +func (c *nodeExecutor) handleDownstream(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + logger.Debugf(ctx, "Handling downstream Nodes") + // This node is success. Handle all downstream nodes + downstreamNodes, err := w.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return executors.NodeStatusFailed(err), nil + } + if len(downstreamNodes) == 0 { + logger.Debugf(ctx, "No downstream nodes found. Complete.") + return executors.NodeStatusComplete, nil + } + // If any downstream node is failed, fail, all + // Else if all are success then success + // Else if any one is running then Downstream is still running + allCompleted := true + partialNodeCompletion := false + for _, downstreamNodeName := range downstreamNodes { + downstreamNode, ok := w.GetNode(downstreamNodeName) + if !ok { + return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", downstreamNodeName)), nil + } + state, err := c.RecursiveNodeHandler(ctx, w, downstreamNode) + if err != nil { + return executors.NodeStatusUndefined, err + } + if state.HasFailed() { + logger.Debugf(ctx, "Some downstream node has failed, %s", state.Err.Error()) + return state, nil + } + if !state.IsComplete() { + allCompleted = false + } + + if state.PartiallyComplete() { + // This implies that one of the downstream nodes has completed and workflow is ready for propagation + // We do not propagate in current cycle to make it possible to store the state between transitions + partialNodeCompletion = true + } + } + if allCompleted { + logger.Debugf(ctx, "All downstream nodes completed") + return executors.NodeStatusComplete, nil + } + if partialNodeCompletion { + return executors.NodeStatusSuccess, nil + } + return executors.NodeStatusPending, nil +} + +func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *handler.Data) (executors.NodeStatus, error) { + startNode := w.StartNode() + if startNode == nil { + return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, v1alpha1.StartNodeID, "Start node not found")), nil + } + ctx = contextutils.WithNodeID(ctx, startNode.GetID()) + if inputs == nil { + logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") + return executors.NodeStatusComplete, nil + } + // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs + nodeStatus := w.GetNodeExecutionStatus(startNode.GetID()) + if nodeStatus.GetDataDir() == "" { + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") + } + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + so := storage.Options{} + if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return executors.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") + } + return executors.NodeStatusComplete, nil +} + +func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + switch nodeStatus.GetPhase() { + case v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseSucceeding: + logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) + return c.executeNode(currentNodeCtx, w, currentNode) + // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped + // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped + // at a time. As we iterate down, further nodes will be skipped + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSkipped: + return c.handleDownstream(ctx, w, currentNode) + case v1alpha1.NodePhaseFailed: + logger.Debugf(currentNodeCtx, "Node Failed") + return executors.NodeStatusFailed(errors.Errorf(errors.RuntimeExecutionError, currentNode.GetID(), "Node Failed.")), nil + } + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), "Should never reach here") +} + +func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + switch nodeStatus.GetPhase() { + case v1alpha1.NodePhaseRunning: + // Abort this node + h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + if err != nil { + return err + } + return h.AbortNode(ctx, w, currentNode) + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSkipped: + // Abort downstream nodes + downstreamNodes, err := w.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return nil + } + for _, d := range downstreamNodes { + downstreamNode, ok := w.GetNode(d) + if !ok { + return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) + } + if err := c.AbortHandler(ctx, w, downstreamNode); err != nil { + return err + } + } + return nil + } + return nil +} + +func (c *nodeExecutor) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing Core Node Executor") + return nil +} + +func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, + revalPeriod time.Duration, eventSink events.EventSink, workflowLauncher launchplan.Executor, + catalogClient catalog.Client, kubeClient executors.Client, scope promutils.Scope) (executors.Node, error) { + + nodeScope := scope.NewSubScope("node") + exec := &nodeExecutor{ + store: store, + enqueueWorkflow: enQWorkflow, + nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope), + metrics: &nodeMetrics{ + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), + ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), + TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + }, + } + nodeHandlerFactory, err := NewHandlerFactory( + ctx, + exec, + eventSink, + workflowLauncher, + enQWorkflow, + revalPeriod, + store, + catalogClient, + kubeClient, + nodeScope, + ) + exec.nodeHandlerFactory = nodeHandlerFactory + return exec, err +} diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go new file mode 100644 index 000000000..2cfcc4caf --- /dev/null +++ b/pkg/controller/nodes/executor_test.go @@ -0,0 +1,1479 @@ +package nodes + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + "time" + + mocks4 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + pluginV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + goerrors "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mocks3 "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks2 "github.com/lyft/flytepropeller/pkg/controller/nodes/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytepropeller/pkg/utils" + flyteassert "github.com/lyft/flytepropeller/pkg/utils/assert" +) + +var fakeKubeClient = mocks4.NewFakeKubeClient() + +func createSingletonTaskExecutorFactory() task.Factory { + return &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginV1.Executor, error) { + return nil, nil + }, + ListAllTaskExecutorsCb: func() []pluginV1.Executor { + return []pluginV1.Executor{} + }, + } +} + +func init() { + flytek8s.InitializeFake() +} + +func TestSetInputsForStartNode(t *testing.T) { + ctx := context.Background() + mockStorage := createInmemoryDataStore(t, testScope.NewSubScope("f")) + catalogClient := catalog.NewCatalogClient(mockStorage) + enQWf := func(workflowID v1alpha1.WorkflowID) {} + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + exec, err := NewExecutor(ctx, mockStorage, enQWf, time.Second, events.NewMockEventSink(), launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral("hello"), + "y": utils.MustMakePrimitiveLiteral("blah"), + }, + } + + t.Run("NoInputs", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := exec.SetInputsForStartNode(ctx, w, nil) + assert.NoError(t, err) + assert.Equal(t, executors.NodeStatusComplete, s) + }) + + t.Run("WithInputs", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.GetNodeExecutionStatus(v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := exec.SetInputsForStartNode(ctx, w, inputs) + assert.NoError(t, err) + assert.Equal(t, executors.NodeStatusComplete, s) + actual := &core.LiteralMap{} + if assert.NoError(t, mockStorage.ReadProtobuf(ctx, "s3://test-bucket/exec/start-node/data/outputs.pb", actual)) { + flyteassert.EqualLiteralMap(t, inputs, actual) + } + }) + + t.Run("DataDirNotSet", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := exec.SetInputsForStartNode(ctx, w, inputs) + assert.Error(t, err) + assert.Equal(t, executors.NodeStatusUndefined, s) + }) + + failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) + execFail, err := NewExecutor(ctx, failStorage, enQWf, time.Second, events.NewMockEventSink(), launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + t.Run("StorageFailure", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.GetNodeExecutionStatus(v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := execFail.SetInputsForStartNode(ctx, w, inputs) + assert.Error(t, err) + assert.Equal(t, executors.NodeStatusUndefined, s) + }) +} + +func TestNodeExecutor_TransitionToPhase(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + catalogClient := catalog.NewCatalogClient(memStore) + execIface, err := NewExecutor(ctx, memStore, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + execID := &core.WorkflowExecutionIdentifier{} + nodeID := "n1" + + expectedErr := fmt.Errorf("test err") + taskErr := fmt.Errorf("task failed") + + // TABLE Tests + tests := []struct { + name string + nodeStatus v1alpha1.ExecutableNodeStatus + toStatus handler.Status + expectedErr bool + expectedNodeStatus executors.NodeStatus + }{ + {"notStarted", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusNotStarted, false, executors.NodeStatusPending}, + {"running", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusRunning, false, executors.NodeStatusRunning}, + {"runningRepeated", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusRunning, false, executors.NodeStatusRunning}, + {"success", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSuccess, false, executors.NodeStatusSuccess}, + {"succeeding", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSucceeding, false, executors.NodeStatusRunning}, + {"failing", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusFailing(nil), false, executors.NodeStatusRunning}, + {"failed", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseFailing}, handler.StatusFailed(taskErr), false, executors.NodeStatusFailed(taskErr)}, + {"undefined", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusUndefined, true, executors.NodeStatusUndefined}, + {"skipped", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusSkipped, false, executors.NodeStatusSuccess}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + n, err := exec.TransitionToPhase(ctx, execID, node, test.nodeStatus, test.toStatus) + if test.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedNodeStatus, n) + }) + } + + // Testing retries + t.Run("noRetryAttemptSet", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseFailed, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(nil) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase) + }) + + // Testing retries + t.Run("maxAttempt0", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseFailed, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + + maxAttempts := 0 + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(&v1alpha1.RetryStrategy{MinAttempts: &maxAttempts}) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase) + }) + + // Testing retries + t.Run("retryAttemptsRemaining", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseRetryableFailure, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetLastUpdatedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + var s *v1alpha1.TaskNodeStatus + status.On("UpdateTaskNodeStatus", s).Times(10) + status.On("ClearTaskStatus").Return() + status.On("ClearWorkflowStatus").Return() + status.On("ClearDynamicNodeStatus").Return() + + maxAttempts := 2 + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(&v1alpha1.RetryStrategy{MinAttempts: &maxAttempts}) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseRunning, n.NodePhase, "%+v", n) + }) + + // Testing retries + t.Run("retriesExhausted", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + // Change to return 3 + status.On("IncrementAttempts").Return(uint32(3)) + status.On("UpdatePhase", v1alpha1.NodePhaseFailed, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + + maxAttempts := 2 + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(&v1alpha1.RetryStrategy{MinAttempts: &maxAttempts}) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase, "%+v", n.NodePhase) + }) + + t.Run("eventSendFailure", func(t *testing.T) { + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + // In case Report event fails + mockEventSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return expectedErr + } + n, err := exec.TransitionToPhase(ctx, execID, node, &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSuccess) + assert.Error(t, err) + assert.Equal(t, expectedErr, goerrors.Cause(err)) + assert.Equal(t, executors.NodeStatusUndefined, n) + }) + + t.Run("eventSendMismatch", func(t *testing.T) { + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + // In case Report event fails + mockEventSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + } + } + n, err := exec.TransitionToPhase(ctx, execID, node, &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSuccess) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase) + }) + + // Testing that workflow execution name is queried in running + t.Run("childWorkflows", func(t *testing.T) { + now := v1.Now() + + wfNodeStatus := &mocks.ExecutableWorkflowNodeStatus{} + wfNodeStatus.On("GetWorkflowExecutionName").Return("childWfName") + + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseQueued) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseRunning, mock.Anything, mock.AnythingOfType("string")) + status.On("GetStartedAt").Return(&now) + status.On("GetQueuedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetOrCreateWorkflowStatus").Return(wfNodeStatus) + status.On("ClearTaskStatus").Return() + status.On("ClearWorkflowStatus").Return() + status.On("GetWorkflowNodeStatus").Return(wfNodeStatus) + + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(nil) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRunning) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseRunning, n.NodePhase) + wfNodeStatus.AssertCalled(t, "GetWorkflowExecutionName") + }) +} + +func TestNodeExecutor_Initialize(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + catalogClient := catalog.NewCatalogClient(memStore) + + execIface, err := NewExecutor(ctx, memStore, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + assert.NoError(t, exec.Initialize(ctx)) +} + +func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, + launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createStartNodeWf := func(p v1alpha1.NodePhase, _ int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + startNode := &v1alpha1.NodeSpec{ + Kind: v1alpha1.NodeKindStart, + ID: v1alpha1.StartNodeID, + } + startNodeStatus := &v1alpha1.NodeStatus{ + Phase: p, + } + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.StartNodeID: startNodeStatus, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + v1alpha1.StartNodeID: startNode, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {defaultNodeID}, + }, + }, + }, + }, startNode, startNodeStatus + + } + + // Recurse Child Node Queued previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"nys->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + {"nys->error", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) + + mockWf, startNode, startNodeStatus := createStartNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), startNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, startNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), startNodeStatus.GetPhase().String()) + }) + } + } +} + +func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + // Node not yet started + { + createSingleNodeWf := func(parentPhase v1alpha1.NodePhase, _ int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + Kind: v1alpha1.NodeKindEnd, + } + ns := &v1alpha1.NodeStatus{} + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.EndNodeID: ns, + v1alpha1.StartNodeID: { + Phase: parentPhase, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + v1alpha1.EndNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.EndNodeID: {v1alpha1.StartNodeID}, + }, + }, + }, + }, n, ns + + } + tests := []struct { + name string + parentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + expectedError bool + }{ + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, + {"success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil) + h.On("StartNode", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(handler.StatusQueued, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Recurse End Node Queued previously + { + createSingleNodeWf := func(endNodePhase v1alpha1.NodePhase, _ int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + Kind: v1alpha1.NodeKindEnd, + } + ns := &v1alpha1.NodeStatus{ + Phase: endNodePhase, + } + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.EndNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + v1alpha1.StartNodeID: { + ID: v1alpha1.StartNodeID, + Kind: v1alpha1.NodeKindStart, + }, + v1alpha1.EndNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.EndNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {v1alpha1.EndNodeID}, + }, + }, + }, + }, n, ns + + } + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil) + + mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + startNode := mockWf.StartNode() + startStatus := mockWf.GetNodeExecutionStatus(startNode.GetID()) + assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + +} + +func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createSingleNodeWf := func(p v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + ns := &v1alpha1.NodeStatus{ + Phase: p, + } + + startNode := &v1alpha1.NodeSpec{ + Kind: v1alpha1.NodeKindStart, + ID: v1alpha1.StartNodeID, + } + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + v1alpha1.StartNodeID: startNode, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {defaultNodeID}, + }, + }, + }, + }, n, ns + + } + + // Recursion test with child Node not yet started + { + nodeN0 := "n0" + nodeN2 := "n2" + ctx := context.Background() + connections := &v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + nodeN2: {nodeN0}, + }, + } + + setupNodePhase := func(n0Phase, n2Phase, expectedN2Phase v1alpha1.NodePhase) (*mocks.ExecutableWorkflow, *mocks.ExecutableNodeStatus) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("GetPhase").Return(n2Phase) + mockN2Status.On("SetDataDir", mock.AnythingOfType(reflect.TypeOf(storage.DataReference("x")).String())) + mockN2Status.On("GetDataDir").Return(storage.DataReference("blah")) + mockN2Status.On("GetWorkflowNodeStatus").Return(nil) + mockN2Status.On("GetStoppedAt").Return(nil) + mockN2Status.On("UpdatePhase", expectedN2Phase, mock.Anything, mock.AnythingOfType("string")) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.ExecutableNode{} + mockNode.On("GetID").Return(nodeN2) + mockNode.On("GetBranchNode").Return(nil) + mockNode.On("GetKind").Return(v1alpha1.NodeKindTask) + mockNode.On("IsStartNode").Return(false) + mockNode.On("IsEndNode").Return(false) + + mockNodeN0 := &mocks.ExecutableNode{} + mockNodeN0.On("GetID").Return(nodeN0) + mockNodeN0.On("GetBranchNode").Return(nil) + mockNodeN0.On("GetKind").Return(v1alpha1.NodeKindTask) + mockNodeN0.On("IsStartNode").Return(false) + mockNodeN0.On("IsEndNode").Return(false) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(n0Phase) + mockN0Status.On("IsDirty").Return(false) + + mockWfStatus := &mocks.ExecutableWorkflowStatus{} + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("StartNode").Return(mockNodeN0) + mockWf.On("GetNode", nodeN2).Return(mockNode, true) + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + mockWf.On("FromNode", nodeN0).Return([]string{nodeN2}, nil) + mockWf.On("FromNode", nodeN2).Return([]string{}, fmt.Errorf("did not expect")) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{}) + mockWf.On("GetExecutionStatus").Return(mockWfStatus) + mockWfStatus.On("GetDataDir").Return(storage.DataReference("x")) + return mockWf, mockN2Status + } + + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + parentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + updateCalled bool + }{ + {"notYetStarted->notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseNotYetStarted, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusNotStarted, nil + }, false, false}, + + {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSkipped, nil + }, false, true}, + + {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, _ := setupNodePhase(test.parentNodePhase, test.currentNodePhase, test.expectedNodePhase) + startNode := mockWf.StartNode() + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + }) + } + } + + // Recurse Child Node Queued previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + startNode := mockWf.StartNode() + startStatus := mockWf.GetNodeExecutionStatus(startNode.GetID()) + assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Recurse Child Node started previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at running + {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"running->failed", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"running->success", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + + {"previously-failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"previously-success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, executors.NodePhaseComplete, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + startNode := mockWf.StartNode() + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } +} + +func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createSingleNodeWf := func(p v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + ns := &v1alpha1.NodeStatus{ + Phase: p, + } + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + }, + }, + }, n, ns + + } + + // Node not yet started + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + {"notYetStarted->running", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"notYetStarted->failed", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"notYetStarted->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node queued previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at running + {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"running->failed", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"running->success", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously and is failing + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Failing + // TODO this should be illegal + {"failing->running", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseRunning, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"failing->failed", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + // TODO this should be illegal + {"failing->success", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"failing->error", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailing, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("HandleFailingNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously and retryable failure + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"running->retryable", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRetryableFailure, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRetryableFailure(fmt.Errorf("err")), nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 2) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(1), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously and retryable failure - but exhausted attempts + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + {"running->retryable", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusRetryableFailure(fmt.Errorf("err")), nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 1) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(1), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } +} + +func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createSingleNodeWf := func(parentPhase v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + ns := &v1alpha1.NodeStatus{} + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: parentPhase, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + }, + }, + }, n, ns + + } + + // Node not yet started + { + tests := []struct { + name string + parentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + expectedError bool + }{ + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } +} + +func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { + testScope := promutils.NewTestScope() + type fields struct { + nodeHandlerFactory HandlerFactory + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + nodeRecorder events.NodeEventRecorder + metrics *nodeMetrics + } + type args struct { + w v1alpha1.ExecutableWorkflow + node v1alpha1.ExecutableNode + nodeStatus v1alpha1.ExecutableNodeStatus + } + + nsf := func(phase v1alpha1.NodePhase, lastUpdated *time.Time) *mocks.ExecutableNodeStatus { + ns := &mocks.ExecutableNodeStatus{} + ns.On("GetPhase").Return(phase) + var t *v1.Time + if lastUpdated != nil { + t = &v1.Time{Time: *lastUpdated} + } + ns.On("GetLastUpdatedAt").Return(t) + return ns + } + testTime := time.Now() + tests := []struct { + name string + fields fields + args args + recordingExpected bool + }{ + { + "retryable-failure", + fields{metrics: &nodeMetrics{TransitionLatency: labeled.NewStopWatch("test", "xyz", time.Millisecond, testScope)}}, + args{nodeStatus: nsf(v1alpha1.NodePhaseRetryableFailure, &testTime)}, + true, + }, + { + "retryable-failure-notime", + fields{metrics: &nodeMetrics{TransitionLatency: labeled.NewStopWatch("test2", "xyz", time.Millisecond, testScope)}}, + args{nodeStatus: nsf(v1alpha1.NodePhaseRetryableFailure, nil)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &nodeExecutor{ + nodeHandlerFactory: tt.fields.nodeHandlerFactory, + enqueueWorkflow: tt.fields.enqueueWorkflow, + store: tt.fields.store, + nodeRecorder: tt.fields.nodeRecorder, + metrics: tt.fields.metrics, + } + c.RecordTransitionLatency(context.TODO(), tt.args.w, tt.args.node, tt.args.nodeStatus) + + ch := make(chan prometheus.Metric, 2) + tt.fields.metrics.TransitionLatency.Collect(ch) + assert.Equal(t, len(ch) == 1, tt.recordingExpected) + }) + } +} diff --git a/pkg/controller/nodes/handler/iface.go b/pkg/controller/nodes/handler/iface.go new file mode 100644 index 000000000..4101ba09d --- /dev/null +++ b/pkg/controller/nodes/handler/iface.go @@ -0,0 +1,128 @@ +package handler + +import ( + "context" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +//go:generate mockery -all + +type Data = core.LiteralMap +type VarName = string + +type Phase int + +const ( + // Indicates that the handler was unable to Start the Node due to an internal failure + PhaseNotStarted Phase = iota + // Incase of retryable failure and should be retried + PhaseRetryableFailure + // Indicates that the node is queued because the task is queued + PhaseQueued + // Indicates that the node is currently executing and no errors have been observed + PhaseRunning + // PhaseFailing is currently used by SubWorkflow Only. It indicates that the Node's primary work has failed, + // but, either some cleanup or exception handling condition is in progress + PhaseFailing + // This is a terminal Status and indicates that the node execution resulted in a Failure + PhaseFailed + // This is a pre-terminal state, currently unused and indicates that the Node execution has succeeded barring any cleanup + PhaseSucceeding + // This is a terminal state and indicates successful completion of the node execution. + PhaseSuccess + // This Phase indicates that the node execution can be skipped, because of either conditional failures or user defined cases + PhaseSkipped + // This phase indicates that an error occurred and is always accompanied by `error`. the execution for that node is + // in an indeterminate state and should be retried + PhaseUndefined +) + +var PhasesToString = map[Phase]string{ + PhaseNotStarted: "NotStarted", + PhaseQueued: "Queued", + PhaseRunning: "Running", + PhaseFailing: "Failing", + PhaseFailed: "Failed", + PhaseSucceeding: "Succeeding", + PhaseSuccess: "Success", + PhaseSkipped: "Skipped", + PhaseUndefined: "Undefined", + PhaseRetryableFailure: "RetryableFailure", +} + +func (p Phase) String() string { + str, found := PhasesToString[p] + if found { + return str + } + + return "Unknown" +} + +// This encapsulates the status of the node +type Status struct { + Phase Phase + Err error + OccurredAt time.Time +} + +var StatusNotStarted = Status{Phase: PhaseNotStarted} +var StatusQueued = Status{Phase: PhaseQueued} +var StatusRunning = Status{Phase: PhaseRunning} +var StatusSucceeding = Status{Phase: PhaseSucceeding} +var StatusSuccess = Status{Phase: PhaseSuccess} +var StatusUndefined = Status{Phase: PhaseUndefined} +var StatusSkipped = Status{Phase: PhaseSkipped} + +func (s Status) WithOccurredAt(t time.Time) Status { + s.OccurredAt = t + return s +} + +func StatusFailed(err error) Status { + return Status{Phase: PhaseFailed, Err: err} +} + +func StatusRetryableFailure(err error) Status { + return Status{Phase: PhaseRetryableFailure, Err: err} +} + +func StatusFailing(err error) Status { + return Status{Phase: PhaseFailing, Err: err} +} + +type OutputResolver interface { + // Extracts a subset of node outputs to literals. + ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar VarName) (values *core.Literal, err error) +} + +type PostNodeSuccessHandler interface { + HandleNodeSuccess(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (Status, error) +} + +// Interface that should be implemented for a node type. +type IFace interface { + //OutputResolver + + // Initialize should be called, before invoking any other methods of this handler. Initialize will be called using one thread + // only + Initialize(ctx context.Context) error + + // Start node is called for a node only if the recorded state indicates that the node was never started previously. + // the implementation should handle idempotency, even if the chance of invoking it more than once for an execution is rare. + StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *Data) (Status, error) + + // For any node that is not in a NEW/READY state in the recording, CheckNodeStatus will be invoked. The implementation should handle + // idempotency and return the current observed state of the node + CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, previousNodeStatus v1alpha1.ExecutableNodeStatus) (Status, error) + + // This is called in the case, a node failure is observed. + HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (Status, error) + + // Abort is invoked as a way to clean up failing/aborted workflows + AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error +} diff --git a/pkg/controller/nodes/handler/mocks/IFace.go b/pkg/controller/nodes/handler/mocks/IFace.go new file mode 100644 index 000000000..b21e42da5 --- /dev/null +++ b/pkg/controller/nodes/handler/mocks/IFace.go @@ -0,0 +1,105 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// IFace is an autogenerated mock type for the IFace type +type IFace struct { + mock.Mock +} + +// AbortNode provides a mock function with given fields: ctx, w, node +func (_m *IFace) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + ret := _m.Called(ctx, w, node) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) error); ok { + r0 = rf(ctx, w, node) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CheckNodeStatus provides a mock function with given fields: ctx, w, node, previousNodeStatus +func (_m *IFace) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, previousNodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + ret := _m.Called(ctx, w, node, previousNodeStatus) + + var r0 handler.Status + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) handler.Status); ok { + r0 = rf(ctx, w, node, previousNodeStatus) + } else { + r0 = ret.Get(0).(handler.Status) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) error); ok { + r1 = rf(ctx, w, node, previousNodeStatus) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HandleFailingNode provides a mock function with given fields: ctx, w, node +func (_m *IFace) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + ret := _m.Called(ctx, w, node) + + var r0 handler.Status + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) handler.Status); ok { + r0 = rf(ctx, w, node) + } else { + r0 = ret.Get(0).(handler.Status) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) error); ok { + r1 = rf(ctx, w, node) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: ctx +func (_m *IFace) Initialize(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StartNode provides a mock function with given fields: ctx, w, node, nodeInputs +func (_m *IFace) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *core.LiteralMap) (handler.Status, error) { + ret := _m.Called(ctx, w, node, nodeInputs) + + var r0 handler.Status + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, *core.LiteralMap) handler.Status); ok { + r0 = rf(ctx, w, node, nodeInputs) + } else { + r0 = ret.Get(0).(handler.Status) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, *core.LiteralMap) error); ok { + r1 = rf(ctx, w, node, nodeInputs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/handler/mocks/OutputResolver.go b/pkg/controller/nodes/handler/mocks/OutputResolver.go new file mode 100644 index 000000000..92b9560cc --- /dev/null +++ b/pkg/controller/nodes/handler/mocks/OutputResolver.go @@ -0,0 +1,37 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// OutputResolver is an autogenerated mock type for the OutputResolver type +type OutputResolver struct { + mock.Mock +} + +// ExtractOutput provides a mock function with given fields: ctx, w, n, bindToVar +func (_m *OutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, bindToVar string) (*core.Literal, error) { + ret := _m.Called(ctx, w, n, bindToVar) + + var r0 *core.Literal + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) *core.Literal); ok { + r0 = rf(ctx, w, n, bindToVar) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Literal) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) error); ok { + r1 = rf(ctx, w, n, bindToVar) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go new file mode 100644 index 000000000..586998897 --- /dev/null +++ b/pkg/controller/nodes/handler_factory.go @@ -0,0 +1,76 @@ +package nodes + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/dynamic" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/branch" + "github.com/lyft/flytepropeller/pkg/controller/nodes/end" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/start" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytestdlib/storage" + "github.com/pkg/errors" +) + +//go:generate mockery -name HandlerFactory + +type HandlerFactory interface { + GetHandler(kind v1alpha1.NodeKind) (handler.IFace, error) +} + +type handlerFactory struct { + handlers map[v1alpha1.NodeKind]handler.IFace +} + +func (f handlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.IFace, error) { + h, ok := f.handlers[kind] + if !ok { + return nil, errors.Errorf("Handler not registered for NodeKind [%v]", kind) + } + return h, nil +} + +func NewHandlerFactory(ctx context.Context, + executor executors.Node, + eventSink events.EventSink, + workflowLauncher launchplan.Executor, + enQWorkflow v1alpha1.EnqueueWorkflow, + revalPeriod time.Duration, + store *storage.DataStore, + catalogClient catalog.Client, + kubeClient executors.Client, + scope promutils.Scope, +) (HandlerFactory, error) { + + f := &handlerFactory{ + handlers: map[v1alpha1.NodeKind]handler.IFace{ + v1alpha1.NodeKindBranch: branch.New(executor, eventSink, scope), + v1alpha1.NodeKindTask: dynamic.New( + task.New(eventSink, store, enQWorkflow, revalPeriod, catalogClient, kubeClient, scope), + executor, + enQWorkflow, + store, + scope), + v1alpha1.NodeKindWorkflow: subworkflow.New(executor, eventSink, workflowLauncher, enQWorkflow, store, scope), + v1alpha1.NodeKindStart: start.New(store), + v1alpha1.NodeKindEnd: end.New(store), + }, + } + for _, v := range f.handlers { + if err := v.Initialize(ctx); err != nil { + return nil, err + } + } + return f, nil +} diff --git a/pkg/controller/nodes/mocks/HandlerFactory.go b/pkg/controller/nodes/mocks/HandlerFactory.go new file mode 100644 index 000000000..024d0c352 --- /dev/null +++ b/pkg/controller/nodes/mocks/HandlerFactory.go @@ -0,0 +1,36 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +import mock "github.com/stretchr/testify/mock" + +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// HandlerFactory is an autogenerated mock type for the HandlerFactory type +type HandlerFactory struct { + mock.Mock +} + +// GetHandler provides a mock function with given fields: kind +func (_m *HandlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.IFace, error) { + ret := _m.Called(kind) + + var r0 handler.IFace + if rf, ok := ret.Get(0).(func(v1alpha1.NodeKind) handler.IFace); ok { + r0 = rf(kind) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(handler.IFace) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(v1alpha1.NodeKind) error); ok { + r1 = rf(kind) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/predicate.go b/pkg/controller/nodes/predicate.go new file mode 100644 index 000000000..cd718d8bc --- /dev/null +++ b/pkg/controller/nodes/predicate.go @@ -0,0 +1,111 @@ +package nodes + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytestdlib/logger" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// Special enum to indicate if the node under consideration is ready to be executed or should be skipped +type PredicatePhase int + +const ( + // Indicates node is not yet ready to be executed + PredicatePhaseNotReady PredicatePhase = iota + // Indicates node is ready to be executed - execution should proceed + PredicatePhaseReady + // Indicates that the node execution should be skipped as one of its parents was skipped or the branch was not taken + PredicatePhaseSkip + // Indicates failure during Predicate check + PredicatePhaseUndefined +) + +func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode) (PredicatePhase, error) { + nodeID := node.GetID() + if nodeID == v1alpha1.StartNodeID { + logger.Debugf(ctx, "Start Node id is assumed to be ready.") + return PredicatePhaseReady, nil + } + nodeStatus := w.GetNodeExecutionStatus(nodeID) + parentNodeID := nodeStatus.GetParentNodeID() + upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] + if !ok { + return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") + } + skipped := false + for _, upstreamNodeID := range upstreamNodes { + upstreamNodeStatus := w.GetNodeExecutionStatus(upstreamNodeID) + + if upstreamNodeStatus.IsDirty() { + return PredicatePhaseNotReady, nil + } + + if parentNodeID != nil && *parentNodeID == upstreamNodeID { + upstreamNode, ok := w.GetNode(upstreamNodeID) + if !ok { + return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) + } + // This only happens if current node is the child node of a branch node + if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { + logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") + return PredicatePhaseUndefined, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) + } + continue + } + + if upstreamNodeStatus.GetPhase() == v1alpha1.NodePhaseSkipped { + skipped = true + } else if upstreamNodeStatus.GetPhase() != v1alpha1.NodePhaseSucceeded { + return PredicatePhaseNotReady, nil + } + } + if skipped { + return PredicatePhaseSkip, nil + } + return PredicatePhaseReady, nil +} + +func GetParentNodeMaxEndTime(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode) (t v1.Time, err error) { + zeroTime := v1.NewTime(time.Time{}) + nodeID := node.GetID() + if nodeID == v1alpha1.StartNodeID { + logger.Debugf(ctx, "Start Node id is assumed to be ready.") + return zeroTime, nil + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + parentNodeID := nodeStatus.GetParentNodeID() + upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] + if !ok { + return zeroTime, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") + } + + var latest v1.Time + for _, upstreamNodeID := range upstreamNodes { + upstreamNodeStatus := w.GetNodeExecutionStatus(upstreamNodeID) + if parentNodeID != nil && *parentNodeID == upstreamNodeID { + upstreamNode, ok := w.GetNode(upstreamNodeID) + if !ok { + return zeroTime, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) + } + + // This only happens if current node is the child node of a branch node + if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { + logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") + return zeroTime, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) + } + + continue + } + + if stoppedAt := upstreamNodeStatus.GetStoppedAt(); stoppedAt != nil && stoppedAt.Unix() > latest.Unix() { + latest = *upstreamNodeStatus.GetStoppedAt() + } + } + + return latest, nil +} diff --git a/pkg/controller/nodes/predicate_test.go b/pkg/controller/nodes/predicate_test.go new file mode 100644 index 000000000..ae151010b --- /dev/null +++ b/pkg/controller/nodes/predicate_test.go @@ -0,0 +1,550 @@ +package nodes + +import ( + "context" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/stretchr/testify/assert" +) + +func TestCanExecute(t *testing.T) { + nodeN0 := "n0" + nodeN1 := "n1" + nodeN2 := "n2" + ctx := context.Background() + connections := &v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + nodeN2: {nodeN0, nodeN1}, + }, + } + + // Table tests are not really helpful here, so we decided against it + + t.Run("startNode", func(t *testing.T) { + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(v1alpha1.StartNodeID) + p, err := CanExecute(ctx, nil, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseReady, p) + }) + + t.Run("noUpstreamConnection", func(t *testing.T) { + // Setup + mockNodeStatus := &mocks.ExecutableNodeStatus{} + // No parent node + mockNodeStatus.On("GetParentNodeID").Return(nil) + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockNodeStatus) + mockWf.On("GetConnections").Return(&v1alpha1.Connections{}) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + t.Run("upstreamConnectionsNotReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsPartialReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsCompletelyReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseReady, p) + }) + + t.Run("upstreamConnectionsDirty", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(true) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsPartialSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsOneSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseSkip, p) + }) + + t.Run("upstreamConnectionsAllSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseSkip, p) + }) + + // Failed should never happen for predicate check. Hence we return not ready + t.Run("upstreamConnectionsFailed", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseFailed) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseFailed) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + // Branch node tests + + // ParentNode not found? + t.Run("upstreamConnectionsParentNodeNotFound", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(nil, false) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode has no branch node + t.Run("upstreamConnectionsParentHasNoBranch", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(nil) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode branch not ready + t.Run("upstreamConnectionsBranchNodeNotReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeNotYetEvaluated) + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode branch is errored + t.Run("upstreamConnectionsBranchNodeError", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeError) + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode branch ready + t.Run("upstreamConnectionsBranchSuccessOtherSuccess", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeSuccess) + mockN0BranchNode := &mocks.ExecutableBranchNode{} + + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseReady, p) + }) + + // ParentNode branch ready + t.Run("upstreamConnectionsBranchSuccessOtherSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeSuccess) + + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseSkip, p) + }) + + // ParentNode branch ready + t.Run("upstreamConnectionsBranchSuccessOtherRunning", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeSuccess) + + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) +} diff --git a/pkg/controller/nodes/resolve.go b/pkg/controller/nodes/resolve.go new file mode 100644 index 000000000..c0a247267 --- /dev/null +++ b/pkg/controller/nodes/resolve.go @@ -0,0 +1,104 @@ +package nodes + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/common" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/storage" +) + +func ResolveBindingData(ctx context.Context, h HandlerFactory, w v1alpha1.ExecutableWorkflow, bindingData *core.BindingData, store storage.ProtobufStore) (*core.Literal, error) { + literal := &core.Literal{} + if bindingData == nil { + return nil, nil + } + switch bindingData.GetValue().(type) { + case *core.BindingData_Collection: + literalCollection := make([]*core.Literal, 0, len(bindingData.GetCollection().GetBindings())) + for _, b := range bindingData.GetCollection().GetBindings() { + l, err := ResolveBindingData(ctx, h, w, b, store) + if err != nil { + return nil, err + } + + literalCollection = append(literalCollection, l) + } + literal.Value = &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literalCollection, + }, + } + case *core.BindingData_Map: + literalMap := make(map[string]*core.Literal, len(bindingData.GetMap().GetBindings())) + for k, v := range bindingData.GetMap().GetBindings() { + l, err := ResolveBindingData(ctx, h, w, v, store) + if err != nil { + return nil, err + } + + literalMap[k] = l + } + literal.Value = &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: literalMap, + }, + } + case *core.BindingData_Promise: + upstreamNodeID := bindingData.GetPromise().GetNodeId() + bindToVar := bindingData.GetPromise().GetVar() + if w == nil { + return nil, errors.Errorf(errors.IllegalStateError, upstreamNodeID, + "Trying to resolve output from previous node, without providing the workflow for variable [%s]", + bindToVar) + } + if upstreamNodeID == "" { + return nil, errors.Errorf(errors.BadSpecificationError, "missing", + "No nodeId (missing) specified for binding in Workflow.") + } + n, ok := w.GetNode(upstreamNodeID) + if !ok { + return nil, errors.Errorf(errors.IllegalStateError, w.GetID(), upstreamNodeID, + "Undefined node in Workflow") + } + + nodeHandler, err := h.GetHandler(n.GetKind()) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, "Failed to find handler for node kind [%v]", n.GetKind()) + } + + resolver, casted := nodeHandler.(handler.OutputResolver) + if !casted { + // If the handler doesn't implement output resolver, use simple resolver which expects an outputs.pb at the + // output location of the task. + if store == nil { + return nil, errors.Errorf(errors.IllegalStateError, w.GetID(), n.GetID(), "System error. Promise lookup without store.") + } + + resolver = common.NewSimpleOutputsResolver(store) + } + + return resolver.ExtractOutput(ctx, w, n, bindToVar) + case *core.BindingData_Scalar: + literal.Value = &core.Literal_Scalar{Scalar: bindingData.GetScalar()} + } + return literal, nil +} + +func Resolve(ctx context.Context, h HandlerFactory, w v1alpha1.ExecutableWorkflow, nodeID v1alpha1.NodeID, bindings []*v1alpha1.Binding, store storage.ProtobufStore) (*handler.Data, error) { + literalMap := make(map[string]*core.Literal, len(bindings)) + for _, binding := range bindings { + l, err := ResolveBindingData(ctx, h, w, binding.GetBinding(), store) + if err != nil { + return nil, errors.Wrapf(errors.BindingResolutionError, nodeID, err, "Error binding Var [%v].[%v]", w.GetID(), binding.GetVar()) + } + literalMap[binding.GetVar()] = l + } + return &core.LiteralMap{ + Literals: literalMap, + }, nil +} diff --git a/pkg/controller/nodes/resolve_test.go b/pkg/controller/nodes/resolve_test.go new file mode 100644 index 000000000..3e5c98032 --- /dev/null +++ b/pkg/controller/nodes/resolve_test.go @@ -0,0 +1,434 @@ +package nodes + +import ( + "context" + "fmt" + "testing" + + mocks2 "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/mocks" + "github.com/stretchr/testify/mock" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/utils" + flyteassert "github.com/lyft/flytepropeller/pkg/utils/assert" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +var testScope = promutils.NewScope("test") + +type dummyBaseWorkflow struct { + DummyStartNode v1alpha1.ExecutableNode + ID v1alpha1.WorkflowID + FromNodeCb func(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) + GetNodeCb func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) + Status map[v1alpha1.NodeID]*v1alpha1.NodeStatus +} + +func (d *dummyBaseWorkflow) GetOutputBindings() []*v1alpha1.Binding { + return []*v1alpha1.Binding{} +} + +func (d *dummyBaseWorkflow) GetOnFailureNode() v1alpha1.ExecutableNode { + return nil +} + +func (d *dummyBaseWorkflow) GetNodes() []v1alpha1.NodeID { + return []v1alpha1.NodeID{d.DummyStartNode.GetID()} +} + +func (d *dummyBaseWorkflow) GetConnections() *v1alpha1.Connections { + return &v1alpha1.Connections{} +} + +func (d *dummyBaseWorkflow) GetOutputs() *v1alpha1.OutputVarMap { + return &v1alpha1.OutputVarMap{} +} + +func (d *dummyBaseWorkflow) GetExecutionID() v1alpha1.ExecutionID { + return v1alpha1.ExecutionID{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Name: "test", + }, + } +} + +func (d *dummyBaseWorkflow) GetK8sWorkflowID() types.NamespacedName { + return types.NamespacedName{ + Name: "WF_Name", + } +} + +func (d *dummyBaseWorkflow) NewControllerRef() v1.OwnerReference { + return v1.OwnerReference{} +} + +func (d *dummyBaseWorkflow) GetNamespace() string { + return d.GetK8sWorkflowID().Namespace +} + +func (d *dummyBaseWorkflow) GetCreationTimestamp() v1.Time { + return v1.Now() +} + +func (d *dummyBaseWorkflow) GetAnnotations() map[string]string { + return map[string]string{} +} + +func (d *dummyBaseWorkflow) GetLabels() map[string]string { + return map[string]string{} +} + +func (d *dummyBaseWorkflow) GetName() string { + return d.ID +} + +func (d *dummyBaseWorkflow) GetServiceAccountName() string { + return "" +} + +func (d *dummyBaseWorkflow) GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) { + return nil, nil +} + +func (d *dummyBaseWorkflow) FindSubWorkflow(subID v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow { + return nil +} + +func (d *dummyBaseWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + return nil +} + +func (d *dummyBaseWorkflow) GetNodeExecutionStatus(id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { + n, ok := d.Status[id] + if ok { + return n + } + n = &v1alpha1.NodeStatus{} + d.Status[id] = n + return n +} + +func (d *dummyBaseWorkflow) StartNode() v1alpha1.ExecutableNode { + return d.DummyStartNode +} + +func (d *dummyBaseWorkflow) GetID() v1alpha1.WorkflowID { + return d.ID +} + +func (d *dummyBaseWorkflow) FromNode(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return d.FromNodeCb(name) +} + +func (d *dummyBaseWorkflow) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + return d.GetNodeCb(nodeID) +} + +func createDummyBaseWorkflow() *dummyBaseWorkflow { + return &dummyBaseWorkflow{ + ID: "w1", + Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.StartNodeID: {}, + }, + } +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func createFailingDatastore(_ testing.TB, scope promutils.Scope) *storage.DataStore { + return storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, scope)) +} + +func TestResolveBindingData(t *testing.T) { + ctx := context.Background() + outputRef := v1alpha1.DataReference("output-ref") + n1 := &v1alpha1.NodeSpec{ + ID: "n1", + OutputAliases: []v1alpha1.Alias{ + {Alias: core.Alias{ + Var: "x", + Alias: "m", + }}, + }, + } + + n2 := &v1alpha1.NodeSpec{ + ID: "n2", + OutputAliases: []v1alpha1.Alias{ + {Alias: core.Alias{ + Var: "x", + Alias: "m", + }}, + }, + } + + outputPath := v1alpha1.GetOutputsFile(outputRef) + + w := &dummyBaseWorkflow{ + Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + "n2": { + DataDir: outputRef, + }, + }, + GetNodeCb: func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + switch nodeId { + case "n1": + return n1, true + case "n2": + return n2, true + } + return nil, false + }, + } + + hf := &mocks.HandlerFactory{} + h := &mocks2.IFace{} + hf.On("GetHandler", mock.Anything).Return(h, nil) + h.On("ExtractOutput", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + + t.Run("StaticBinding", func(t *testing.T) { + w := &dummyBaseWorkflow{} + b := utils.MustMakePrimitiveBindingData(1) + l, err := ResolveBindingData(ctx, hf, w, b, nil) + assert.NoError(t, err) + flyteassert.EqualLiterals(t, utils.MustMakeLiteral(1), l) + }) + + t.Run("PromiseMissingNode", func(t *testing.T) { + w := &dummyBaseWorkflow{ + GetNodeCb: func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + return nil, false + }, + } + b := utils.MakeBindingDataPromise("n1", "x") + _, err := ResolveBindingData(ctx, nil, w, b, nil) + assert.Error(t, err) + }) + + t.Run("PromiseMissingStore", func(t *testing.T) { + b := utils.MakeBindingDataPromise("n1", "x") + _, err := ResolveBindingData(ctx, hf, w, b, nil) + assert.Error(t, err) + }) + + t.Run("PromiseMissing", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("1")) + b := utils.MakeBindingDataPromise("n1", "x") + _, err := ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + }) + + t.Run("PromiseMissingWithData", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("2")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"z": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + b := utils.MakeBindingDataPromise("n1", "x") + _, err = ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + }) + + t.Run("PromiseFound", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("3")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + + b := utils.MakeBindingDataPromise("n2", "x") + l, err := ResolveBindingData(ctx, hf, w, b, store) + if assert.NoError(t, err) { + flyteassert.EqualLiterals(t, utils.MustMakeLiteral(1), l) + } + }) + + t.Run("NullBinding", func(t *testing.T) { + l, err := ResolveBindingData(ctx, hf, w, nil, nil) + assert.NoError(t, err) + assert.Nil(t, l) + }) + + t.Run("NullWorkflowPromise", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("4")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + b := utils.MakeBindingDataPromise("n1", "x") + _, err = ResolveBindingData(ctx, nil, nil, b, store) + assert.Error(t, err) + }) + + t.Run("PromiseFoundAlias", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("5")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + b := utils.MakeBindingDataPromise("n2", "m") + l, err := ResolveBindingData(ctx, hf, w, b, store) + if assert.NoError(t, err) { + flyteassert.EqualLiterals(t, utils.MustMakeLiteral(1), l) + } + }) + + t.Run("BindingDataMap", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("6")) + // Store output of previous + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + m2 := &core.LiteralMap{} + assert.NoError(t, store.ReadProtobuf(ctx, outputPath, m2)) + // Output of current + b := utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n2", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + ) + l, err := ResolveBindingData(ctx, hf, w, b, store) + if assert.NoError(t, err) { + expected, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1, "z": 5}) + assert.NoError(t, err) + flyteassert.EqualLiteralMap(t, expected, l.GetMap()) + } + + }) + + t.Run("BindingDataMapFailedPromise", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("7")) + // do not store anything + + // Output of current + b := utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n1", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + ) + _, err := ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + }) + + t.Run("BindingDataCollection", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("8")) + // Store random value + m, err := utils.MakeLiteralMap(map[string]interface{}{"jj": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + + // binding of current npde + b := utils.MakeBindingDataCollection( + utils.MakeBindingDataPromise("n1", "x"), + utils.MustMakePrimitiveBindingData(5), + ) + _, err = ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + + }) +} + +func TestResolve(t *testing.T) { + ctx := context.Background() + outputRef := v1alpha1.DataReference("output-ref") + n1 := &v1alpha1.NodeSpec{ + ID: "n1", + OutputAliases: []v1alpha1.Alias{ + {Alias: core.Alias{ + Var: "x", + Alias: "m", + }}, + }, + } + + outputPath := v1alpha1.GetOutputsFile(outputRef) + + w := &dummyBaseWorkflow{ + Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + "n1": { + DataDir: outputRef, + }, + }, + GetNodeCb: func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + if nodeId == "n1" { + return n1, true + } + return nil, false + }, + } + + t.Run("SimpleResolve", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("9")) + // Store output of previous + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + + //bindings + b := []*v1alpha1.Binding{ + { + Binding: utils.MakeBinding("map", utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n1", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + )), + }, + { + Binding: utils.MakeBinding("simple", utils.MustMakePrimitiveBindingData(1)), + }, + } + + hf := &mocks.HandlerFactory{} + h := &mocks2.IFace{} + hf.On("GetHandler", mock.Anything).Return(h, nil) + expected, err := utils.MakeLiteralMap(map[string]interface{}{ + "map": map[string]interface{}{"x": 1, "z": 5}, + "simple": utils.MustMakePrimitiveLiteral(1), + }) + assert.NoError(t, err) + + h.On("ExtractOutput", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expected, nil) + + l, err := Resolve(ctx, hf, w, "n2", b, store) + if assert.NoError(t, err) { + assert.NotNil(t, l) + if assert.NoError(t, err) { + flyteassert.EqualLiteralMap(t, expected, l) + } + } + }) + + t.Run("SimpleResolveFail", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("10")) + // Store has no previous output + + //bindings + b := []*v1alpha1.Binding{ + { + Binding: utils.MakeBinding("map", utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n1", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + )), + }, + { + Binding: utils.MakeBinding("simple", utils.MustMakePrimitiveBindingData(1)), + }, + } + + hf := &mocks.HandlerFactory{} + h := &mocks2.IFace{} + hf.On("GetHandler", mock.Anything).Return(h, nil) + h.On("ExtractOutput", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("No outputs")) + + _, err := Resolve(ctx, hf, w, "n2", b, store) + assert.Error(t, err) + }) + +} diff --git a/pkg/controller/nodes/start/handler.go b/pkg/controller/nodes/start/handler.go new file mode 100644 index 000000000..61034cf4c --- /dev/null +++ b/pkg/controller/nodes/start/handler.go @@ -0,0 +1,40 @@ +package start + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/storage" +) + +type startHandler struct { + store *storage.DataStore +} + +func (s startHandler) Initialize(ctx context.Context) error { + return nil +} + +func (s *startHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + return handler.StatusSuccess, nil +} + +func (s *startHandler) CheckNodeStatus(ctx context.Context, g v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + return handler.StatusSuccess, nil +} + +func (s *startHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, node.GetID(), "start node cannot enter a failing state")), nil +} + +func (s *startHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + return nil +} + +func New(store *storage.DataStore) handler.IFace { + return &startHandler{ + store: store, + } +} diff --git a/pkg/controller/nodes/start/handler_test.go b/pkg/controller/nodes/start/handler_test.go new file mode 100644 index 000000000..18004308c --- /dev/null +++ b/pkg/controller/nodes/start/handler_test.go @@ -0,0 +1,77 @@ +package start + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +var testScope = promutils.NewScope("start_test") + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func init() { + labeled.SetMetricKeys(contextutils.NodeIDKey) +} + +func TestStartNodeHandler_Initialize(t *testing.T) { + h := startHandler{} + // Do nothing + assert.NoError(t, h.Initialize(context.TODO())) +} + +func TestStartNodeHandler_StartNode(t *testing.T) { + ctx := context.Background() + mockStorage := createInmemoryDataStore(t, testScope.NewSubScope("z")) + h := New(mockStorage) + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + } + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: v1alpha1.WorkflowID("w1"), + }, + } + t.Run("NoInputs", func(t *testing.T) { + s, err := h.StartNode(ctx, w, node, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) + t.Run("WithInputs", func(t *testing.T) { + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.NodeID("n1"), + InputBindings: []*v1alpha1.Binding{ + { + Binding: utils.MakeBinding("x", utils.MustMakePrimitiveBindingData("hello")), + }, + }, + } + s, err := h.StartNode(ctx, w, node, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} + +func TestStartNodeHandler_HandleNode(t *testing.T) { + ctx := context.Background() + h := startHandler{} + s, err := h.CheckNodeStatus(ctx, nil, nil, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) +} diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go new file mode 100644 index 000000000..8be387941 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -0,0 +1,78 @@ +package subworkflow + +import ( + "context" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" +) + +type workflowNodeHandler struct { + recorder events.WorkflowEventRecorder + lpHandler launchPlanHandler + subWfHandler subworkflowHandler +} + +func (w *workflowNodeHandler) Initialize(ctx context.Context) error { + return nil +} + +func (w *workflowNodeHandler) StartNode(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.StartSubWorkflow(ctx, wf, node, nodeInputs) + } + + if node.GetWorkflowNode().GetLaunchPlanRefID() != nil { + return w.lpHandler.StartLaunchPlan(ctx, wf, node, nodeInputs) + } + + return handler.StatusFailed(errors.Errorf(errors.BadSpecificationError, node.GetID(), "SubWorkflow is incorrectly specified.")), nil +} + +func (w *workflowNodeHandler) CheckNodeStatus(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, status v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.CheckSubWorkflowStatus(ctx, wf, node, status) + } + + if node.GetWorkflowNode().GetLaunchPlanRefID() != nil { + return w.lpHandler.CheckLaunchPlanStatus(ctx, wf, node, status) + } + + return handler.StatusFailed(errors.Errorf(errors.BadSpecificationError, node.GetID(), "workflow node does not have a subworkflow or child workflow reference")), nil +} + +func (w *workflowNodeHandler) HandleFailingNode(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + if node.GetWorkflowNode() != nil && node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.HandleSubWorkflowFailingNode(ctx, wf, node) + } + return handler.StatusFailed(nil), nil +} + +func (w *workflowNodeHandler) AbortNode(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.HandleAbort(ctx, wf, node) + } + + if node.GetWorkflowNode().GetLaunchPlanRefID() != nil { + return w.lpHandler.HandleAbort(ctx, wf, node) + } + return nil +} + +func New(executor executors.Node, eventSink events.EventSink, workflowLauncher launchplan.Executor, enQWorkflow v1alpha1.EnqueueWorkflow, store *storage.DataStore, scope promutils.Scope) handler.IFace { + subworkflowScope := scope.NewSubScope("workflow") + return &workflowNodeHandler{ + subWfHandler: newSubworkflowHandler(executor, enQWorkflow, store), + lpHandler: launchPlanHandler{ + store: store, + launchPlan: workflowLauncher, + }, + recorder: events.NewWorkflowEventRecorder(eventSink, subworkflowScope), + } +} diff --git a/pkg/controller/nodes/subworkflow/handler_test.go b/pkg/controller/nodes/subworkflow/handler_test.go new file mode 100644 index 000000000..124079e5d --- /dev/null +++ b/pkg/controller/nodes/subworkflow/handler_test.go @@ -0,0 +1,230 @@ +package subworkflow + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + mockWfNode.On("GetSubWorkflowRef").Return(nil) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + wfStatus := &mocks2.MutableWorkflowNodeStatus{} + mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) + wfStatus.On("SetWorkflowExecutionName", + mock.MatchedBy(func(name string) bool { + return name == "x-n1-1" + }), + ).Return() + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + ni := &core.LiteralMap{} + + t.Run("happy", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(nil) + + s, err := h.StartNode(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) +} + +func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + mockWfNode.On("GetSubWorkflowRef").Return(nil) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("stillRunning", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + }, nil) + + s, err := h.CheckNodeStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) +} + +func TestWorkflowNodeHandler_AbortNode(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + mockWfNode.On("GetSubWorkflowRef").Return(nil) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetName").Return("test") + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("abort", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(nil) + + err := h.AbortNode(ctx, mockWf, mockNode) + assert.NoError(t, err) + }) + + t.Run("abort-fail", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + expectedErr := fmt.Errorf("fail") + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(expectedErr) + + err := h.AbortNode(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, err, expectedErr) + }) +} + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan.go new file mode 100644 index 000000000..03c0b71d9 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan.go @@ -0,0 +1,139 @@ +package subworkflow + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +type launchPlanHandler struct { + launchPlan launchplan.Executor + store *storage.DataStore +} + +func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + childID, err := GetChildWorkflowExecutionID( + w.GetExecutionID().WorkflowExecutionIdentifier, + node.GetID(), + nodeStatus.GetAttempts(), + ) + + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.RuntimeExecutionError, node.GetID(), err, "failed to create unique ID")), nil + } + + launchCtx := launchplan.LaunchContext{ + // TODO we need to add principal and nestinglevel as annotations or labels? + Principal: "unknown", + NestingLevel: 0, + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, + }, + } + err = l.launchPlan.Launch(ctx, launchCtx, childID, node.GetWorkflowNode().GetLaunchPlanRefID().Identifier, nodeInputs) + if err != nil { + if launchplan.IsAlreadyExists(err) { + logger.Info(ctx, "Execution already exists [%s].", childID.Name) + } else if launchplan.IsUserError(err) { + return handler.StatusFailed(err), nil + } else { + return handler.StatusUndefined, err + } + } else { + logger.Infof(ctx, "Launched launchplan with ID [%s]", childID.Name) + } + + nodeStatus.GetOrCreateWorkflowStatus().SetWorkflowExecutionName(childID.Name) + return handler.StatusRunning, nil +} + +func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, status v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + // Handle launch plan + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + childID, err := GetChildWorkflowExecutionID( + w.GetExecutionID().WorkflowExecutionIdentifier, + node.GetID(), + nodeStatus.GetAttempts(), + ) + + if err != nil { + // THIS SHOULD NEVER HAPPEN + return handler.StatusFailed(errors.Wrapf(errors.RuntimeExecutionError, node.GetID(), err, "failed to create unique ID")), nil + } + + wfStatusClosure, err := l.launchPlan.GetStatus(ctx, childID) + if err != nil { + if launchplan.IsNotFound(err) { //NotFound + return handler.StatusFailed(err), nil + } + + return handler.StatusUndefined, err + } + + if wfStatusClosure == nil { + logger.Info(ctx, "Retrieved Launch Plan status is nil. This might indicate pressure on the admin cache."+ + " Consider tweaking its size to allow for more concurrent executions to be cached.") + return handler.StatusRunning, nil + } + + var wErr error + switch wfStatusClosure.GetPhase() { + case core.WorkflowExecution_ABORTED: + wErr = fmt.Errorf("launchplan execution aborted") + return handler.StatusFailed(errors.Wrapf(errors.RemoteChildWorkflowExecutionFailed, node.GetID(), wErr, "launchplan [%s] failed", childID.Name)), nil + case core.WorkflowExecution_FAILED: + wErr = fmt.Errorf("launchplan execution failed without explicit error") + if wfStatusClosure.GetError() != nil { + wErr = fmt.Errorf(" errorCode[%s]: %s", wfStatusClosure.GetError().Code, wfStatusClosure.GetError().Message) + } + return handler.StatusFailed(errors.Wrapf(errors.RemoteChildWorkflowExecutionFailed, node.GetID(), wErr, "launchplan [%s] failed", childID.Name)), nil + case core.WorkflowExecution_SUCCEEDED: + if wfStatusClosure.GetOutputs() != nil { + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + childOutput := &core.LiteralMap{} + uri := wfStatusClosure.GetOutputs().GetUri() + if uri != "" { + // Copy remote data to local S3 path + if err := l.store.ReadProtobuf(ctx, storage.DataReference(uri), childOutput); err != nil { + if storage.IsNotFound(err) { + return handler.StatusFailed(errors.Wrapf(errors.RemoteChildWorkflowExecutionFailed, node.GetID(), err, "remote output for launchplan execution was not found, uri [%s]", uri)), nil + } + return handler.StatusUndefined, errors.Wrapf(errors.RuntimeExecutionError, node.GetID(), err, "failed to read outputs from child workflow @ [%s]", uri) + } + + } else if wfStatusClosure.GetOutputs().GetValues() != nil { + // Store data to S3Path + childOutput = wfStatusClosure.GetOutputs().GetValues() + } + if err := l.store.WriteProtobuf(ctx, outputFile, storage.Options{}, childOutput); err != nil { + logger.Debugf(ctx, "failed to write data to Storage, err: %v", err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to copy outputs for child workflow") + } + } + return handler.StatusSuccess, nil + } + return handler.StatusRunning, nil +} + +func (l *launchPlanHandler) HandleAbort(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + childID, err := GetChildWorkflowExecutionID( + w.GetExecutionID().WorkflowExecutionIdentifier, + node.GetID(), + nodeStatus.GetAttempts(), + ) + if err != nil { + // THIS SHOULD NEVER HAPPEN + return err + } + return l.launchPlan.Kill(ctx, childID, fmt.Sprintf("parent execution id [%s] aborted", w.GetName())) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/admin.go b/pkg/controller/nodes/subworkflow/launchplan/admin.go new file mode 100644 index 000000000..00b52712f --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/admin.go @@ -0,0 +1,167 @@ +package launchplan + +import ( + "context" + "fmt" + "runtime/pprof" + "time" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flytestdlib/contextutils" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/utils" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Executor for Launchplans that executes on a remote FlyteAdmin service (if configured) +type adminLaunchPlanExecutor struct { + adminClient service.AdminServiceClient + cache utils.AutoRefreshCache +} + +type executionCacheItem struct { + core.WorkflowExecutionIdentifier + ExecutionClosure *admin.ExecutionClosure + SyncError error +} + +func (e executionCacheItem) ID() string { + return e.String() +} + +func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { + req := &admin.ExecutionCreateRequest{ + Project: executionID.Project, + Domain: executionID.Domain, + Name: executionID.Name, + Spec: &admin.ExecutionSpec{ + LaunchPlan: launchPlanRef, + Metadata: &admin.ExecutionMetadata{ + Mode: admin.ExecutionMetadata_SYSTEM, + Nesting: launchCtx.NestingLevel + 1, + Principal: launchCtx.Principal, + ParentNodeExecution: launchCtx.ParentNodeExecution, + }, + Inputs: inputs, + }, + } + _, err := a.adminClient.CreateExecution(ctx, req) + if err != nil { + statusCode := status.Code(err) + switch statusCode { + case codes.AlreadyExists: + _, err := a.cache.GetOrCreate(executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + logger.Errorf(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) + } + + return Wrapf(RemoteErrorAlreadyExists, err, "ExecID %s already exists", executionID.Name) + case codes.DataLoss, codes.DeadlineExceeded, codes.Internal, codes.Unknown, codes.Canceled: + return Wrapf(RemoteErrorSystem, err, "failed to launch workflow [%s], system error", launchPlanRef.Name) + default: + return Wrapf(RemoteErrorUser, err, "failed to launch workflow") + } + } + + _, err = a.cache.GetOrCreate(executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + logger.Info(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) + } + + return nil +} + +func (a *adminLaunchPlanExecutor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) { + if executionID == nil { + return nil, fmt.Errorf("nil executionID") + } + + obj, err := a.cache.GetOrCreate(executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + return nil, err + } + + item := obj.(executionCacheItem) + + return item.ExecutionClosure, item.SyncError +} + +func (a *adminLaunchPlanExecutor) Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error { + req := &admin.ExecutionTerminateRequest{ + Id: executionID, + Cause: reason, + } + _, err := a.adminClient.TerminateExecution(ctx, req) + if err != nil { + if status.Code(err) == codes.NotFound { + return nil + } + return Wrapf(RemoteErrorSystem, err, "system error") + } + return nil +} + +func (a *adminLaunchPlanExecutor) Initialize(ctx context.Context) error { + go func() { + // Set goroutine-label... + ctx = contextutils.WithGoroutineLabel(ctx, "admin-launcher") + pprof.SetGoroutineLabels(ctx) + a.cache.Start(ctx) + }() + + return nil +} + +func (a *adminLaunchPlanExecutor) syncItem(ctx context.Context, obj utils.CacheItem) ( + newItem utils.CacheItem, result utils.CacheSyncAction, err error) { + exec := obj.(executionCacheItem) + req := &admin.WorkflowExecutionGetRequest{ + Id: &exec.WorkflowExecutionIdentifier, + } + + res, err := a.adminClient.GetExecution(ctx, req) + if err != nil { + // TODO: Define which error codes are system errors (and return the error) vs user errors. + + if status.Code(err) == codes.NotFound { + err = Wrapf(RemoteErrorNotFound, err, "execID [%s] not found on remote", exec.WorkflowExecutionIdentifier.Name) + } else { + err = Wrapf(RemoteErrorSystem, err, "system error") + } + + return executionCacheItem{ + WorkflowExecutionIdentifier: exec.WorkflowExecutionIdentifier, + SyncError: err, + }, utils.Update, nil + } + + return executionCacheItem{ + WorkflowExecutionIdentifier: exec.WorkflowExecutionIdentifier, + ExecutionClosure: res.Closure, + }, utils.Update, nil +} + +func NewAdminLaunchPlanExecutor(_ context.Context, client service.AdminServiceClient, + syncPeriod time.Duration, cfg *AdminConfig, scope promutils.Scope) (Executor, error) { + exec := &adminLaunchPlanExecutor{ + adminClient: client, + } + + // TODO: make tps/burst/size configurable + cache, err := utils.NewAutoRefreshCache(exec.syncItem, utils.NewRateLimiter("adminSync", + float64(cfg.TPS), cfg.Burst), syncPeriod, cfg.MaxCacheSize, scope) + if err != nil { + return nil, err + } + + exec.cache = cache + return exec, nil +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/admin_test.go b/pkg/controller/nodes/subworkflow/launchplan/admin_test.go new file mode 100644 index 000000000..3384cedcb --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/admin_test.go @@ -0,0 +1,277 @@ +package launchplan + +import ( + "context" + "testing" + "time" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/clients/go/admin/mocks" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestAdminLaunchPlanExecutor_GetStatus(t *testing.T) { + ctx := context.TODO() + id := &core.WorkflowExecutionIdentifier{ + Name: "n", + Domain: "d", + Project: "p", + } + var result *admin.ExecutionClosure + + t.Run("happy", func(t *testing.T) { + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Millisecond, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + mockClient.On("GetExecution", + ctx, + mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), + ).Return(result, nil) + assert.NoError(t, err) + s, err := exec.GetStatus(ctx, id) + assert.NoError(t, err) + assert.Equal(t, result, s) + }) + + t.Run("notFound", func(t *testing.T) { + mockClient := &mocks.AdminServiceClient{} + + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil + }), + ).Return(nil, nil) + + mockClient.On("GetExecution", + mock.Anything, + mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), + ).Return(nil, status.Error(codes.NotFound, "")) + + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Millisecond, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, exec.Initialize(ctx)) + + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + + // Allow for sync to be called + time.Sleep(time.Second) + + s, err := exec.GetStatus(ctx, id) + assert.Error(t, err) + assert.Nil(t, s) + assert.True(t, IsNotFound(err)) + }) + + t.Run("other", func(t *testing.T) { + mockClient := &mocks.AdminServiceClient{} + + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil + }), + ).Return(nil, nil) + + mockClient.On("GetExecution", + mock.Anything, + mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), + ).Return(nil, status.Error(codes.Canceled, "")) + + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Millisecond, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, exec.Initialize(ctx)) + + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + + // Allow for sync to be called + time.Sleep(time.Second) + + s, err := exec.GetStatus(ctx, id) + assert.Error(t, err) + assert.Nil(t, s) + assert.False(t, IsNotFound(err)) + }) +} + +func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { + ctx := context.TODO() + id := &core.WorkflowExecutionIdentifier{ + Name: "n", + Domain: "d", + Project: "p", + } + + t.Run("happy", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil + }), + ).Return(nil, nil) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + }) + + t.Run("notFound", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { return true }), + ).Return(nil, status.Error(codes.AlreadyExists, "")) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.Error(t, err) + assert.True(t, IsAlreadyExists(err)) + }) + + t.Run("other", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { return true }), + ).Return(nil, status.Error(codes.Canceled, "")) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.Error(t, err) + assert.False(t, IsAlreadyExists(err)) + }) +} + +func TestAdminLaunchPlanExecutor_Kill(t *testing.T) { + ctx := context.TODO() + id := &core.WorkflowExecutionIdentifier{ + Name: "n", + Domain: "d", + Project: "p", + } + + const reason = "reason" + t.Run("happy", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("TerminateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionTerminateRequest) bool { return o.Id == id && o.Cause == reason }), + ).Return(&admin.ExecutionTerminateResponse{}, nil) + assert.NoError(t, err) + err = exec.Kill(ctx, id, reason) + assert.NoError(t, err) + }) + + t.Run("notFound", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("TerminateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionTerminateRequest) bool { return o.Id == id && o.Cause == reason }), + ).Return(nil, status.Error(codes.NotFound, "")) + assert.NoError(t, err) + err = exec.Kill(ctx, id, reason) + assert.NoError(t, err) + }) + + t.Run("other", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("TerminateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionTerminateRequest) bool { return o.Id == id && o.Cause == reason }), + ).Return(nil, status.Error(codes.Canceled, "")) + assert.NoError(t, err) + err = exec.Kill(ctx, id, reason) + assert.Error(t, err) + assert.False(t, IsNotFound(err)) + }) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go b/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go new file mode 100644 index 000000000..ed7b8d517 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go @@ -0,0 +1,33 @@ +package launchplan + +import ( + ctrlConfig "github.com/lyft/flytepropeller/pkg/controller/config" +) + +//go:generate pflags AdminConfig --default-var defaultAdminConfig + +var ( + defaultAdminConfig = &AdminConfig{ + TPS: 5, + Burst: 10, + MaxCacheSize: 10000, + } + + adminConfigSection = ctrlConfig.ConfigSection.MustRegisterSection("admin-launcher", defaultAdminConfig) +) + +type AdminConfig struct { + // TPS indicates the maximum transactions per second to flyte admin from this client. + // If it's zero, the created client will use DefaultTPS: 5 + TPS int64 `json:"tps" pflag:",The maximum number of transactions per second to flyte admin from this client."` + + // Maximum burst for throttle. + // If it's zero, the created client will use DefaultBurst: 10. + Burst int `json:"burst" pflag:",Maximum burst for throttle"` + + MaxCacheSize int `json:"cacheSize" pflag:",Maximum cache in terms of number of items stored."` +} + +func GetAdminConfig() *AdminConfig { + return adminConfigSection.GetConfig().(*AdminConfig) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go new file mode 100755 index 000000000..9b85965e7 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go @@ -0,0 +1,48 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package launchplan + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (AdminConfig) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (AdminConfig) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in AdminConfig and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg AdminConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("AdminConfig", pflag.ExitOnError) + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "tps"), defaultAdminConfig.TPS, "The maximum number of transactions per second to flyte admin from this client.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "burst"), defaultAdminConfig.Burst, "Maximum burst for throttle") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cacheSize"), defaultAdminConfig.MaxCacheSize, "Maximum cache in terms of number of items stored.") + return cmdFlags +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go new file mode 100755 index 000000000..79de09463 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go @@ -0,0 +1,168 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package launchplan + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsAdminConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementAdminConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsAdminConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookAdminConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementAdminConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_AdminConfig(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookAdminConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_AdminConfig(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_AdminConfig(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_AdminConfig(val, result)) +} + +func testDecodeSlice_AdminConfig(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_AdminConfig(vStringSlice, result)) +} + +func TestAdminConfig_GetPFlagSet(t *testing.T) { + val := AdminConfig{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestAdminConfig_SetFlags(t *testing.T) { + actual := AdminConfig{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_tps", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("tps"); err == nil { + assert.Equal(t, int64(defaultAdminConfig.TPS), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("tps", testValue) + if vInt64, err := cmdFlags.GetInt64("tps"); err == nil { + testDecodeJson_AdminConfig(t, fmt.Sprintf("%v", vInt64), &actual.TPS) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_burst", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("burst"); err == nil { + assert.Equal(t, int(defaultAdminConfig.Burst), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("burst", testValue) + if vInt, err := cmdFlags.GetInt("burst"); err == nil { + testDecodeJson_AdminConfig(t, fmt.Sprintf("%v", vInt), &actual.Burst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cacheSize", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("cacheSize"); err == nil { + assert.Equal(t, int(defaultAdminConfig.MaxCacheSize), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cacheSize", testValue) + if vInt, err := cmdFlags.GetInt("cacheSize"); err == nil { + testDecodeJson_AdminConfig(t, fmt.Sprintf("%v", vInt), &actual.MaxCacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/errors.go b/pkg/controller/nodes/subworkflow/launchplan/errors.go new file mode 100644 index 000000000..02b9203fe --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/errors.go @@ -0,0 +1,57 @@ +package launchplan + +import "fmt" + +type ErrorCode string + +const ( + RemoteErrorAlreadyExists ErrorCode = "AlreadyExists" + RemoteErrorNotFound ErrorCode = "NotFound" + RemoteErrorSystem = "SystemError" // timeouts, network error etc + RemoteErrorUser = "UserError" // Incase of bad specification, invalid arguments, etc +) + +type RemoteError struct { + Code ErrorCode + Cause error + Message string +} + +func (r RemoteError) Error() string { + return fmt.Sprintf("%s: %s, caused by [%s]", r.Code, r.Message, r.Cause.Error()) +} + +func Wrapf(code ErrorCode, cause error, msg string, args ...interface{}) error { + return &RemoteError{ + Code: code, + Cause: cause, + Message: fmt.Sprintf(msg, args...), + } +} + +// Checks if the error is of type RemoteError and the ErrorCode is of type RemoteErrorAlreadyExists +func IsAlreadyExists(err error) bool { + e, ok := err.(*RemoteError) + if ok { + return e.Code == RemoteErrorAlreadyExists + } + return false +} + +// Checks if the error is of type RemoteError and the ErrorCode is of type RemoteErrorUser +func IsUserError(err error) bool { + e, ok := err.(*RemoteError) + if ok { + return e.Code == RemoteErrorUser + } + return false +} + +// Checks if the error is of type RemoteError and the ErrorCode is of type RemoteErrorNotFound +func IsNotFound(err error) bool { + e, ok := err.(*RemoteError) + if ok { + return e.Code == RemoteErrorNotFound + } + return false +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/errors_test.go b/pkg/controller/nodes/subworkflow/launchplan/errors_test.go new file mode 100644 index 000000000..4519218fc --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/errors_test.go @@ -0,0 +1,36 @@ +package launchplan + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRemoteError(t *testing.T) { + t.Run("alreadyExists", func(t *testing.T) { + e := Wrapf(RemoteErrorAlreadyExists, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.True(t, IsAlreadyExists(e)) + }) + + t.Run("notfound", func(t *testing.T) { + e := Wrapf(RemoteErrorNotFound, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.True(t, IsNotFound(e)) + }) + + t.Run("alreadyExists", func(t *testing.T) { + e := Wrapf(RemoteErrorUser, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.True(t, IsUserError(e)) + }) + + t.Run("system", func(t *testing.T) { + e := Wrapf(RemoteErrorSystem, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.False(t, IsAlreadyExists(e)) + assert.False(t, IsNotFound(e)) + assert.False(t, IsUserError(e)) + }) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan/launchplan.go new file mode 100644 index 000000000..413f4a2b8 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/launchplan.go @@ -0,0 +1,36 @@ +package launchplan + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +//go:generate mockery -name Executor + +// A simple context that is used to start an execution of a LaunchPlan. It encapsulates enough parent information +// to tie the executions +type LaunchContext struct { + // Nesting level of the current workflow (parent) + NestingLevel uint32 + // Principal of the current workflow, so that billing can be tied correctly + Principal string + // If a node launched the execution, this specifies which node execution + ParentNodeExecution *core.NodeExecutionIdentifier +} + +// Interface to be implemented by the remote system that can allow workflow launching capabilities +type Executor interface { + // Start an execution of a launchplan + Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error + + // Retrieve status of a LaunchPlan execution + GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) + + // Kill a remote execution + Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error + + // Initializes Executor. + Initialize(ctx context.Context) error +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go b/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go new file mode 100644 index 000000000..83d28be3b --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go @@ -0,0 +1,79 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import admin "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import launchplan "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" +import mock "github.com/stretchr/testify/mock" + +// Executor is an autogenerated mock type for the Executor type +type Executor struct { + mock.Mock +} + +// GetStatus provides a mock function with given fields: ctx, executionID +func (_m *Executor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) { + ret := _m.Called(ctx, executionID) + + var r0 *admin.ExecutionClosure + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier) *admin.ExecutionClosure); ok { + r0 = rf(ctx, executionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ExecutionClosure) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier) error); ok { + r1 = rf(ctx, executionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: ctx +func (_m *Executor) Initialize(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Kill provides a mock function with given fields: ctx, executionID, reason +func (_m *Executor) Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error { + ret := _m.Called(ctx, executionID, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, string) error); ok { + r0 = rf(ctx, executionID, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Launch provides a mock function with given fields: ctx, launchCtx, executionID, launchPlanRef, inputs +func (_m *Executor) Launch(ctx context.Context, launchCtx launchplan.LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { + ret := _m.Called(ctx, launchCtx, executionID, launchPlanRef, inputs) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, launchplan.LaunchContext, *core.WorkflowExecutionIdentifier, *core.Identifier, *core.LiteralMap) error); ok { + r0 = rf(ctx, launchCtx, executionID, launchPlanRef, inputs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/noop.go b/pkg/controller/nodes/subworkflow/launchplan/noop.go new file mode 100644 index 000000000..f913a8f34 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/noop.go @@ -0,0 +1,37 @@ +package launchplan + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" +) + +type failFastWorkflowLauncher struct { +} + +func (failFastWorkflowLauncher) Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { + logger.Infof(ctx, "Fail: Launch Workflow requested with ExecID [%s], LaunchPlan [%s]", executionID.Name, fmt.Sprintf("%s:%s:%s", launchPlanRef.Project, launchPlanRef.Domain, launchPlanRef.Name)) + return Wrapf(RemoteErrorUser, fmt.Errorf("badly configured system"), "please enable admin workflow launch to use launchplans") +} + +func (failFastWorkflowLauncher) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) { + logger.Infof(ctx, "NOOP: Workflow Status ExecID [%s]", executionID.Name) + return nil, Wrapf(RemoteErrorUser, fmt.Errorf("badly configured system"), "please enable admin workflow launch to use launchplans") +} + +func (failFastWorkflowLauncher) Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error { + return nil +} + +// Initializes Executor. +func (failFastWorkflowLauncher) Initialize(ctx context.Context) error { + return nil +} + +func NewFailFastLaunchPlanExecutor() Executor { + logger.Infof(context.TODO(), "created failFast workflow launcher, will not launch subworkflows.") + return &failFastWorkflowLauncher{} +} diff --git a/pkg/controller/nodes/subworkflow/launchplan/noop_test.go b/pkg/controller/nodes/subworkflow/launchplan/noop_test.go new file mode 100644 index 000000000..f75d54250 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan/noop_test.go @@ -0,0 +1,51 @@ +package launchplan + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestFailFastWorkflowLauncher(t *testing.T) { + ctx := context.TODO() + f := NewFailFastLaunchPlanExecutor() + t.Run("getStatus", func(t *testing.T) { + a, err := f.GetStatus(ctx, &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }) + assert.Nil(t, a) + assert.Error(t, err) + }) + + t.Run("launch", func(t *testing.T) { + err := f.Launch(ctx, LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }, + }, + }, &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }, &core.Identifier{}, + nil) + assert.Error(t, err) + }) + + t.Run("kill", func(t *testing.T) { + err := f.Kill(ctx, &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }, "reason") + assert.NoError(t, err) + }) +} diff --git a/pkg/controller/nodes/subworkflow/launchplan_test.go b/pkg/controller/nodes/subworkflow/launchplan_test.go new file mode 100644 index 000000000..400203e68 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -0,0 +1,640 @@ +package subworkflow + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func createInmemoryStore(t testing.TB) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + + d, err := storage.NewDataStore(&cfg, promutils.NewTestScope()) + assert.NoError(t, err) + + return d +} + +func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + ni := &core.LiteralMap{} + + t.Run("happy", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(nil) + + wfStatus := &mocks2.MutableWorkflowNodeStatus{} + mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) + wfStatus.On("SetWorkflowExecutionName", + mock.MatchedBy(func(name string) bool { + return name == "x-n1-1" + }), + ).Return() + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) + + t.Run("alreadyExists", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(launchplan.Wrapf(launchplan.RemoteErrorAlreadyExists, fmt.Errorf("blah"), "failed")) + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) + + t.Run("systemError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(launchplan.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("blah"), "failed")) + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) + + t.Run("userError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(launchplan.Wrapf(launchplan.RemoteErrorUser, fmt.Errorf("blah"), "failed")) + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) +} + +func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { + + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("stillRunning", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) + + t.Run("successNoOutputs", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseSuccess) + }) + + t.Run("successOutputURI", func(t *testing.T) { + + mockStore := createInmemoryStore(t) + mockLPExec := &mocks.Executor{} + uri := storage.DataReference("uri") + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + op := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + }, + } + err := mockStore.WriteProtobuf(ctx, uri, storage.Options{}, op) + assert.NoError(t, err) + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Uri{ + Uri: uri.String(), + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseSuccess) + final := &core.LiteralMap{} + assert.NoError(t, mockStore.ReadProtobuf(ctx, v1alpha1.GetOutputsFile(dataDir), final)) + v, ok := final.GetLiterals()["x"] + assert.True(t, ok) + assert.Equal(t, int64(1), v.GetScalar().GetPrimitive().GetInteger()) + }) + + t.Run("successOutputs", func(t *testing.T) { + + mockStore := createInmemoryStore(t) + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + op := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + }, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Values{ + Values: op, + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseSuccess) + final := &core.LiteralMap{} + assert.NoError(t, mockStore.ReadProtobuf(ctx, v1alpha1.GetOutputsFile(dataDir), final)) + v, ok := final.GetLiterals()["x"] + assert.True(t, ok) + assert.Equal(t, int64(1), v.GetScalar().GetPrimitive().GetInteger()) + }) + + t.Run("failureError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_FAILED, + OutputResult: &admin.ExecutionClosure_Error{ + Error: &core.ExecutionError{ + Message: "msg", + Code: "code", + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("failureNoError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_FAILED, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("aborted", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_ABORTED, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("notFound", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(nil, launchplan.Wrapf(launchplan.RemoteErrorNotFound, fmt.Errorf("some error"), "not found")) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("systemError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(nil, launchplan.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("some error"), "not found")) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) + + t.Run("dataStoreFailure", func(t *testing.T) { + + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + op := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + }, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Values{ + Values: op, + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) + + t.Run("outputURINotFound", func(t *testing.T) { + + mockStore := createInmemoryStore(t) + mockLPExec := &mocks.Executor{} + uri := storage.DataReference("uri") + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Uri{ + Uri: uri.String(), + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("outputURISystemError", func(t *testing.T) { + + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec := &mocks.Executor{} + uri := storage.DataReference("uri") + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Uri{ + Uri: uri.String(), + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) +} + +func TestLaunchPlanHandler_HandleAbort(t *testing.T) { + + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return(nodeID) + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetName").Return("test") + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("abort-success", func(t *testing.T) { + mockLPExec := &mocks.Executor{} + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(nil) + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + err := h.HandleAbort(ctx, mockWf, mockNode) + assert.NoError(t, err) + }) + + t.Run("abort-fail", func(t *testing.T) { + expectedErr := fmt.Errorf("fail") + mockLPExec := &mocks.Executor{} + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(expectedErr) + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + err := h.HandleAbort(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, err, expectedErr) + }) +} diff --git a/pkg/controller/nodes/subworkflow/sub_workflow.go b/pkg/controller/nodes/subworkflow/sub_workflow.go new file mode 100644 index 000000000..a712d991f --- /dev/null +++ b/pkg/controller/nodes/subworkflow/sub_workflow.go @@ -0,0 +1,195 @@ +package subworkflow + +import ( + "context" + "fmt" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +//TODO Add unit tests for subworkflow handler + +// Subworkflow handler handles inline subworkflows +type subworkflowHandler struct { + nodeExecutor executors.Node + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore +} + +func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, + parentNodeStatus v1alpha1.ExecutableNodeStatus, startNode v1alpha1.ExecutableNode) (handler.Status, error) { + + //TODO we need to handle failing and success nodes + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, w, startNode) + if err != nil { + return handler.StatusUndefined, err + } + + if state.HasFailed() { + if w.GetOnFailureNode() != nil { + return handler.StatusFailing(state.Err), nil + } + return handler.StatusFailed(state.Err), nil + } + + if state.IsComplete() { + nodeID := "" + if parentNodeStatus.GetParentNodeID() != nil { + nodeID = *parentNodeStatus.GetParentNodeID() + } + + // If the WF interface has outputs, validate that the outputs file was written. + if outputBindings := w.GetOutputBindings(); len(outputBindings) > 0 { + endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + if endNodeStatus == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "No end node found in subworkflow.")), nil + } + + sourcePath := v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir()) + if metadata, err := s.store.Head(ctx, sourcePath); err == nil { + if !metadata.Exists() { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "Subworkflow is expected to produce outputs but no outputs file was written to %v.", + sourcePath)), nil + } + } else { + return handler.StatusUndefined, err + } + + destinationPath := v1alpha1.GetOutputsFile(parentNodeStatus.GetDataDir()) + if err := s.store.CopyRaw(ctx, sourcePath, destinationPath, storage.Options{}); err != nil { + return handler.StatusFailed(errors.Wrapf(errors.OutputsNotFoundError, nodeID, + err, "Failed to copy subworkflow outputs from [%v] to [%v]", + sourcePath, destinationPath)), nil + } + } + + return handler.StatusSuccess, nil + } + + if state.PartiallyComplete() { + // Re-enqueue the workflow + s.enqueueWorkflow(w.GetK8sWorkflowID().String()) + } + + return handler.StatusRunning, nil +} + +func (s *subworkflowHandler) DoInFailureHandling(ctx context.Context, w v1alpha1.ExecutableWorkflow) (handler.Status, error) { + if w.GetOnFailureNode() != nil { + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, w, w.GetOnFailureNode()) + if err != nil { + return handler.StatusUndefined, err + } + if state.HasFailed() { + return handler.StatusFailed(state.Err), nil + } + if state.IsComplete() { + // Re-enqueue the workflow + s.enqueueWorkflow(w.GetK8sWorkflowID().String()) + return handler.StatusFailed(nil), nil + } + return handler.StatusFailing(nil), nil + } + return handler.StatusFailed(nil), nil +} + +func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No subWorkflow [%s], workflow.", subID)), nil + } + + status := w.GetNodeExecutionStatus(node.GetID()) + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) + startNode := contextualSubWorkflow.StartNode() + if startNode == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, "", "No start node found in subworkflow.")), nil + } + + // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially + // Copy of the inputs to the Node + nodeStatus := contextualSubWorkflow.GetNodeExecutionStatus(startNode.GetID()) + if len(nodeStatus.GetDataDir()) == 0 { + dataDir, err := contextualSubWorkflow.GetExecutionStatus().ConstructNodeDataDir(ctx, s.store, startNode.GetID()) + if err != nil { + logger.Errorf(ctx, "Failed to create metadata store key. Error [%v]", err) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to create metadata store key.") + } + + nodeStatus.SetDataDir(dataDir) + startStatus, err := s.nodeExecutor.SetInputsForStartNode(ctx, contextualSubWorkflow, nodeInputs) + if err != nil { + // TODO we are considering an error when setting inputs are retryable + return handler.StatusUndefined, err + } + + if startStatus.HasFailed() { + return handler.StatusFailed(startStatus.Err), nil + } + } + + return s.DoInlineSubWorkflow(ctx, contextualSubWorkflow, status, startNode) +} + +func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, status v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + // Handle subworkflow + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No subWorkflow [%s], workflow.", subID)), nil + } + + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) + startNode := w.StartNode() + if startNode == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No start node found in subworkflow")), nil + } + + parentNodeStatus := w.GetNodeExecutionStatus(node.GetID()) + return s.DoInlineSubWorkflow(ctx, contextualSubWorkflow, parentNodeStatus, startNode) +} + +func (s *subworkflowHandler) HandleSubWorkflowFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + status := w.GetNodeExecutionStatus(node.GetID()) + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No subWorkflow [%s], workflow.", subID)), nil + } + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) + return s.DoInFailureHandling(ctx, contextualSubWorkflow) +} + +func (s *subworkflowHandler) HandleAbort(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return fmt.Errorf("no sub workflow [%s] found in node [%s]", subID, node.GetID()) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, nodeStatus) + + startNode := w.StartNode() + if startNode == nil { + return fmt.Errorf("no sub workflow [%s] found in node [%s]", subID, node.GetID()) + } + + return s.nodeExecutor.AbortHandler(ctx, contextualSubWorkflow, startNode) +} + +func newSubworkflowHandler(nodeExecutor executors.Node, enqueueWorkflow v1alpha1.EnqueueWorkflow, store *storage.DataStore) subworkflowHandler { + return subworkflowHandler{ + nodeExecutor: nodeExecutor, + enqueueWorkflow: enqueueWorkflow, + store: store, + } +} diff --git a/pkg/controller/nodes/subworkflow/util.go b/pkg/controller/nodes/subworkflow/util.go new file mode 100644 index 000000000..973a2e0b9 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/util.go @@ -0,0 +1,24 @@ +package subworkflow + +import ( + "strconv" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/utils" +) + +const maxLengthForSubWorkflow = 20 + +func GetChildWorkflowExecutionID(parentID *core.WorkflowExecutionIdentifier, id v1alpha1.NodeID, attempt uint32) (*core.WorkflowExecutionIdentifier, error) { + name, err := utils.FixedLengthUniqueIDForParts(maxLengthForSubWorkflow, parentID.Name, id, strconv.Itoa(int(attempt))) + if err != nil { + return nil, err + } + // Restriction on name is 20 chars + return &core.WorkflowExecutionIdentifier{ + Project: parentID.Project, + Domain: parentID.Domain, + Name: name, + }, nil +} diff --git a/pkg/controller/nodes/subworkflow/util_test.go b/pkg/controller/nodes/subworkflow/util_test.go new file mode 100644 index 000000000..a3e126f94 --- /dev/null +++ b/pkg/controller/nodes/subworkflow/util_test.go @@ -0,0 +1,19 @@ +package subworkflow + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestGetChildWorkflowExecutionID(t *testing.T) { + id, err := GetChildWorkflowExecutionID( + &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "first-name-is-pretty-large", + }, "hello-world", 1) + assert.Equal(t, id.Name, "fav2uxxi") + assert.NoError(t, err) +} diff --git a/pkg/controller/nodes/task/factory.go b/pkg/controller/nodes/task/factory.go new file mode 100644 index 000000000..456590b69 --- /dev/null +++ b/pkg/controller/nodes/task/factory.go @@ -0,0 +1,72 @@ +package task + +import ( + "time" + + v1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/pkg/errors" +) + +var testModeEnabled = false +var testTaskFactory Factory + +type Factory interface { + GetTaskExecutor(taskType v1alpha1.TaskType) (types.Executor, error) + ListAllTaskExecutors() []types.Executor +} + +// We create a simple facade so that if required we could make a Readonly cache of the Factory without any mutexes +// TODO decide if we want to make this a cache +type sealedTaskFactory struct { +} + +func (sealedTaskFactory) GetTaskExecutor(taskType v1alpha1.TaskType) (types.Executor, error) { + return v1.GetTaskExecutor(taskType) +} + +func (sealedTaskFactory) ListAllTaskExecutors() []types.Executor { + return v1.ListAllTaskExecutors() +} + +func NewFactory(revalPeriod time.Duration) Factory { + if testModeEnabled { + return testTaskFactory + } + + return sealedTaskFactory{} +} + +func SetTestFactory(tf Factory) { + testModeEnabled = true + testTaskFactory = tf +} + +func IsTestModeEnabled() bool { + return testModeEnabled +} + +func DisableTestMode() { + testTaskFactory = nil + testModeEnabled = false +} + +type FactoryFuncs struct { + GetTaskExecutorCb func(taskType v1alpha1.TaskType) (types.Executor, error) + ListAllTaskExecutorsCb func() []types.Executor +} + +func (t *FactoryFuncs) GetTaskExecutor(taskType v1alpha1.TaskType) (types.Executor, error) { + if t.GetTaskExecutorCb != nil { + return t.GetTaskExecutorCb(taskType) + } + return nil, errors.Errorf("No implementation provided") +} + +func (t *FactoryFuncs) ListAllTaskExecutors() []types.Executor { + if t.ListAllTaskExecutorsCb != nil { + return t.ListAllTaskExecutorsCb() + } + return nil +} diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go new file mode 100644 index 000000000..b118107ed --- /dev/null +++ b/pkg/controller/nodes/task/handler.go @@ -0,0 +1,439 @@ +package task + +import ( + "context" + "fmt" + "reflect" + "runtime/debug" + "strconv" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "sigs.k8s.io/controller-runtime/pkg/runtime/inject" + + "github.com/lyft/flytestdlib/promutils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + pluginsV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + errors2 "github.com/pkg/errors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +const IDMaxLength = 50 + +// TODO handle retries +type taskContext struct { + taskExecutionID taskExecutionID + dataDir storage.DataReference + workflow v1alpha1.WorkflowMeta + node v1alpha1.ExecutableNode + status v1alpha1.ExecutableTaskNodeStatus + serviceAccountName string +} + +func (t *taskContext) GetCustomState() pluginsV1.CustomState { + return t.status.GetCustomState() +} + +func (t *taskContext) GetPhase() pluginsV1.TaskPhase { + return t.status.GetPhase() +} + +func (t *taskContext) GetPhaseVersion() uint32 { + return t.status.GetPhaseVersion() +} + +type taskExecutionID struct { + execName string + id core.TaskExecutionIdentifier +} + +func (te taskExecutionID) GetID() core.TaskExecutionIdentifier { + return te.id +} + +func (te taskExecutionID) GetGeneratedName() string { + return te.execName +} + +func (t *taskContext) GetOwnerID() types.NamespacedName { + return t.workflow.GetK8sWorkflowID() +} + +func (t *taskContext) GetTaskExecutionID() pluginsV1.TaskExecutionID { + return t.taskExecutionID +} + +func (t *taskContext) GetDataDir() storage.DataReference { + return t.dataDir +} + +func (t *taskContext) GetInputsFile() storage.DataReference { + return v1alpha1.GetInputsFile(t.dataDir) +} + +func (t *taskContext) GetOutputsFile() storage.DataReference { + return v1alpha1.GetOutputsFile(t.dataDir) +} + +func (t *taskContext) GetErrorFile() storage.DataReference { + return v1alpha1.GetOutputErrorFile(t.dataDir) +} + +func (t *taskContext) GetNamespace() string { + return t.workflow.GetNamespace() +} + +func (t *taskContext) GetOwnerReference() v1.OwnerReference { + return t.workflow.NewControllerRef() +} + +func (t *taskContext) GetOverrides() pluginsV1.TaskOverrides { + return t.node +} + +func (t *taskContext) GetLabels() map[string]string { + return t.workflow.GetLabels() +} + +func (t *taskContext) GetAnnotations() map[string]string { + return t.workflow.GetAnnotations() +} + +func (t *taskContext) GetK8sServiceAccount() string { + return t.serviceAccountName +} + +type metrics struct { + pluginPanics labeled.Counter + unsupportedTaskType labeled.Counter + discoveryPutFailureCount labeled.Counter + discoveryGetFailureCount labeled.Counter + discoveryMissCount labeled.Counter + discoveryHitCount labeled.Counter + + // TODO We should have a metric to capture custom state size +} + +type taskHandler struct { + taskFactory Factory + recorder events.TaskEventRecorder + enqueueWf v1alpha1.EnqueueWorkflow + store *storage.DataStore + scope promutils.Scope + catalogClient catalog.Client + kubeClient executors.Client + metrics *metrics +} + +func (h *taskHandler) GetTaskExecutorContext(ctx context.Context, w v1alpha1.ExecutableWorkflow, + node v1alpha1.ExecutableNode) (pluginsV1.Executor, v1alpha1.ExecutableTask, pluginsV1.TaskContext, error) { + + taskID := node.GetTaskID() + if taskID == nil { + return nil, nil, nil, errors.Errorf(errors.BadSpecificationError, node.GetID(), "Task Id not set for NodeKind `Task`") + } + task, err := w.GetTask(*taskID) + if err != nil { + return nil, nil, nil, errors.Wrapf(errors.BadSpecificationError, node.GetID(), err, "Unable to find task for taskId: [%v]", *taskID) + } + + exec, err := h.taskFactory.GetTaskExecutor(task.TaskType()) + if err != nil { + h.metrics.unsupportedTaskType.Inc(ctx) + return nil, nil, nil, errors.Wrapf(errors.UnsupportedTaskTypeError, node.GetID(), err, + "Unable to find taskExecutor for taskId: [%v]. TaskType: [%v]", *taskID, task.TaskType()) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + id := core.TaskExecutionIdentifier{ + TaskId: task.CoreTask().Id, + RetryAttempt: nodeStatus.GetAttempts(), + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, + }, + } + + uniqueID, err := utils.FixedLengthUniqueIDForParts(IDMaxLength, w.GetName(), node.GetID(), strconv.Itoa(int(id.RetryAttempt))) + if err != nil { + // SHOULD never really happen + return nil, nil, nil, err + } + + taskNodeStatus := nodeStatus.GetTaskNodeStatus() + if taskNodeStatus == nil { + mutableTaskNodeStatus := nodeStatus.GetOrCreateTaskStatus() + taskNodeStatus = mutableTaskNodeStatus + } + + return exec, task, &taskContext{ + taskExecutionID: taskExecutionID{execName: uniqueID, id: id}, + dataDir: nodeStatus.GetDataDir(), + workflow: w, + node: node, + status: taskNodeStatus, + serviceAccountName: w.GetServiceAccountName(), + }, nil +} + +func (h *taskHandler) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar handler.VarName) (values *core.Literal, err error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, n) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, "failed to create TaskCtx") + } + + l, err := t.ResolveOutputs(ctx, taskCtx, bindToVar) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, + "failed to resolve output [%v] from task of type [%v]", bindToVar, task.TaskType()) + } + + return l[bindToVar], nil +} + +func (h *taskHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to create TaskCtx")), nil + } + + logger.Infof(ctx, "Executor type: [%v]. Properties: finalizer[%v]. disable[%v].", reflect.TypeOf(t).String(), t.GetProperties().RequiresFinalizer, t.GetProperties().DisableNodeLevelCaching) + if task.CoreTask().Metadata.Discoverable { + if t.GetProperties().DisableNodeLevelCaching { + logger.Infof(ctx, "Executor has Node-Level caching disabled. Skipping.") + } else if resp, err := h.catalogClient.Get(ctx, task.CoreTask(), taskCtx.GetInputsFile()); err != nil { + if taskStatus, ok := status.FromError(err); ok && taskStatus.Code() == codes.NotFound { + h.metrics.discoveryMissCount.Inc(ctx) + logger.Infof(ctx, "Artifact not found in Discovery. Executing Task.") + } else { + h.metrics.discoveryGetFailureCount.Inc(ctx) + logger.Errorf(ctx, "Discovery check failed. Executing Task. Err: %v", err.Error()) + } + } else if resp != nil { + h.metrics.discoveryHitCount.Inc(ctx) + if iface := task.CoreTask().Interface; iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { + if err := h.store.WriteProtobuf(ctx, taskCtx.GetOutputsFile(), storage.Options{}, resp); err != nil { + logger.Errorf(ctx, "failed to write data to Storage, err: %v", err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to copy cached results for task.") + } + } + // SetCached. + w.GetNodeExecutionStatus(node.GetID()).SetCached() + return handler.StatusSuccess, nil + } else { + // Nil response and Nil error + h.metrics.discoveryGetFailureCount.Inc(ctx) + return handler.StatusUndefined, errors.Wrapf(errors.CatalogCallFailed, node.GetID(), err, "Nil catalog response. Failed to check Catalog for previous results") + } + } + + var taskStatus pluginsV1.TaskStatus + func() { + defer func() { + if r := recover(); r != nil { + h.metrics.pluginPanics.Inc(ctx) + stack := debug.Stack() + err = fmt.Errorf("panic when executing a plugin for TaskType [%s]. Stack: [%s]", task.TaskType(), string(stack)) + logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) + } + }() + taskStatus, err = t.StartTask(ctx, taskCtx, task.CoreTask(), nodeInputs) + }() + + if err != nil { + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to start task [retry attempt: %d]", taskCtx.GetTaskExecutionID().GetID().RetryAttempt) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + taskNodeStatus := nodeStatus.GetOrCreateTaskStatus() + taskNodeStatus.SetPhase(taskStatus.Phase) + taskNodeStatus.SetPhaseVersion(taskStatus.PhaseVersion) + taskNodeStatus.SetCustomState(taskStatus.State) + + logger.Debugf(ctx, "Started Task Node") + return ConvertTaskPhaseToHandlerStatus(taskStatus) +} + +func ConvertTaskPhaseToHandlerStatus(taskStatus pluginsV1.TaskStatus) (handler.Status, error) { + // TODO handle retryable failure + switch taskStatus.Phase { + case pluginsV1.TaskPhaseNotReady: + return handler.StatusQueued.WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhaseQueued, pluginsV1.TaskPhaseRunning: + return handler.StatusRunning.WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhasePermanentFailure: + return handler.StatusFailed(taskStatus.Err).WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhaseRetryableFailure: + return handler.StatusRetryableFailure(taskStatus.Err).WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhaseSucceeded: + return handler.StatusSuccess.WithOccurredAt(taskStatus.OccurredAt), nil + default: + return handler.StatusUndefined, errors.Errorf(errors.IllegalStateError, "received unknown task phase. [%s]", taskStatus.Phase.String()) + } +} + +func (h *taskHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, prevNodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to create TaskCtx")), nil + } + + var taskStatus pluginsV1.TaskStatus + func() { + defer func() { + if r := recover(); r != nil { + h.metrics.pluginPanics.Inc(ctx) + stack := debug.Stack() + err = fmt.Errorf("panic when executing a plugin for TaskType [%s]. Stack: [%s]", task.TaskType(), string(stack)) + logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) + } + }() + taskStatus, err = t.CheckTaskStatus(ctx, taskCtx, task.CoreTask()) + }() + + if err != nil { + logger.Warnf(ctx, "Failed to check status") + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to check status") + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + taskNodeStatus := nodeStatus.GetOrCreateTaskStatus() + taskNodeStatus.SetPhase(taskStatus.Phase) + taskNodeStatus.SetPhaseVersion(taskStatus.PhaseVersion) + taskNodeStatus.SetCustomState(taskStatus.State) + + return ConvertTaskPhaseToHandlerStatus(taskStatus) +} + +func (h *taskHandler) HandleNodeSuccess(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to create TaskCtx")), nil + } + + // If the task interface has outputs, validate that the outputs file was written. + if iface := task.CoreTask().Interface; task.TaskType() != "container_array" && iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { + if metadata, err := h.store.Head(ctx, taskCtx.GetOutputsFile()); err != nil { + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to HEAD task outputs file.") + } else if !metadata.Exists() { + return handler.StatusRetryableFailure(errors.Errorf(errors.OutputsNotFoundError, node.GetID(), + "Outputs not found for task type %s, looking for output file %s", task.TaskType(), taskCtx.GetOutputsFile())), nil + } + + // ignores discovery write failures + if task.CoreTask().Metadata.Discoverable && !t.GetProperties().DisableNodeLevelCaching { + taskExecutionID := taskCtx.GetTaskExecutionID().GetID() + if err2 := h.catalogClient.Put(ctx, task.CoreTask(), &taskExecutionID, taskCtx.GetInputsFile(), taskCtx.GetOutputsFile()); err2 != nil { + h.metrics.discoveryPutFailureCount.Inc(ctx) + logger.Errorf(ctx, "Failed to write results to catalog. Err: %v", err2) + } else { + logger.Debugf(ctx, "Successfully cached results to discovery - Task [%s]", task.CoreTask().GetId()) + } + } + } + return handler.StatusSuccess, nil +} + +func (h *taskHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, node.GetID(), "A regular Task node cannot enter a failing state")), nil +} + +func (h *taskHandler) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing taskHandler") + enqueueFn := func(ownerId types.NamespacedName) error { + h.enqueueWf(ownerId.String()) + return nil + } + + initParams := pluginsV1.ExecutorInitializationParameters{ + CatalogClient: h.catalogClient, + EventRecorder: h.recorder, + DataStore: h.store, + EnqueueOwner: enqueueFn, + OwnerKind: v1alpha1.FlyteWorkflowKind, + MetricsScope: h.scope, + } + + for _, r := range h.taskFactory.ListAllTaskExecutors() { + logger.Infof(ctx, "Initializing Executor [%v]", r.GetID()) + // Inject a RuntimeClient if the executor needs one. + if _, err := inject.ClientInto(h.kubeClient.GetClient(), r); err != nil { + return errors2.Wrapf(err, "Failed to initialize [%v]", r.GetID()) + } + + if _, err := inject.CacheInto(h.kubeClient.GetCache(), r); err != nil { + return errors2.Wrapf(err, "Failed to initialize [%v]", r.GetID()) + } + + err := r.Initialize(ctx, initParams) + if err != nil { + return errors2.Wrapf(err, "Failed to Initialize TaskExecutor [%v]", r.GetID()) + } + } + + logger.Infof(ctx, "taskHandler Initialization complete") + return nil +} + +func (h *taskHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + t, _, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to create TaskCtx") + } + + err = t.KillTask(ctx, taskCtx, "Node aborted") + if err != nil { + return errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to abort task") + } + // TODO: Do we need to update the Node status to Failed here as well ? + logger.Infof(ctx, "Invoked KillTask on Task Node.") + return nil +} + +func NewTaskHandlerForFactory(eventSink events.EventSink, store *storage.DataStore, enqueueWf v1alpha1.EnqueueWorkflow, + tf Factory, catalogClient catalog.Client, kubeClient executors.Client, scope promutils.Scope) handler.IFace { + + // create a recorder for the plugins + eventsRecorder := utils.NewPluginTaskEventRecorder(events.NewTaskEventRecorder(eventSink, scope)) + return &taskHandler{ + taskFactory: tf, + recorder: eventsRecorder, + enqueueWf: enqueueWf, + store: store, + scope: scope, + catalogClient: catalogClient, + kubeClient: kubeClient, + metrics: &metrics{ + pluginPanics: labeled.NewCounter("plugin_panic", "Task plugin paniced when trying to execute a task.", scope), + unsupportedTaskType: labeled.NewCounter("unsupported_tasktype", "No task plugin configured for task type", scope), + discoveryHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", scope), + discoveryMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), + discoveryPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), + discoveryGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), + }, + } +} + +func New(eventSink events.EventSink, store *storage.DataStore, enqueueWf v1alpha1.EnqueueWorkflow, revalPeriod time.Duration, + catalogClient catalog.Client, kubeClient executors.Client, scope promutils.Scope) handler.IFace { + + return NewTaskHandlerForFactory(eventSink, store, enqueueWf, NewFactory(revalPeriod), + catalogClient, kubeClient, scope.NewSubScope("task")) +} diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go new file mode 100644 index 000000000..1832e82f7 --- /dev/null +++ b/pkg/controller/nodes/task/handler_test.go @@ -0,0 +1,769 @@ +package task + +import ( + "context" + "fmt" + "reflect" + "testing" + + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + pluginsV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/storage" + regErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + typesV1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +) + +const DataDir = storage.DataReference("test-data") +const NodeID = "n1" + +var ( + enqueueWfFunc = func(id string) {} + fakeKubeClient = mocks2.NewFakeKubeClient() +) + +func mockCatalogClient() catalog.Client { + return &catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } +} + +func createWf(id string, execID string, project string, domain string, name string) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + ExecutionID: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: execID, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + NodeID: { + DataDir: DataDir, + }, + }, + }, + ObjectMeta: v1.ObjectMeta{ + Name: name, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: id, + }, + } +} + +func createStartNode() *v1alpha1.NodeSpec { + return &v1alpha1.NodeSpec{ + ID: NodeID, + Kind: v1alpha1.NodeKindStart, + Resources: &typesV1.ResourceRequirements{ + Requests: typesV1.ResourceList{ + typesV1.ResourceCPU: resource.MustParse("1"), + }, + }, + } +} + +func createTask(id string, ttype string, discoverable bool) *v1alpha1.TaskSpec { + return &v1alpha1.TaskSpec{ + TaskTemplate: &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: ttype, + Metadata: &core.TaskMetadata{Discoverable: discoverable}, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + }, + } +} + +func createDummyExec() *mocks.Executor { + dummyExec := &mocks.Executor{} + dummyExec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginsV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + dummyExec.On("GetID").Return("test") + + return dummyExec +} + +func TestTaskHandler_Initialize(t *testing.T) { + ctx := context.TODO() + t.Run("NoHandlers", func(t *testing.T) { + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, th.Initialize(context.TODO())) + }) + + t.Run("SomeHandler", func(t *testing.T) { + d := &FactoryFuncs{ + ListAllTaskExecutorsCb: func() []pluginsV1.Executor { + return []pluginsV1.Executor{ + createDummyExec(), + } + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, th.Initialize(ctx)) + }) +} + +func TestTaskHandler_HandleFailingNode(t *testing.T) { + ctx := context.Background() + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + w := createWf("w1", "w2-exec", "project", "domain", "execName1") + n := createStartNode() + s, err := th.HandleFailingNode(ctx, w, n) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseFailed, s.Phase) + assert.Error(t, s.Err) +} + +func TestTaskHandler_GetTaskExecutorContext(t *testing.T) { + ctx := context.Background() + const execName = "w1-exec" + t.Run("NoTaskId", func(t *testing.T) { + w := createWf("w1", execName, "project", "domain", "execName1") + n := createStartNode() + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + + _, _, _, err := th.GetTaskExecutorContext(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.BadSpecificationError)) + }) + + t.Run("NoTaskMatch", func(t *testing.T) { + taskID := "t1" + w := createWf("w1", execName, "project", "domain", "execName1") + n := createStartNode() + n.TaskRef = &taskID + + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + _, _, _, err := th.GetTaskExecutorContext(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.BadSpecificationError)) + }) + + t.Run("TaskMatchNoExecutor", func(t *testing.T) { + taskID := "t1" + task := createTask(taskID, "dynamic", false) + + w := createWf("w1", execName, "project", "domain", "execName1") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + + n := createStartNode() + n.TaskRef = &taskID + + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + _, _, _, err := th.GetTaskExecutorContext(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.UnsupportedTaskTypeError)) + }) + + t.Run("TaskMatch", func(t *testing.T) { + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w1", execName, "project", "domain", "execName1") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + w.ServiceAccountName = "service-account" + n := createStartNode() + n.TaskRef = &taskID + + taskExec := &mocks.Executor{} + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + te, receivedTask, tc, err := th.GetTaskExecutorContext(ctx, w, n) + if assert.NoError(t, err) { + assert.Equal(t, taskExec, te) + if assert.NotNil(t, tc) { + assert.Equal(t, "execName1-n1-0", tc.GetTaskExecutionID().GetGeneratedName()) + assert.Equal(t, DataDir, tc.GetDataDir()) + assert.NotNil(t, tc.GetOverrides()) + assert.NotNil(t, tc.GetOverrides().GetResources()) + assert.NotEmpty(t, tc.GetOverrides().GetResources().Requests) + assert.Equal(t, "service-account", tc.GetK8sServiceAccount()) + } + assert.Equal(t, task, receivedTask) + } + }) + + t.Run("TaskMatchAttempt>0", func(t *testing.T) { + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w1", execName, "project", "domain", "execName1") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + status := w.Status.GetNodeExecutionStatus(n.ID).(*v1alpha1.NodeStatus) + status.Attempts = 2 + + taskExec := &mocks.Executor{} + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + te, receivedTask, tc, err := th.GetTaskExecutorContext(ctx, w, n) + if assert.NoError(t, err) { + assert.Equal(t, taskExec, te) + if assert.NotNil(t, tc) { + assert.Equal(t, "execName1-n1-2", tc.GetTaskExecutionID().GetGeneratedName()) + assert.Equal(t, DataDir, tc.GetDataDir()) + assert.NotNil(t, tc.GetOverrides()) + assert.NotNil(t, tc.GetOverrides().GetResources()) + assert.NotEmpty(t, tc.GetOverrides().GetResources().Requests) + } + assert.Equal(t, task, receivedTask) + } + }) + +} + +func TestTaskHandler_StartNode(t *testing.T) { + ctx := context.Background() + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w2", "w2-exec", "project", "domain", "execName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("NoTaskExec", func(t *testing.T) { + taskExec := &mocks.Executor{} + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return nil, regErrors.New("No match") + } + return taskExec, nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Error(t, s.Err) + assert.True(t, errors.Matches(s.Err, errors.CausedByError)) + }) + + t.Run("TaskExecStartFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusPermanentFailure(regErrors.New("Failed")), nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseFailed, s.Phase) + }) + + t.Run("TaskExecStartPanic", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return( + func(ctx context.Context, taskCtx pluginsV1.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (pluginsV1.TaskStatus, error) { + panic("failed in execution") + }, + ) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + s, err := th.StartNode(ctx, w, n, nil) + assert.Error(t, err) + assert.Equal(t, handler.PhaseUndefined, s.Phase) + }) + + t.Run("TaskExecStarted", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) +} + +func TestTaskHandler_StartNodeDiscoverable(t *testing.T) { + ctx := context.Background() + taskID := "t1" + task := createTask(taskID, "container", true) + task.Id.Project = "flytekit" + w := createWf("w2", "w2-exec", "flytekit", "domain", "execName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("TaskExecStartNodeDiscoveryFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, regErrors.Errorf("error") + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) + + t.Run("TaskExecStartNodeDiscoveryMiss", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, status.Errorf(codes.NotFound, "not found") + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) + + t.Run("TaskExecStartNodeDiscoveryHit", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + + return &core.LiteralMap{ + Literals: paramsMap, + }, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } + store := createInmemoryDataStore(t, testScope.NewSubScope("12")) + th := NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} + +func TestTaskHandler_AbortNode(t *testing.T) { + ctx := context.Background() + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w2", "w2-exec", "project", "domain", "execName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("NoTaskExec", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return nil, regErrors.New("No match") + } + return taskExec, nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + err := th.AbortNode(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.CausedByError)) + }) + + t.Run("TaskExecKillFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("KillTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.Anything, + ).Return(regErrors.New("Failed")) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + err := th.AbortNode(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.CausedByError)) + }) + + t.Run("TaskExecKilled", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("KillTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.Anything, + ).Return(nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + err := th.AbortNode(ctx, w, n) + assert.NoError(t, err) + }) +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +var testScope = promutils.NewScope("test_wfexec") + +func TestTaskHandler_CheckNodeStatus(t *testing.T) { + ctx := context.Background() + + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w1", "w2-exec", "projTest", "domainTest", "checkNodeTestName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("NoTaskExec", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return nil, regErrors.New("No match") + } + return taskExec, nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.True(t, errors.Matches(s.Err, errors.CausedByError)) + }) + + t.Run("TaskExecStartFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusPermanentFailure(regErrors.New("Failed")), nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseFailed, s.Phase) + }) + + t.Run("TaskExecCheckPanic", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(func(ctx context.Context, taskCtx pluginsV1.TaskContext, task *core.TaskTemplate) (status pluginsV1.TaskStatus, err error) { + panic("failed in execution") + }) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.Error(t, err) + assert.Equal(t, handler.PhaseUndefined, s.Phase) + }) + + t.Run("TaskExecRunning", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) + + t.Run("TaskExecDone", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + + store := createInmemoryDataStore(t, testScope.NewSubScope("4")) + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + err1 := store.WriteProtobuf(ctx, "test-data/inputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + err2 := store.WriteProtobuf(ctx, "test-data/outputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + assert.NoError(t, err1) + assert.NoError(t, err2) + + th := NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} + +func TestConvertTaskPhaseToHandlerStatus(t *testing.T) { + expectedErr := fmt.Errorf("failed") + tests := []struct { + name string + status pluginsV1.TaskStatus + hs handler.Status + isError bool + }{ + {"undefined", pluginsV1.TaskStatusUndefined, handler.StatusUndefined, true}, + {"running", pluginsV1.TaskStatusRunning, handler.StatusRunning, false}, + {"queued", pluginsV1.TaskStatusQueued, handler.StatusRunning, false}, + {"succeeded", pluginsV1.TaskStatusSucceeded, handler.StatusSuccess, false}, + {"unknown", pluginsV1.TaskStatusUnknown, handler.StatusUndefined, true}, + {"retryable", pluginsV1.TaskStatusRetryableFailure(expectedErr), handler.StatusRetryableFailure(expectedErr), false}, + {"failed", pluginsV1.TaskStatusPermanentFailure(expectedErr), handler.StatusFailed(expectedErr), false}, + {"undefined", pluginsV1.TaskStatusUndefined, handler.StatusUndefined, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hs, err := ConvertTaskPhaseToHandlerStatus(test.status) + assert.Equal(t, hs, test.hs) + if test.isError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/controller/workers.go b/pkg/controller/workers.go new file mode 100644 index 000000000..7b6303c81 --- /dev/null +++ b/pkg/controller/workers.go @@ -0,0 +1,177 @@ +package controller + +import ( + "context" + "fmt" + "runtime/pprof" + "time" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/tools/cache" +) + +type Handler interface { + // Initialize the Handler + Initialize(ctx context.Context) error + // Handle method that should handle the object and try to converge the desired and the actual state + Handle(ctx context.Context, namespace, key string) error +} + +type workerPoolMetrics struct { + Scope promutils.Scope + FreeWorkers prometheus.Gauge + PerRoundTimer promutils.StopWatch + RoundError prometheus.Counter + RoundSuccess prometheus.Counter + WorkersRestarted prometheus.Counter +} + +type WorkerPool struct { + workQueue CompositeWorkQueue + metrics workerPoolMetrics + handler Handler +} + +// processNextWorkItem will read a single work item off the workqueue and +// attempt to process it, by calling the handler. +func (w *WorkerPool) processNextWorkItem(ctx context.Context) bool { + obj, shutdown := w.workQueue.Get() + + w.metrics.FreeWorkers.Dec() + defer w.metrics.FreeWorkers.Inc() + + if shutdown { + return false + } + + // We wrap this block in a func so we can defer c.workqueue.Done. + err := func(obj interface{}) error { + // We call Done here so the workqueue knows we have finished + // processing this item. We also must remember to call Forget if we + // do not want this work item being re-queued. For example, we do + // not call Forget if a transient error occurs, instead the item is + // put back on the workqueue and attempted again after a back-off + // period. + defer w.workQueue.Done(obj) + var key string + var ok bool + // We expect strings to come off the workqueue. These are of the + // form namespace/name. We do this as the delayed nature of the + // workqueue means the items in the informer cache may actually be + // more up to date that when the item was initially put onto the + // workqueue. + if key, ok = obj.(string); !ok { + // As the item in the workqueue is actually invalid, we call + // Forget here else we'd go into a loop of attempting to + // process a work item that is invalid. + w.workQueue.Forget(obj) + runtime.HandleError(fmt.Errorf("expected string in workqueue but got %#v", obj)) + return nil + } + + t := w.metrics.PerRoundTimer.Start() + defer t.Stop() + + // Convert the namespace/name string into a distinct namespace and name + namespace, name, err := cache.SplitMetaNamespaceKey(key) + if err != nil { + logger.Errorf(ctx, "Unable to split enqueued key into namespace/execId. Error[%v]", err) + return nil + } + ctx = contextutils.WithNamespace(ctx, namespace) + ctx = contextutils.WithExecutionID(ctx, name) + // Reconcile the Workflow + if err := w.handler.Handle(ctx, namespace, name); err != nil { + w.metrics.RoundError.Inc() + return fmt.Errorf("error syncing '%s': %s", key, err.Error()) + } + w.metrics.RoundSuccess.Inc() + + // Finally, if no error occurs we Forget this item so it does not + // get queued again until another change happens. + w.workQueue.Forget(obj) + logger.Infof(ctx, "Successfully synced '%s'", key) + return nil + }(obj) + + if err != nil { + runtime.HandleError(err) + return true + } + + return true +} + +// runWorker is a long-running function that will continually call the +// processNextWorkItem function in order to read and process a message on the +// workqueue. +func (w *WorkerPool) runWorker(ctx context.Context) { + logger.Infof(ctx, "Started Worker") + defer logger.Infof(ctx, "Exiting Worker") + for w.processNextWorkItem(ctx) { + } +} + +func (w *WorkerPool) Initialize(ctx context.Context) error { + return w.handler.Initialize(ctx) +} + +// Run will set up the event handlers for types we are interested in, as well +// as syncing informer caches and starting workers. It will block until stopCh +// is closed, at which point it will shutdown the workqueue and wait for +// workers to finish processing their current work items. +func (w *WorkerPool) Run(ctx context.Context, threadiness int, synced ...cache.InformerSynced) error { + defer runtime.HandleCrash() + defer w.workQueue.ShutdownAll() + + // Start the informer factories to begin populating the informer caches + logger.Info(ctx, "Starting FlyteWorkflow controller") + w.metrics.WorkersRestarted.Inc() + + // Wait for the caches to be synced before starting workers + logger.Info(ctx, "Waiting for informer caches to sync") + if ok := cache.WaitForCacheSync(ctx.Done(), synced...); !ok { + return fmt.Errorf("failed to wait for caches to sync") + } + + logger.Infof(ctx, "Starting workers [%d]", threadiness) + // Launch workers to process FlyteWorkflow resources + for i := 0; i < threadiness; i++ { + w.metrics.FreeWorkers.Inc() + logger.Infof(ctx, "Starting worker [%d]", i) + workerLabel := fmt.Sprintf("worker-%v", i) + go func() { + workerCtx := contextutils.WithGoroutineLabel(ctx, workerLabel) + pprof.SetGoroutineLabels(workerCtx) + w.runWorker(workerCtx) + }() + } + + w.workQueue.Start(ctx) + logger.Info(ctx, "Started workers") + <-ctx.Done() + logger.Info(ctx, "Shutting down workers") + + return nil +} + +func NewWorkerPool(ctx context.Context, scope promutils.Scope, workQueue CompositeWorkQueue, handler Handler) *WorkerPool { + roundScope := scope.NewSubScope("round") + metrics := workerPoolMetrics{ + Scope: scope, + FreeWorkers: scope.MustNewGauge("free_workers_count", "Number of workers free"), + PerRoundTimer: roundScope.MustNewStopWatch("round_total", "Latency per round", time.Millisecond), + RoundSuccess: roundScope.MustNewCounter("success_count", "Round succeeded"), + RoundError: roundScope.MustNewCounter("error_count", "Round failed"), + WorkersRestarted: scope.MustNewCounter("workers_restarted", "Propeller worker-pool was restarted"), + } + return &WorkerPool{ + workQueue: workQueue, + metrics: metrics, + handler: handler, + } +} diff --git a/pkg/controller/workers_test.go b/pkg/controller/workers_test.go new file mode 100644 index 000000000..aae7e54eb --- /dev/null +++ b/pkg/controller/workers_test.go @@ -0,0 +1,93 @@ +package controller + +import ( + "context" + "sync" + "testing" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +var testLocalScope2 = promutils.NewScope("worker_pool") + +type testHandler struct { + InitCb func(ctx context.Context) error + HandleCb func(ctx context.Context, namespace, key string) error +} + +func (t *testHandler) Initialize(ctx context.Context) error { + return t.InitCb(ctx) +} + +func (t *testHandler) Handle(ctx context.Context, namespace, key string) error { + return t.HandleCb(ctx, namespace, key) +} + +func simpleWorkQ(ctx context.Context, t *testing.T, testScope promutils.Scope) CompositeWorkQueue { + cfg := config.CompositeQueueConfig{} + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + return q +} + +func TestWorkerPool_Run(t *testing.T) { + ctx := context.TODO() + l := testLocalScope2.NewSubScope("new") + h := &testHandler{} + q := simpleWorkQ(ctx, t, l) + w := NewWorkerPool(ctx, l, q, h) + assert.NotNil(t, w) + + t.Run("initcalled", func(t *testing.T) { + + initCalled := false + h.InitCb = func(ctx context.Context) error { + initCalled = true + return nil + } + + assert.NoError(t, w.Initialize(ctx)) + assert.True(t, initCalled) + }) + + // Bad TEST :(. We create 2 waitgroups, one will wait for the Run function to exit (called wg) + // Other is called handleReceived, waits for receiving a handle + // The flow is, + // - start the poll loop + // - add a key `x` + // - wait for `x` to be handled + // - cancel the loop + // - wait for loop to exit + t.Run("run", func(t *testing.T) { + childCtx, cancel := context.WithCancel(ctx) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + assert.NoError(t, w.Run(childCtx, 1, func() bool { + return true + })) + wg.Done() + }() + + handleReceived := sync.WaitGroup{} + handleReceived.Add(1) + + h.HandleCb = func(ctx context.Context, namespace, key string) error { + if key == "x" { + handleReceived.Done() + } else { + assert.FailNow(t, "x expected") + } + return nil + } + q.Add("x") + handleReceived.Wait() + + cancel() + wg.Wait() + }) +} diff --git a/pkg/controller/workflow/errors/codes.go b/pkg/controller/workflow/errors/codes.go new file mode 100644 index 000000000..f4a84685a --- /dev/null +++ b/pkg/controller/workflow/errors/codes.go @@ -0,0 +1,15 @@ +package errors + +type ErrorCode string + +const ( + IllegalStateError ErrorCode = "IllegalStateError" + BadSpecificationError ErrorCode = "BadSpecificationError" + CausedByError ErrorCode = "CausedByError" + RuntimeExecutionError ErrorCode = "RuntimeExecutionError" + EventRecordingError ErrorCode = "ErrorRecordingError" +) + +func (e ErrorCode) String() string { + return string(e) +} diff --git a/pkg/controller/workflow/errors/errors.go b/pkg/controller/workflow/errors/errors.go new file mode 100644 index 000000000..edf56224b --- /dev/null +++ b/pkg/controller/workflow/errors/errors.go @@ -0,0 +1,80 @@ +package errors + +import ( + "fmt" + + "github.com/pkg/errors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type ErrorMessage = string + +type WorkflowError struct { + errors.StackTrace + Code ErrorCode + Message ErrorMessage + Workflow v1alpha1.WorkflowID +} + +func (w *WorkflowError) Error() string { + return fmt.Sprintf("Workflow[%s] failed. %v: %v", w.Workflow, w.Code, w.Message) +} + +type WorkflowErrorWithCause struct { + *WorkflowError + cause error +} + +func (w *WorkflowErrorWithCause) Error() string { + return fmt.Sprintf("%v, caused by: %v", w.WorkflowError.Error(), errors.Cause(w)) +} + +func (w *WorkflowErrorWithCause) Cause() error { + return w.cause +} + +func errorf(c ErrorCode, w v1alpha1.WorkflowID, msgFmt string, args ...interface{}) *WorkflowError { + return &WorkflowError{ + Code: c, + Workflow: w, + Message: fmt.Sprintf(msgFmt, args...), + } +} + +func Errorf(c ErrorCode, w v1alpha1.WorkflowID, msgFmt string, args ...interface{}) error { + return errorf(c, w, msgFmt, args...) +} + +func Wrapf(c ErrorCode, w v1alpha1.WorkflowID, cause error, msgFmt string, args ...interface{}) error { + return &WorkflowErrorWithCause{ + WorkflowError: errorf(c, w, msgFmt, args...), + cause: cause, + } +} + +func Matches(err error, code ErrorCode) bool { + errCode, isWorkflowError := GetErrorCode(err) + if isWorkflowError { + return code == errCode + } + return false +} + +func GetErrorCode(err error) (code ErrorCode, isWorkflowError bool) { + isWorkflowError = false + e, ok := err.(*WorkflowError) + if ok { + code = e.Code + isWorkflowError = true + return + } + + e2, ok := err.(*WorkflowErrorWithCause) + if ok { + code = e2.Code + isWorkflowError = true + return + } + return +} diff --git a/pkg/controller/workflow/errors/errors_test.go b/pkg/controller/workflow/errors/errors_test.go new file mode 100644 index 000000000..c773fa776 --- /dev/null +++ b/pkg/controller/workflow/errors/errors_test.go @@ -0,0 +1,48 @@ +package errors + +import ( + "fmt" + "testing" + + extErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestErrorf(t *testing.T) { + msg := "msg" + err := Errorf(IllegalStateError, "w1", "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*WorkflowError) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "w1", e.Workflow) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, err, extErrors.Cause(e)) + assert.Equal(t, "Workflow[w1] failed. IllegalStateError: Message [msg]", err.Error()) +} + +func TestErrorfWithCause(t *testing.T) { + cause := extErrors.Errorf("Some Error") + msg := "msg" + err := Wrapf(IllegalStateError, "w1", cause, "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*WorkflowErrorWithCause) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "w1", e.Workflow) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, cause, extErrors.Cause(e)) + assert.Equal(t, "Workflow[w1] failed. IllegalStateError: Message [msg], caused by: Some Error", err.Error()) +} + +func TestMatches(t *testing.T) { + err := Errorf(IllegalStateError, "w1", "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + cause := extErrors.Errorf("Some Error") + err = Wrapf(IllegalStateError, "w1", cause, "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + assert.False(t, Matches(cause, IllegalStateError)) + assert.False(t, Matches(cause, BadSpecificationError)) +} diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go new file mode 100644 index 000000000..ddbcc03ec --- /dev/null +++ b/pkg/controller/workflow/executor.go @@ -0,0 +1,420 @@ +package workflow + +import ( + "context" + "fmt" + "time" + + "github.com/lyft/flyteidl/clients/go/events" + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + corev1 "k8s.io/api/core/v1" + "k8s.io/client-go/tools/record" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/workflow/errors" + "github.com/lyft/flytepropeller/pkg/utils" +) + +type workflowMetrics struct { + AcceptedWorkflows labeled.Counter + FailureDuration labeled.StopWatch + SuccessDuration labeled.StopWatch + IncompleteWorkflowAborted labeled.Counter + + // Measures the time between when we receive service call to create an execution and when it has moved to running state. + AcceptanceLatency labeled.StopWatch + // Measures the time between when the WF moved to succeeding/failing state and when it finally moved to a terminal state. + CompletionLatency labeled.StopWatch +} + +type Status struct { + TransitionToPhase v1alpha1.WorkflowPhase + Err error +} + +var StatusReady = Status{TransitionToPhase: v1alpha1.WorkflowPhaseReady} +var StatusRunning = Status{TransitionToPhase: v1alpha1.WorkflowPhaseRunning} +var StatusSucceeding = Status{TransitionToPhase: v1alpha1.WorkflowPhaseSucceeding} +var StatusSuccess = Status{TransitionToPhase: v1alpha1.WorkflowPhaseSuccess} + +func StatusFailing(err error) Status { + return Status{TransitionToPhase: v1alpha1.WorkflowPhaseFailing, Err: err} +} + +func StatusFailed(err error) Status { + return Status{TransitionToPhase: v1alpha1.WorkflowPhaseFailed, Err: err} +} + +type workflowExecutor struct { + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + wfRecorder events.WorkflowEventRecorder + k8sRecorder record.EventRecorder + metadataPrefix storage.DataReference + nodeExecutor executors.Node + metrics *workflowMetrics +} + +func (c *workflowExecutor) constructWorkflowMetadataPrefix(ctx context.Context, w *v1alpha1.FlyteWorkflow) (storage.DataReference, error) { + if w.GetExecutionID().WorkflowExecutionIdentifier != nil { + execID := fmt.Sprintf("%v-%v-%v", w.GetExecutionID().GetProject(), w.GetExecutionID().GetDomain(), w.GetExecutionID().GetName()) + return c.store.ConstructReference(ctx, c.metadataPrefix, execID) + } + // TODO should we use a random guid as the prefix? Otherwise we may get collisions + logger.Warningf(ctx, "Workflow has no ExecutionID. Using the name as the storage-prefix. This maybe unsafe!") + return c.store.ConstructReference(ctx, c.metadataPrefix, w.Name) +} + +func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { + + startNode := w.StartNode() + if startNode == nil { + return StatusFailing(errors.Errorf(errors.BadSpecificationError, w.GetID(), "StartNode not found.")), nil + } + + ref, err := c.constructWorkflowMetadataPrefix(ctx, w) + if err != nil { + return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to create metadata prefix.")), nil + } + w.GetExecutionStatus().SetDataDir(ref) + var inputs *core.LiteralMap + if w.Inputs != nil { + inputs = w.Inputs.LiteralMap + } + // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially + // Copy of the inputs to the Node + nodeStatus := w.GetNodeExecutionStatus(startNode.GetID()) + dataDir, err := c.store.ConstructReference(ctx, ref, startNode.GetID(), "data") + if err != nil { + return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to create metadata prefix for start node.")), nil + } + + logger.Infof(ctx, "Setting the MetadataDir for StartNode [%v]", dataDir) + nodeStatus.SetDataDir(dataDir) + s, err := c.nodeExecutor.SetInputsForStartNode(ctx, w, inputs) + if err != nil { + return StatusReady, err + } + + if s.HasFailed() { + return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to set inputs for Start node.")), nil + } + return StatusRunning, nil +} + +func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { + contextualWf := executors.NewBaseContextualWorkflow(w) + startNode := contextualWf.StartNode() + if startNode == nil { + return StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?")), nil + } + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, contextualWf, startNode) + if err != nil { + return StatusRunning, err + } + if state.HasFailed() { + logger.Infof(ctx, "Workflow has failed. Error [%s]", state.Err.Error()) + return StatusFailing(state.Err), nil + } + if state.IsComplete() { + return StatusSucceeding, nil + } + if state.PartiallyComplete() { + c.enqueueWorkflow(contextualWf.GetK8sWorkflowID().String()) + } + return StatusRunning, nil +} + +func (c *workflowExecutor) handleFailingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { + contextualWf := executors.NewBaseContextualWorkflow(w) + // Best effort clean-up. + if err := c.cleanupRunningNodes(ctx, contextualWf); err != nil { + logger.Errorf(ctx, "Failed to propagate Abort for workflow:%v. Error: %v", w.ExecutionID.WorkflowExecutionIdentifier, err) + } + + errorNode := contextualWf.GetOnFailureNode() + if errorNode != nil { + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, contextualWf, errorNode) + if err != nil { + return StatusFailing(nil), err + } + if state.HasFailed() { + return StatusFailed(state.Err), nil + } + if state.PartiallyComplete() { + // Re-enqueue the workflow + c.enqueueWorkflow(contextualWf.GetK8sWorkflowID().String()) + return StatusFailing(nil), nil + } + // Fallthrough to handle state is complete + } + return StatusFailed(errors.Errorf(errors.CausedByError, w.ID, contextualWf.GetExecutionStatus().GetMessage())), nil +} + +func (c *workflowExecutor) handleSucceedingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) Status { + logger.Infof(ctx, "Workflow completed successfully") + endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + if endNodeStatus.GetPhase() == v1alpha1.NodePhaseSucceeded { + if endNodeStatus.GetDataDir() != "" { + w.Status.SetOutputReference(v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir())) + } + } + return StatusSuccess +} + +func convertToExecutionError(err error, alternateErr string) *event.WorkflowExecutionEvent_Error { + if err != nil { + if code, isWorkflowErr := errors.GetErrorCode(err); isWorkflowErr { + return &event.WorkflowExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: code.String(), + Message: err.Error(), + }, + } + } + } else { + err = fmt.Errorf(alternateErr) + } + return &event.WorkflowExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errors.RuntimeExecutionError.String(), + Message: err.Error(), + }, + } +} + +func (c *workflowExecutor) IdempotentReportEvent(ctx context.Context, e *event.WorkflowExecutionEvent) error { + err := c.wfRecorder.RecordWorkflowEvent(ctx, e) + if err != nil && eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Workflow event phase: %s, executionId %s already exist", + e.Phase.String(), e.ExecutionId) + return nil + } + return err +} + +func (c *workflowExecutor) TransitionToPhase(ctx context.Context, execID *core.WorkflowExecutionIdentifier, wStatus v1alpha1.ExecutableWorkflowStatus, toStatus Status) error { + if wStatus.GetPhase() != toStatus.TransitionToPhase { + logger.Debugf(ctx, "Transitioning/Recording event for workflow state transition [%s] -> [%s]", wStatus.GetPhase().String(), toStatus.TransitionToPhase.String()) + + wfEvent := &event.WorkflowExecutionEvent{ + ExecutionId: execID, + } + previousMsg := wStatus.GetMessage() + switch toStatus.TransitionToPhase { + case v1alpha1.WorkflowPhaseReady: + // Do nothing + return nil + case v1alpha1.WorkflowPhaseRunning: + wfEvent.Phase = core.WorkflowExecution_RUNNING + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseRunning, fmt.Sprintf("Workflow Started")) + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStartedAt()) + case v1alpha1.WorkflowPhaseFailing: + wfEvent.Phase = core.WorkflowExecution_FAILING + e := convertToExecutionError(toStatus.Err, previousMsg) + wfEvent.OutputResult = e + // Completion latency is only observed when a workflow completes successfully + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseFailing, e.Error.Message) + wfEvent.OccurredAt = utils.GetProtoTime(nil) + case v1alpha1.WorkflowPhaseFailed: + wfEvent.Phase = core.WorkflowExecution_FAILED + e := convertToExecutionError(toStatus.Err, previousMsg) + wfEvent.OutputResult = e + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseFailed, e.Error.Message) + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStoppedAt()) + c.metrics.FailureDuration.Observe(ctx, wStatus.GetStartedAt().Time, wStatus.GetStoppedAt().Time) + case v1alpha1.WorkflowPhaseSucceeding: + wfEvent.Phase = core.WorkflowExecution_SUCCEEDING + endNodeStatus := wStatus.GetNodeExecutionStatus(v1alpha1.EndNodeID) + // Workflow completion latency is recorded as the time it takes for the workflow to transition from end + // node started time to workflow success being sent to the control plane. + if endNodeStatus != nil && endNodeStatus.GetStartedAt() != nil { + c.metrics.CompletionLatency.Observe(ctx, endNodeStatus.GetStartedAt().Time, time.Now()) + } + + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + wfEvent.OccurredAt = utils.GetProtoTime(nil) + case v1alpha1.WorkflowPhaseSuccess: + wfEvent.Phase = core.WorkflowExecution_SUCCEEDED + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseSuccess, "") + // Not all workflows have outputs + if wStatus.GetOutputReference() != "" { + wfEvent.OutputResult = &event.WorkflowExecutionEvent_OutputUri{ + OutputUri: wStatus.GetOutputReference().String(), + } + } + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStoppedAt()) + c.metrics.SuccessDuration.Observe(ctx, wStatus.GetStartedAt().Time, wStatus.GetStoppedAt().Time) + case v1alpha1.WorkflowPhaseAborted: + wfEvent.Phase = core.WorkflowExecution_ABORTED + if wStatus.GetLastUpdatedAt() != nil { + c.metrics.CompletionLatency.Observe(ctx, wStatus.GetLastUpdatedAt().Time, time.Now()) + } + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseAborted, "") + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStoppedAt()) + default: + return errors.Errorf(errors.IllegalStateError, "", "Illegal transition from [%v] -> [%v]", wStatus.GetPhase().String(), toStatus.TransitionToPhase.String()) + } + + if recordingErr := c.IdempotentReportEvent(ctx, wfEvent); recordingErr != nil { + if eventsErr.IsEventAlreadyInTerminalStateError(recordingErr) { + // Move to WorkflowPhaseFailed for state mis-match + msg := fmt.Sprintf("workflow state mismatch between propeller and control plane; Propeller State: %s, ExecutionId %s", wfEvent.Phase.String(), wfEvent.ExecutionId) + logger.Warningf(ctx, msg) + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseFailed, msg) + return nil + } + logger.Warningf(ctx, "Event recording failed. Error [%s]", recordingErr.Error()) + return errors.Wrapf(errors.EventRecordingError, "", recordingErr, "failed to publish event") + } + } + return nil +} + +func (c *workflowExecutor) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing Core Workflow Executor") + return c.nodeExecutor.Initialize(ctx) +} + +func (c *workflowExecutor) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + logger.Infof(ctx, "Handling Workflow [%s], id: [%s], Phase [%s]", w.GetName(), w.GetExecutionID(), w.GetExecutionStatus().GetPhase().String()) + defer logger.Infof(ctx, "Handling Workflow [%s] Done", w.GetName()) + + wStatus := w.GetExecutionStatus() + // Initialize the Status if not already initialized + switch wStatus.GetPhase() { + case v1alpha1.WorkflowPhaseReady: + newStatus, err := c.handleReadyWorkflow(ctx, w) + if err != nil { + return err + } + c.metrics.AcceptedWorkflows.Inc(ctx) + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + c.k8sRecorder.Event(w, corev1.EventTypeNormal, v1alpha1.WorkflowPhaseRunning.String(), "Workflow began execution") + + // TODO: Consider annotating with the newStatus. + acceptedAt := w.GetCreationTimestamp().Time + if w.AcceptedAt != nil && !w.AcceptedAt.IsZero() { + acceptedAt = w.AcceptedAt.Time + } + + c.metrics.AcceptanceLatency.Observe(ctx, acceptedAt, time.Now()) + return nil + + case v1alpha1.WorkflowPhaseRunning: + newStatus, err := c.handleRunningWorkflow(ctx, w) + if err != nil { + logger.Warningf(ctx, "Error in handling running workflow [%v]", err.Error()) + return err + } + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + return nil + case v1alpha1.WorkflowPhaseSucceeding: + newStatus := c.handleSucceedingWorkflow(ctx, w) + + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + c.k8sRecorder.Event(w, corev1.EventTypeNormal, v1alpha1.WorkflowPhaseSuccess.String(), "Workflow completed.") + return nil + case v1alpha1.WorkflowPhaseFailing: + newStatus, err := c.handleFailingWorkflow(ctx, w) + if err != nil { + return err + } + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + c.k8sRecorder.Event(w, corev1.EventTypeWarning, v1alpha1.WorkflowPhaseFailed.String(), "Workflow failed.") + return nil + default: + return errors.Errorf(errors.IllegalStateError, w.ID, "Unsupported state [%s] for workflow", w.GetExecutionStatus().GetPhase().String()) + } +} + +func (c *workflowExecutor) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + if !w.Status.IsTerminated() { + c.metrics.IncompleteWorkflowAborted.Inc(ctx) + var err error + if w.Status.FailedAttempts > maxRetries { + err = errors.Errorf(errors.RuntimeExecutionError, w.GetID(), "max number of system retry attempts [%d/%d] exhausted. Last known status message: %v", w.Status.FailedAttempts, maxRetries, w.Status.Message) + } + + // Best effort clean-up. + contextualWf := executors.NewBaseContextualWorkflow(w) + if err2 := c.cleanupRunningNodes(ctx, contextualWf); err2 != nil { + logger.Errorf(ctx, "Failed to propagate Abort for workflow:%v. Error: %v", w.ExecutionID.WorkflowExecutionIdentifier, err2) + } + + var status Status + if err != nil { + // This workflow failed, record that phase and corresponding error message. + status = StatusFailed(err) + } else { + // Otherwise, this workflow is aborted. + status = Status{ + TransitionToPhase: v1alpha1.WorkflowPhaseAborted, + } + } + + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, w.GetExecutionStatus(), status); err != nil { + return err + } + } + return nil +} + +func (c *workflowExecutor) cleanupRunningNodes(ctx context.Context, w v1alpha1.ExecutableWorkflow) error { + startNode := w.StartNode() + if startNode == nil { + return errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?") + } + + if err := c.nodeExecutor.AbortHandler(ctx, w, startNode); err != nil { + return errors.Errorf(errors.CausedByError, w.GetID(), "Failed to propagate Abort for workflow. Error: %v", err) + } + + return nil +} + +func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, k8sEventRecorder record.EventRecorder, metadataPrefix string, nodeExecutor executors.Node, scope promutils.Scope) (executors.Workflow, error) { + basePrefix := store.GetBaseContainerFQN(ctx) + if metadataPrefix != "" { + var err error + basePrefix, err = store.ConstructReference(ctx, basePrefix, metadataPrefix) + if err != nil { + return nil, err + } + } + logger.Infof(ctx, "Metadata will be stored in container path: [%s]", basePrefix) + + workflowScope := scope.NewSubScope("workflow") + + return &workflowExecutor{ + nodeExecutor: nodeExecutor, + store: store, + enqueueWorkflow: enQWorkflow, + wfRecorder: events.NewWorkflowEventRecorder(eventSink, workflowScope), + k8sRecorder: k8sEventRecorder, + metadataPrefix: basePrefix, + metrics: &workflowMetrics{ + AcceptedWorkflows: labeled.NewCounter("accepted", "Number of workflows accepted by propeller", workflowScope), + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, workflowScope), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, workflowScope), + IncompleteWorkflowAborted: labeled.NewCounter("workflow_aborted", "Indicates an inprogress execution was aborted", workflowScope), + AcceptanceLatency: labeled.NewStopWatch("acceptance_latency", "Delay between workflow creation and moving it to running state.", time.Millisecond, workflowScope, labeled.EmitUnlabeledMetric), + CompletionLatency: labeled.NewStopWatch("completion_latency", "Measures the time between when the WF moved to succeeding/failing state and when it finally moved to a terminal state.", time.Millisecond, workflowScope, labeled.EmitUnlabeledMetric), + }, + }, nil +} diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go new file mode 100644 index 000000000..492e4067d --- /dev/null +++ b/pkg/controller/workflow/executor_test.go @@ -0,0 +1,615 @@ +package workflow + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "testing" + + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + wfErrors "github.com/lyft/flytepropeller/pkg/controller/workflow/errors" + + "time" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + pluginV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/lyft/flytestdlib/yamlutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "k8s.io/client-go/tools/record" +) + +var ( + testScope = promutils.NewScope("test_wfexec") + fakeKubeClient = mocks2.NewFakeKubeClient() +) + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func StdOutEventRecorder() record.EventRecorder { + eventChan := make(chan string) + recorder := &record.FakeRecorder{ + Events: eventChan, + } + + go func() { + defer close(eventChan) + for { + s := <-eventChan + if s == "" { + return + } + fmt.Printf("Event: [%v]\n", s) + } + }() + return recorder +} + +func createHappyPathTaskExecutor(t assert.TestingT, store *storage.DataStore, enableAsserts bool) pluginV1.Executor { + exec := &mocks.Executor{} + exec.On("GetID").Return("task") + exec.On("GetProperties").Return(pluginV1.ExecutorProperties{}) + exec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + + exec.On("ResolveOutputs", + mock.Anything, + mock.Anything, + mock.Anything, + ). + Return(func(ctx context.Context, taskCtx pluginV1.TaskContext, varNames ...string) (values map[string]*core.Literal) { + d := &handler.Data{} + outputsFileRef := v1alpha1.GetOutputsFile(taskCtx.GetDataDir()) + assert.NoError(t, store.ReadProtobuf(ctx, outputsFileRef, d)) + assert.NotNil(t, d.Literals) + + values = make(map[string]*core.Literal, len(varNames)) + for _, varName := range varNames { + l, ok := d.Literals[varName] + assert.True(t, ok, "Expect var %v in task outputs.", varName) + + values[varName] = l + } + + return values + }, func(ctx context.Context, taskCtx pluginV1.TaskContext, varNames ...string) error { + return nil + }) + + startFn := func(ctx context.Context, taskCtx pluginV1.TaskContext, task *core.TaskTemplate, _ *core.LiteralMap) pluginV1.TaskStatus { + outputVars := task.GetInterface().Outputs.Variables + o := &core.LiteralMap{ + Literals: make(map[string]*core.Literal, len(outputVars)), + } + for k, v := range outputVars { + l, err := utils.MakeDefaultLiteralForType(v.Type) + if enableAsserts && !assert.NoError(t, err) { + assert.FailNow(t, "Failed to create default output for node [%v] Type [%v]", taskCtx.GetTaskExecutionID(), v.Type) + } + o.Literals[k] = l + } + assert.NoError(t, store.WriteProtobuf(ctx, v1alpha1.GetOutputsFile(taskCtx.GetDataDir()), storage.Options{}, o)) + + return pluginV1.TaskStatusRunning + } + + exec.On("StartTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(startFn, nil) + + checkStatusFn := func(_ context.Context, taskCtx pluginV1.TaskContext, _ *core.TaskTemplate) pluginV1.TaskStatus { + if enableAsserts { + assert.NotEmpty(t, taskCtx.GetDataDir()) + } + return pluginV1.TaskStatusSucceeded + } + exec.On("CheckTaskStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(checkStatusFn, nil) + + return exec +} + +func createFailingTaskExecutor() pluginV1.Executor { + exec := &mocks.Executor{} + exec.On("GetID").Return("task") + exec.On("GetProperties").Return(pluginV1.ExecutorProperties{}) + exec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + + exec.On("StartTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginV1.TaskStatusRunning, nil) + + exec.On("CheckTaskStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginV1.TaskStatusPermanentFailure(errors.New("failed")), nil) + + exec.On("KillTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.Anything, + ).Return(nil) + + return exec +} + +func createTaskExecutorErrorInCheck() pluginV1.Executor { + exec := &mocks.Executor{} + exec.On("GetID").Return("task") + exec.On("GetProperties").Return(pluginV1.ExecutorProperties{}) + exec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + + exec.On("StartTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginV1.TaskStatusRunning, nil) + + exec.On("CheckTaskStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginV1.TaskStatusUndefined, errors.New("check failed")) + + return exec +} + +func createSingletonTaskExecutorFactory(te pluginV1.Executor) task.Factory { + return &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginV1.Executor, error) { + return te, nil + }, + ListAllTaskExecutorsCb: func() []pluginV1.Executor { + return []pluginV1.Executor{te} + }, + } +} + +func init() { + flytek8s.InitializeFake() +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, testScope.NewSubScope("12")) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createTaskExecutorErrorInCheck() + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + eventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 7 + 1 + + for i := 0; i < 11; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + + if i < 4 { + assert.NoError(t, err, "Round %d", i) + } else { + assert.Error(t, err, "Round %d", i) + } + } + assert.Equal(t, v1alpha1.WorkflowPhaseRunning.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } +} + +func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, testScope.NewSubScope("13")) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createHappyPathTaskExecutor(t, store, true) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + eventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 12 ? + for i := 0; i < 12; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + if err != nil { + t.Log(err) + } + + assert.NoError(t, err) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + } + + assert.Equal(t, v1alpha1.WorkflowPhaseSuccess.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } +} + +func BenchmarkWorkflowExecutor(b *testing.B) { + scope := promutils.NewScope("test3") + ctx := context.Background() + store := createInmemoryDataStore(b, scope.NewSubScope(strconv.Itoa(b.N))) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(b, err) + + te := createHappyPathTaskExecutor(b, store, false) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(b, task.IsTestModeEnabled()) + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + eventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, scope) + assert.NoError(b, err) + + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(b, err) + + assert.NoError(b, executor.Initialize(ctx)) + b.ReportAllocs() + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if err != nil { + assert.FailNow(b, "Got error reading the testdata") + } + w := &v1alpha1.FlyteWorkflow{} + err = json.Unmarshal(wJSON, w) + if err != nil { + assert.FailNow(b, "Got error unmarshalling the testdata") + } + + // Current benchmark 2ms/op + for i := 0; i < b.N; i++ { + deepW := w.DeepCopy() + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 7 + 1 + for i := 0; i < 8; i++ { + err := executor.HandleFlyteWorkflow(ctx, deepW) + if err != nil { + assert.FailNow(b, "Run the unit test first. Benchmark should not fail") + } + } + if deepW.Status.Phase != v1alpha1.WorkflowPhaseSuccess { + assert.FailNow(b, "Workflow did not end in the expected state") + } + } +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, promutils.NewTestScope()) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createFailingTaskExecutor() + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + recordedRunning := false + recordedFailed := false + recordedFailing := true + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + e, ok := message.(*event.WorkflowExecutionEvent) + + if ok { + assert.True(t, ok) + switch e.Phase { + case core.WorkflowExecution_RUNNING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedRunning = true + case core.WorkflowExecution_FAILING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedFailing = true + case core.WorkflowExecution_FAILED: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedFailed = true + default: + return fmt.Errorf("MockWorkflowRecorder should not have entered into any other states [%v]", e.Phase) + } + } + return nil + } + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we will run into the first failure on round 6 + + for i := 0; i < 6; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + assert.Nil(t, err, "Round [%v]", i) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + + if i == 5 { + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase) + } else { + assert.NotEqual(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase, "For Round [%v] got phase [%v]", i, w.Status.Phase.String()) + } + + } + + assert.Equal(t, v1alpha1.WorkflowPhaseFailed.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } + assert.True(t, recordedRunning) + assert.True(t, recordedFailing) + assert.True(t, recordedFailed) +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, promutils.NewTestScope()) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createHappyPathTaskExecutor(t, store, true) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + recordedRunning := false + recordedSuccess := false + recordedFailing := true + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + e, ok := message.(*event.WorkflowExecutionEvent) + if ok { + switch e.Phase { + case core.WorkflowExecution_RUNNING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedRunning = true + case core.WorkflowExecution_SUCCEEDING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedFailing = true + case core.WorkflowExecution_SUCCEEDED: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedSuccess = true + default: + return fmt.Errorf("MockWorkflowRecorder should not have entered into any other states, received [%v]", e.Phase.String()) + } + } + return nil + } + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 12 ? + for i := 0; i < 12; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + assert.NoError(t, err) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + } + + assert.Equal(t, v1alpha1.WorkflowPhaseSuccess.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } + assert.True(t, recordedRunning) + assert.True(t, recordedFailing) + assert.True(t, recordedSuccess) +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, promutils.NewTestScope()) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createHappyPathTaskExecutor(t, store, true) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + assert.NoError(t, err) + + nodeEventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, nodeEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + + t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { + + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + } + } + executor, err := NewExecutor(ctx, store, enqueueWorkflow, mockSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + assert.NoError(t, json.Unmarshal(wJSON, w)) + + assert.NoError(t, executor.Initialize(ctx)) + err = executor.HandleFlyteWorkflow(ctx, w) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed.String(), w.Status.Phase.String()) + + assert.NoError(t, err) + }) + + t.Run("EventSinkAlreadyExistsError", func(t *testing.T) { + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.AlreadyExists, + Cause: errors.New("already exists"), + } + } + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + assert.NoError(t, json.Unmarshal(wJSON, w)) + + err = executor.HandleFlyteWorkflow(ctx, w) + assert.NoError(t, err) + }) + + t.Run("EventSinkGenericError", func(t *testing.T) { + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.EventSinkError, + Cause: errors.New("generic exists"), + } + } + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + assert.NoError(t, json.Unmarshal(wJSON, w)) + + err = executor.HandleFlyteWorkflow(ctx, w) + assert.Error(t, err) + assert.True(t, wfErrors.Matches(err, wfErrors.EventRecordingError)) + }) + +} diff --git a/pkg/controller/workflow/testdata/benchmark_wf.yaml b/pkg/controller/workflow/testdata/benchmark_wf.yaml new file mode 100644 index 000000000..d53059b33 --- /dev/null +++ b/pkg/controller/workflow/testdata/benchmark_wf.yaml @@ -0,0 +1,378 @@ +kind: flyteworkflow +metadata: + creationTimestamp: null + generateName: dummy-workflow-1-0- + labels: + execution-id: "exec-id" + workflow-id: dummy-workflow-1-0 + namespace: myflytenamespace + name: "test-wf" +inputs: + literals: + triggered_date: + scalar: + primitive: + datetime: 2018-08-08T22:16:36.860016587Z +spec: + connections: + add-one-and-print-0: + - sum-non-none-0 + add-one-and-print-1: + - add-one-and-print-2 + - add-one-and-print-2 + - sum-and-print-0 + - sum-and-print-0 + add-one-and-print-2: + - sum-and-print-0 + - sum-and-print-0 + add-one-and-print-3: + - sum-non-none-0 + - sum-non-none-0 + start-node: + - print-every-time-0 + - add-one-and-print-0 + - add-one-and-print-3 + sum-and-print-0: + - print-every-time-0 + - print-every-time-0 + - print-every-time-0 + - print-every-time-0 + sum-non-none-0: + - add-one-and-print-1 + - add-one-and-print-1 + - sum-and-print-0 + id: dummy-workflow-1-0 + nodes: + add-one-and-print-0: + activeDeadlineSeconds: 0 + id: add-one-and-print-0 + inputBindings: + - binding: + scalar: + primitive: + integer: "3" + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + add-one-and-print-1: + activeDeadlineSeconds: 0 + id: add-one-and-print-1 + inputBindings: + - binding: + promise: + nodeId: sum-non-none-0 + var: out + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + add-one-and-print-2: + activeDeadlineSeconds: 0 + id: add-one-and-print-2 + inputBindings: + - binding: + promise: + nodeId: add-one-and-print-1 + var: out + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + add-one-and-print-3: + activeDeadlineSeconds: 0 + id: add-one-and-print-3 + inputBindings: + - binding: + scalar: + primitive: + integer: "101" + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + end-node: + id: end-node + kind: end + resources: {} + status: + phase: 0 + print-every-time-0: + activeDeadlineSeconds: 0 + id: print-every-time-0 + inputBindings: + - binding: + promise: + nodeId: start-node + var: triggered_date + var: date_triggered + - binding: + promise: + nodeId: sum-and-print-0 + var: out_blob + var: in_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: multi_blob + var: multi_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: out + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: print-every-time + start-node: + id: start-node + kind: start + resources: {} + status: + phase: 0 + sum-and-print-0: + activeDeadlineSeconds: 0 + id: sum-and-print-0 + inputBindings: + - binding: + collection: + bindings: + - promise: + nodeId: sum-non-none-0 + var: out + - promise: + nodeId: add-one-and-print-1 + var: out + - promise: + nodeId: add-one-and-print-2 + var: out + - scalar: + primitive: + integer: "100" + var: values_to_add + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: sum-and-print + sum-non-none-0: + activeDeadlineSeconds: 0 + id: sum-non-none-0 + inputBindings: + - binding: + collection: + bindings: + - promise: + nodeId: add-one-and-print-0 + var: out + - promise: + nodeId: add-one-and-print-3 + var: out + var: values_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: sum-non-none +status: + phase: 0 +tasks: + add-one-and-print: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=add_one_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: add-one-and-print + interface: + inputs: + variables: + value_to_print: + type: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + print-every-time: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=print_every_time + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: print-every-time + interface: + inputs: + variables: + date_triggered: + type: + simple: DATETIME + in_blob: + type: + blob: + dimensionality: SINGLE + multi_blob: + type: + blob: + dimensionality: 1 + value_to_print: + type: + simple: INTEGER + outputs: + variables: {} + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + sum-and-print: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=sum_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: sum-and-print + interface: + inputs: + variables: + values_to_add: + type: + collectionType: + simple: INTEGER + outputs: + variables: + multi_blob: + type: + blob: + dimensionality: 1 + out: + type: + blob: + dimensionality: 0 + out_blob: + type: + blob: + dimensionality: 0 + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + sum-non-none: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=sum_non_none + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: sum-non-none + interface: + inputs: + variables: + values_to_print: + type: + collectionType: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + diff --git a/pkg/controller/workflowstore/errors.go b/pkg/controller/workflowstore/errors.go new file mode 100644 index 000000000..572508850 --- /dev/null +++ b/pkg/controller/workflowstore/errors.go @@ -0,0 +1,18 @@ +package workflowstore + +import ( + "fmt" + + "github.com/pkg/errors" +) + +var errStaleWorkflowError = fmt.Errorf("stale Workflow Found error") +var errWorkflowNotFound = fmt.Errorf("workflow not-found error") + +func IsNotFound(err error) bool { + return errors.Cause(err) == errWorkflowNotFound +} + +func IsWorkflowStale(err error) bool { + return errors.Cause(err) == errStaleWorkflowError +} diff --git a/pkg/controller/workflowstore/iface.go b/pkg/controller/workflowstore/iface.go new file mode 100644 index 000000000..7e40f26b5 --- /dev/null +++ b/pkg/controller/workflowstore/iface.go @@ -0,0 +1,20 @@ +package workflowstore + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type PriorityClass int + +const ( + PriorityClassCritical PriorityClass = iota + PriorityClassRegular +) + +type FlyteWorkflow interface { + Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) + UpdateStatus(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error + Update(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error +} diff --git a/pkg/controller/workflowstore/inmemory.go b/pkg/controller/workflowstore/inmemory.go new file mode 100644 index 000000000..906e1cc75 --- /dev/null +++ b/pkg/controller/workflowstore/inmemory.go @@ -0,0 +1,70 @@ +package workflowstore + +import ( + "context" + "fmt" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +type InmemoryWorkflowStore struct { + store map[string]map[string]*v1alpha1.FlyteWorkflow +} + +func (i *InmemoryWorkflowStore) Create(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + if w != nil { + if w.Name != "" && w.Namespace != "" { + if _, ok := i.store[w.Namespace]; !ok { + i.store[w.Namespace] = map[string]*v1alpha1.FlyteWorkflow{} + } + i.store[w.Namespace][w.Name] = w + return nil + } + } + return kubeerrors.NewBadRequest(fmt.Sprintf("Workflow object with Namespace [%v] & Name [%v] is required", w.Namespace, w.Name)) +} + +func (i *InmemoryWorkflowStore) Delete(ctx context.Context, namespace, name string) error { + if m, ok := i.store[namespace]; ok { + if _, ok := m[name]; ok { + delete(m, name) + return nil + } + } + return nil +} + +func (i *InmemoryWorkflowStore) Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) { + if m, ok := i.store[namespace]; ok { + if v, ok := m[name]; ok { + return v, nil + } + } + return nil, errWorkflowNotFound +} + +func (i *InmemoryWorkflowStore) UpdateStatus(ctx context.Context, w *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + if w != nil { + if w.Name != "" && w.Namespace != "" { + if m, ok := i.store[w.Namespace]; ok { + if _, ok := m[w.Name]; ok { + m[w.Name] = w + return nil + } + } + return nil + } + } + return kubeerrors.NewBadRequest("Workflow object with Namespace & Name is required") +} + +func (i *InmemoryWorkflowStore) Update(ctx context.Context, w *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + return i.UpdateStatus(ctx, w, priorityClass) +} + +func NewInMemoryWorkflowStore() *InmemoryWorkflowStore { + return &InmemoryWorkflowStore{ + store: map[string]map[string]*v1alpha1.FlyteWorkflow{}, + } +} diff --git a/pkg/controller/workflowstore/passthrough.go b/pkg/controller/workflowstore/passthrough.go new file mode 100644 index 000000000..882242b34 --- /dev/null +++ b/pkg/controller/workflowstore/passthrough.go @@ -0,0 +1,106 @@ +package workflowstore + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1alpha12 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + listers "github.com/lyft/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +type workflowstoreMetrics struct { + workflowUpdateCount prometheus.Counter + workflowUpdateFailedCount prometheus.Counter + workflowUpdateSuccessCount prometheus.Counter + workflowUpdateConflictCount prometheus.Counter + workflowUpdateLatency promutils.StopWatch +} + +type passthroughWorkflowStore struct { + wfLister listers.FlyteWorkflowLister + wfClientSet v1alpha12.FlyteworkflowV1alpha1Interface + metrics *workflowstoreMetrics +} + +func (p *passthroughWorkflowStore) Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) { + w, err := p.wfLister.FlyteWorkflows(namespace).Get(name) + if err != nil { + // The FlyteWorkflow resource may no longer exist, in which case we stop + // processing. + if kubeerrors.IsNotFound(err) { + logger.Warningf(ctx, "Workflow not found in cache.") + return nil, errWorkflowNotFound + } + return nil, err + } + return w, nil +} + +func (p *passthroughWorkflowStore) UpdateStatus(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + p.metrics.workflowUpdateCount.Inc() + // Something has changed. Lets save + logger.Debugf(ctx, "Observed FlyteWorkflow State change. [%v] -> [%v]", workflow.Status.Phase.String(), workflow.Status.Phase.String()) + t := p.metrics.workflowUpdateLatency.Start() + _, err := p.wfClientSet.FlyteWorkflows(workflow.Namespace).Update(workflow) + if err != nil { + if kubeerrors.IsNotFound(err) { + return nil + } + if kubeerrors.IsConflict(err) { + p.metrics.workflowUpdateConflictCount.Inc() + } + p.metrics.workflowUpdateFailedCount.Inc() + logger.Errorf(ctx, "Failed to update workflow status. Error [%v]", err) + return err + } + t.Stop() + p.metrics.workflowUpdateSuccessCount.Inc() + logger.Debugf(ctx, "Updated workflow status.") + return nil +} + +func (p *passthroughWorkflowStore) Update(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + p.metrics.workflowUpdateCount.Inc() + // Something has changed. Lets save + logger.Debugf(ctx, "Observed FlyteWorkflow Update (maybe finalizer)") + t := p.metrics.workflowUpdateLatency.Start() + _, err := p.wfClientSet.FlyteWorkflows(workflow.Namespace).Update(workflow) + if err != nil { + if kubeerrors.IsNotFound(err) { + return nil + } + if kubeerrors.IsConflict(err) { + p.metrics.workflowUpdateConflictCount.Inc() + } + p.metrics.workflowUpdateFailedCount.Inc() + logger.Errorf(ctx, "Failed to update workflow. Error [%v]", err) + return err + } + t.Stop() + p.metrics.workflowUpdateSuccessCount.Inc() + logger.Debugf(ctx, "Updated workflow.") + return nil +} + +func NewPassthroughWorkflowStore(_ context.Context, scope promutils.Scope, wfClient v1alpha12.FlyteworkflowV1alpha1Interface, + flyteworkflowLister listers.FlyteWorkflowLister) FlyteWorkflow { + + metrics := &workflowstoreMetrics{ + workflowUpdateCount: scope.MustNewCounter("wf_updated", "Total number of status updates"), + workflowUpdateFailedCount: scope.MustNewCounter("wf_update_failed", "Failure to update ETCd"), + workflowUpdateConflictCount: scope.MustNewCounter("wf_update_conflict", "Failure to update ETCd because of conflict"), + workflowUpdateSuccessCount: scope.MustNewCounter("wf_update_success", "Success in updating ETCd"), + workflowUpdateLatency: scope.MustNewStopWatch("wf_update_latency", "Time taken to complete update/updatestatus", time.Millisecond), + } + + return &passthroughWorkflowStore{ + wfLister: flyteworkflowLister, + wfClientSet: wfClient, + metrics: metrics, + } +} diff --git a/pkg/controller/workflowstore/passthrough_test.go b/pkg/controller/workflowstore/passthrough_test.go new file mode 100644 index 000000000..9d77993e0 --- /dev/null +++ b/pkg/controller/workflowstore/passthrough_test.go @@ -0,0 +1,130 @@ +package workflowstore + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + listers "github.com/lyft/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/fake" + + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +type mockWFNamespaceLister struct { + listers.FlyteWorkflowNamespaceLister + GetCb func(name string) (*v1alpha1.FlyteWorkflow, error) +} + +func (m *mockWFNamespaceLister) Get(name string) (*v1alpha1.FlyteWorkflow, error) { + return m.GetCb(name) +} + +type mockWFLister struct { + listers.FlyteWorkflowLister + V listers.FlyteWorkflowNamespaceLister +} + +func (m *mockWFLister) FlyteWorkflows(namespace string) listers.FlyteWorkflowNamespaceLister { + return m.V +} + +func TestPassthroughWorkflowStore_Get(t *testing.T) { + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + l := &mockWFNamespaceLister{} + wfStore := NewPassthroughWorkflowStore(ctx, promutils.NewTestScope(), mockClient, &mockWFLister{V: l}) + + t.Run("notFound", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewNotFound(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.True(t, IsNotFound(err)) + assert.Nil(t, w) + }) + + t.Run("alreadyExists?", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewAlreadyExists(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("unknownError", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, fmt.Errorf("error") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("success", func(t *testing.T) { + expW := &v1alpha1.FlyteWorkflow{} + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return expW, nil + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.NoError(t, err) + assert.Equal(t, expW, w) + }) +} + +func dummyWf(namespace, name string) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + } +} + +func TestPassthroughWorkflowStore_UpdateStatus(t *testing.T) { + + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + l := &mockWFNamespaceLister{} + wfStore := NewPassthroughWorkflowStore(ctx, promutils.NewTestScope(), mockClient, &mockWFLister{V: l}) + + const namespace = "test-ns" + t.Run("notFound", func(t *testing.T) { + wf := dummyWf(namespace, "x") + err := wfStore.UpdateStatus(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + updated, err := mockClient.FlyteWorkflows(namespace).Get("x", v1.GetOptions{}) + assert.Error(t, err) + assert.Nil(t, updated) + }) + + t.Run("Found-Updated", func(t *testing.T) { + n := mockClient.FlyteWorkflows(namespace) + wf := dummyWf(namespace, "x") + wf.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + wf.ResourceVersion = "r1" + _, err := n.Create(wf) + assert.NoError(t, err) + updated, err := n.Get("x", v1.GetOptions{}) + if assert.NoError(t, err) { + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, updated.GetExecutionStatus().GetPhase()) + wf.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseFailed, "") + err := wfStore.UpdateStatus(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + newVal, err := n.Get("x", v1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, newVal.GetExecutionStatus().GetPhase()) + } + }) + +} diff --git a/pkg/controller/workflowstore/resource_version_caching.go b/pkg/controller/workflowstore/resource_version_caching.go new file mode 100644 index 000000000..2bb84854c --- /dev/null +++ b/pkg/controller/workflowstore/resource_version_caching.go @@ -0,0 +1,96 @@ +package workflowstore + +import ( + "context" + "fmt" + "sync" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +// TODO - optimization maybe? we can move this to predicate check, before we add it to the queue? +type resourceVersionMetrics struct { + workflowStaleCount prometheus.Counter + workflowEvictedCount prometheus.Counter +} + +// Simple function that covnerts the namespace and name to a string +func resourceVersionKey(namespace, name string) string { + return fmt.Sprintf("%s/%s", namespace, name) +} + +// A specialized store that stores a inmemory cache of all the workflows that are currently executing and their last observed version numbers +// If the version numbers between the last update and the next Get have not been updated then the Get returns a nil (ignores the workflow) +// Propeller round will then just ignore the workflow +type resourceVersionCaching struct { + w FlyteWorkflow + metrics *resourceVersionMetrics + lastUpdatedResourceVersionCache sync.Map +} + +func (r *resourceVersionCaching) updateRevisionCache(_ context.Context, namespace, name, resourceVersion string, isTerminated bool) { + if isTerminated { + r.metrics.workflowEvictedCount.Inc() + r.lastUpdatedResourceVersionCache.Delete(resourceVersionKey(namespace, name)) + } else { + r.lastUpdatedResourceVersionCache.Store(resourceVersionKey(namespace, name), resourceVersion) + } +} + +func (r *resourceVersionCaching) isResourceVersionSameAsPrevious(namespace, name, resourceVersion string) bool { + if v, ok := r.lastUpdatedResourceVersionCache.Load(resourceVersionKey(namespace, name)); ok { + strV := v.(string) + if strV == resourceVersion { + r.metrics.workflowStaleCount.Inc() + return true + } + } + return false +} + +func (r *resourceVersionCaching) Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) { + w, err := r.w.Get(ctx, namespace, name) + if err != nil { + return nil, err + } + if w != nil { + if r.isResourceVersionSameAsPrevious(namespace, name, w.ResourceVersion) { + return nil, errStaleWorkflowError + } + } + return w, nil +} + +func (r *resourceVersionCaching) UpdateStatus(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + err := r.w.UpdateStatus(ctx, workflow, priorityClass) + if err != nil { + return err + } + + r.updateRevisionCache(ctx, workflow.Namespace, workflow.Name, workflow.ResourceVersion, workflow.Status.IsTerminated()) + return nil +} + +func (r *resourceVersionCaching) Update(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + err := r.w.Update(ctx, workflow, priorityClass) + if err != nil { + return err + } + + r.updateRevisionCache(ctx, workflow.Namespace, workflow.Name, workflow.ResourceVersion, workflow.Status.IsTerminated()) + return nil +} + +func NewResourceVersionCachingStore(ctx context.Context, scope promutils.Scope, workflowStore FlyteWorkflow) FlyteWorkflow { + + return &resourceVersionCaching{ + w: workflowStore, + metrics: &resourceVersionMetrics{ + workflowStaleCount: scope.MustNewCounter("wf_stale", "Found stale workflow in cache"), + workflowEvictedCount: scope.MustNewCounter("wf_evict", "removed workflow from resource version cache"), + }, + lastUpdatedResourceVersionCache: sync.Map{}, + } +} diff --git a/pkg/controller/workflowstore/resource_version_caching_test.go b/pkg/controller/workflowstore/resource_version_caching_test.go new file mode 100644 index 000000000..b09847e0f --- /dev/null +++ b/pkg/controller/workflowstore/resource_version_caching_test.go @@ -0,0 +1,153 @@ +package workflowstore + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/fake" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +func TestResourceVersionCaching_Get_NotInCache(t *testing.T) { + ctx := context.TODO() + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + + t.Run("notFound", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewNotFound(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.True(t, IsNotFound(err)) + assert.Nil(t, w) + }) + + t.Run("alreadyExists?", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewAlreadyExists(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("unknownError", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, fmt.Errorf("error") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("success", func(t *testing.T) { + expW := &v1alpha1.FlyteWorkflow{} + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return expW, nil + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.NoError(t, err) + assert.Equal(t, expW, w) + }) +} + +func TestResourceVersionCaching_Get_UpdateAndRead(t *testing.T) { + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + namespace := "ns" + name := "name" + resourceVersion := "r1" + + wf := dummyWf(namespace, name) + wf.ResourceVersion = resourceVersion + + t.Run("Stale", func(t *testing.T) { + + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + // Insert a new workflow with R1 + err := wfStore.Update(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + + // Return the same workflow + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + + return wf, nil + } + + w, err := wfStore.Get(ctx, namespace, name) + assert.Error(t, err) + assert.False(t, IsNotFound(err)) + assert.True(t, IsWorkflowStale(err)) + assert.Nil(t, w) + }) + + t.Run("Updated", func(t *testing.T) { + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + // Insert a new workflow with R1 + err := wfStore.Update(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + + // Update the workflow version + wf2 := wf.DeepCopy() + wf2.ResourceVersion = "r2" + + // Return updated workflow + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return wf2, nil + } + + w, err := wfStore.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.NotNil(t, w) + assert.Equal(t, "r2", w.ResourceVersion) + }) +} + +func TestResourceVersionCaching_UpdateTerminated(t *testing.T) { + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + namespace := "ns" + name := "name" + resourceVersion := "r1" + + wf := dummyWf(namespace, name) + wf.ResourceVersion = resourceVersion + + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + // Insert a new workflow with R1 + err := wfStore.Update(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + + rvStore := wfStore.(*resourceVersionCaching) + v, ok := rvStore.lastUpdatedResourceVersionCache.Load(resourceVersionKey(namespace, name)) + assert.True(t, ok) + assert.Equal(t, resourceVersion, v.(string)) + + wf2 := wf.DeepCopy() + wf2.Status.Phase = v1alpha1.WorkflowPhaseAborted + err = wfStore.Update(ctx, wf2, PriorityClassCritical) + assert.NoError(t, err) + + v, ok = rvStore.lastUpdatedResourceVersionCache.Load(resourceVersionKey(namespace, name)) + assert.False(t, ok) + assert.Nil(t, v) + +} diff --git a/pkg/controller/workqueue.go b/pkg/controller/workqueue.go new file mode 100644 index 000000000..051333a8c --- /dev/null +++ b/pkg/controller/workqueue.go @@ -0,0 +1,47 @@ +package controller + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "golang.org/x/time/rate" + + "github.com/lyft/flytestdlib/logger" + "k8s.io/client-go/util/workqueue" +) + +func NewWorkQueue(ctx context.Context, cfg config.WorkqueueConfig, name string) (workqueue.RateLimitingInterface, error) { + // TODO introduce bounds checks + logger.Infof(ctx, "WorkQueue type [%v] configured", cfg.Type) + switch cfg.Type { + case config.WorkqueueTypeBucketRateLimiter: + logger.Infof(ctx, "Using Bucket Ratelimited Workqueue, Rate [%v] Capacity [%v]", cfg.Rate, cfg.Capacity) + return workqueue.NewNamedRateLimitingQueue( + // 10 qps, 100 bucket size. This is only for retry speed and its only the overall factor (not per item) + &workqueue.BucketRateLimiter{ + Limiter: rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Capacity), + }, name), nil + case config.WorkqueueTypeExponentialFailureRateLimiter: + logger.Infof(ctx, "Using Exponential failure backoff Ratelimited Workqueue, Base Delay [%v], max Delay [%v]", cfg.BaseDelay, cfg.MaxDelay) + return workqueue.NewNamedRateLimitingQueue( + workqueue.NewItemExponentialFailureRateLimiter(cfg.BaseDelay.Duration, cfg.MaxDelay.Duration), + name), nil + case config.WorkqueueTypeMaxOfRateLimiter: + logger.Infof(ctx, "Using Max-of Ratelimited Workqueue, Bucket {Rate [%v] Capacity [%v]} | FailureBackoff {Base Delay [%v], max Delay [%v]}", cfg.Rate, cfg.Capacity, cfg.BaseDelay, cfg.MaxDelay) + return workqueue.NewNamedRateLimitingQueue( + workqueue.NewMaxOfRateLimiter( + &workqueue.BucketRateLimiter{ + Limiter: rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Capacity), + }, + workqueue.NewItemExponentialFailureRateLimiter(cfg.BaseDelay.Duration, + cfg.MaxDelay.Duration), + ), name), nil + + case config.WorkqueueTypeDefault: + fallthrough + default: + logger.Infof(ctx, "Using Default Workqueue") + return workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), name), nil + } +} diff --git a/pkg/controller/workqueue_test.go b/pkg/controller/workqueue_test.go new file mode 100644 index 000000000..888c93b1f --- /dev/null +++ b/pkg/controller/workqueue_test.go @@ -0,0 +1,54 @@ +package controller + +import ( + "context" + "testing" + "time" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/config" + "github.com/stretchr/testify/assert" +) + +func TestNewWorkQueue(t *testing.T) { + ctx := context.TODO() + + t.Run("emptyConfig", func(t *testing.T) { + cfg := config2.WorkqueueConfig{} + w, err := NewWorkQueue(ctx, cfg, "q_test1") + assert.NoError(t, err) + assert.NotNil(t, w) + }) + + t.Run("simpleConfig", func(t *testing.T) { + cfg := config2.WorkqueueConfig{ + Type: config2.WorkqueueTypeDefault, + } + w, err := NewWorkQueue(ctx, cfg, "q_test2") + assert.NoError(t, err) + assert.NotNil(t, w) + }) + + t.Run("bucket", func(t *testing.T) { + cfg := config2.WorkqueueConfig{ + Type: config2.WorkqueueTypeBucketRateLimiter, + Capacity: 5, + Rate: 1, + } + w, err := NewWorkQueue(ctx, cfg, "q_test3") + assert.NoError(t, err) + assert.NotNil(t, w) + }) + + t.Run("expfailure", func(t *testing.T) { + cfg := config2.WorkqueueConfig{ + Type: config2.WorkqueueTypeExponentialFailureRateLimiter, + MaxDelay: config.Duration{Duration: time.Second * 10}, + BaseDelay: config.Duration{Duration: time.Second * 1}, + } + w, err := NewWorkQueue(ctx, cfg, "q_test4") + assert.NoError(t, err) + assert.NotNil(t, w) + }) +} diff --git a/pkg/signals/signal.go b/pkg/signals/signal.go new file mode 100644 index 000000000..2fe649c13 --- /dev/null +++ b/pkg/signals/signal.go @@ -0,0 +1,46 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/00e875e461860a3b584a2e891485e5b331473ec5/pkg/signals/signal.go + +package signals + +import ( + "context" + "os" + "os/signal" +) + +var onlyOneSignalHandler = make(chan struct{}) + +// SetupSignalHandler registered for SIGTERM and SIGINT. A stop channel is returned +// which is closed on one of these signals. If a second signal is caught, the program +// is terminated with exit code 1. +func SetupSignalHandler(ctx context.Context) context.Context { + close(onlyOneSignalHandler) // panics when called twice + + childCtx, cancel := context.WithCancel(ctx) + c := make(chan os.Signal, 2) + signal.Notify(c, shutdownSignals...) + go func() { + <-c + cancel() + <-c + os.Exit(1) // second signal. Exit directly. + }() + + return childCtx +} diff --git a/pkg/signals/signal_posix.go b/pkg/signals/signal_posix.go new file mode 100644 index 000000000..0c4cd6007 --- /dev/null +++ b/pkg/signals/signal_posix.go @@ -0,0 +1,28 @@ +// +build !windows + +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/00e875e461860a3b584a2e891485e5b331473ec5/pkg/signals/signal_posix.go + +package signals + +import ( + "os" + "syscall" +) + +var shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} diff --git a/pkg/signals/signal_windows.go b/pkg/signals/signal_windows.go new file mode 100644 index 000000000..344002881 --- /dev/null +++ b/pkg/signals/signal_windows.go @@ -0,0 +1,25 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/00e875e461860a3b584a2e891485e5b331473ec5/pkg/signals/signal_windows.go + +package signals + +import ( + "os" +) + +var shutdownSignals = []os.Signal{os.Interrupt} diff --git a/pkg/utils/assert/literals.go b/pkg/utils/assert/literals.go new file mode 100644 index 000000000..caf915c13 --- /dev/null +++ b/pkg/utils/assert/literals.go @@ -0,0 +1,74 @@ +package assert + +import ( + "reflect" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func EqualPrimitive(t *testing.T, p1 *core.Primitive, p2 *core.Primitive) { + if p1 != nil { + assert.NotNil(t, p2) + } + assert.Equal(t, reflect.TypeOf(p1.Value), reflect.TypeOf(p2.Value)) + switch p1.Value.(type) { + case *core.Primitive_Integer: + assert.Equal(t, p1.GetInteger(), p2.GetInteger()) + case *core.Primitive_StringValue: + assert.Equal(t, p1.GetStringValue(), p2.GetStringValue()) + default: + assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(p1.Value)) + } +} + +func EqualScalar(t *testing.T, p1 *core.Scalar, p2 *core.Scalar) { + if p1 != nil { + assert.NotNil(t, p2) + } + assert.Equal(t, reflect.TypeOf(p1.Value), reflect.TypeOf(p2.Value)) + switch p1.Value.(type) { + case *core.Scalar_Primitive: + EqualPrimitive(t, p1.GetPrimitive(), p2.GetPrimitive()) + default: + assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(p1.Value)) + } +} + +func EqualLiterals(t *testing.T, l1 *core.Literal, l2 *core.Literal) { + if l1 != nil { + assert.NotNil(t, l2) + } else { + assert.FailNow(t, "expected value is nil") + } + assert.Equal(t, reflect.TypeOf(l1.Value), reflect.TypeOf(l2.Value)) + switch l1.Value.(type) { + case *core.Literal_Scalar: + EqualScalar(t, l1.GetScalar(), l2.GetScalar()) + case *core.Literal_Map: + EqualLiteralMap(t, l1.GetMap(), l2.GetMap()) + default: + assert.FailNow(t, "Not supported test type") + } +} + +func EqualLiteralMap(t *testing.T, l1 *core.LiteralMap, l2 *core.LiteralMap) { + if assert.NotNil(t, l1, "l1 is nil") && assert.NotNil(t, l2, "l2 is nil") { + assert.Equal(t, len(l1.Literals), len(l2.Literals)) + for k, v := range l1.Literals { + actual, ok := l2.Literals[k] + assert.True(t, ok) + EqualLiterals(t, v, actual) + } + } +} + +func EqualLiteralCollection(t *testing.T, l1 *core.LiteralCollection, l2 *core.LiteralCollection) { + if assert.NotNil(t, l2) { + assert.Equal(t, len(l1.Literals), len(l2.Literals)) + for i, v := range l1.Literals { + EqualLiterals(t, v, l2.Literals[i]) + } + } +} diff --git a/pkg/utils/bindings.go b/pkg/utils/bindings.go new file mode 100644 index 000000000..4a1bdeef2 --- /dev/null +++ b/pkg/utils/bindings.go @@ -0,0 +1,85 @@ +package utils + +import "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + +func MakeBindingDataPromise(fromNode, fromVar string) *core.BindingData { + return &core.BindingData{ + Value: &core.BindingData_Promise{ + Promise: &core.OutputReference{ + Var: fromVar, + NodeId: fromNode, + }, + }, + } +} + +func MakeBindingPromise(fromNode, fromVar, toVar string) *core.Binding { + return &core.Binding{ + Var: toVar, + Binding: MakeBindingDataPromise(fromNode, fromVar), + } +} + +func MakeBindingDataCollection(bindings ...*core.BindingData) *core.BindingData { + return &core.BindingData{ + Value: &core.BindingData_Collection{ + Collection: &core.BindingDataCollection{ + Bindings: bindings, + }, + }, + } +} + +type Pair struct { + K string + V *core.BindingData +} + +func NewPair(k string, v *core.BindingData) Pair { + return Pair{K: k, V: v} +} + +func MakeBindingDataMap(pairs ...Pair) *core.BindingData { + bindingsMap := map[string]*core.BindingData{} + for _, p := range pairs { + bindingsMap[p.K] = p.V + } + return &core.BindingData{ + Value: &core.BindingData_Map{ + Map: &core.BindingDataMap{ + Bindings: bindingsMap, + }, + }, + } +} + +func MakePrimitiveBindingData(v interface{}) (*core.BindingData, error) { + p, err := MakePrimitive(v) + if err != nil { + return nil, err + } + return &core.BindingData{ + Value: &core.BindingData_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: p, + }, + }, + }, + }, nil +} + +func MustMakePrimitiveBindingData(v interface{}) *core.BindingData { + p, err := MakePrimitiveBindingData(v) + if err != nil { + panic(err) + } + return p +} + +func MakeBinding(variable string, b *core.BindingData) *core.Binding { + return &core.Binding{ + Var: variable, + Binding: b, + } +} diff --git a/pkg/utils/bindings_test.go b/pkg/utils/bindings_test.go new file mode 100644 index 000000000..c6cb5fcc1 --- /dev/null +++ b/pkg/utils/bindings_test.go @@ -0,0 +1,156 @@ +package utils + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +const primitiveString = "hello" + +func TestMakePrimitiveBinding(t *testing.T) { + { + v := 1.0 + xb, err := MakePrimitiveBindingData(v) + x := MakeBinding("x", xb) + assert.NoError(t, err) + assert.Equal(t, "x", x.GetVar()) + p := x.GetBinding() + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + } + { + v := struct { + }{} + _, err := MakePrimitiveBindingData(v) + assert.Error(t, err) + } +} + +func TestMustMakePrimitiveBinding(t *testing.T) { + { + v := 1.0 + x := MakeBinding("x", MustMakePrimitiveBindingData(v)) + assert.Equal(t, "x", x.GetVar()) + p := x.GetBinding() + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + } + { + v := struct { + }{} + assert.Panics(t, func() { + MustMakePrimitiveBindingData(v) + }) + } +} + +func TestMakeBindingDataCollection(t *testing.T) { + v1 := int64(1) + v2 := primitiveString + c := MakeBindingDataCollection( + MustMakePrimitiveBindingData(v1), + MustMakePrimitiveBindingData(v2), + ) + + c2 := MakeBindingDataCollection( + MustMakePrimitiveBindingData(v1), + c, + ) + + assert.NotNil(t, c.GetCollection()) + assert.Equal(t, 2, len(c.GetCollection().Bindings)) + { + p := c.GetCollection().GetBindings()[0] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := c.GetCollection().GetBindings()[1] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v2, p.GetScalar().GetPrimitive().GetStringValue()) + } + + assert.NotNil(t, c2.GetCollection()) + assert.Equal(t, 2, len(c2.GetCollection().Bindings)) + { + p := c2.GetCollection().GetBindings()[0] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := c2.GetCollection().GetBindings()[1] + assert.NotNil(t, p.GetCollection()) + assert.Equal(t, c.GetCollection(), p.GetCollection()) + } +} + +func TestMakeBindingDataMap(t *testing.T) { + v1 := int64(1) + v2 := primitiveString + c := MakeBindingDataCollection( + MustMakePrimitiveBindingData(v1), + MustMakePrimitiveBindingData(v2), + ) + + m := MakeBindingDataMap( + NewPair("x", MustMakePrimitiveBindingData(v1)), + NewPair("y", c), + ) + + m2 := MakeBindingDataMap( + NewPair("x", MustMakePrimitiveBindingData(v1)), + NewPair("y", m), + ) + assert.NotNil(t, m.GetMap()) + assert.Equal(t, 2, len(m.GetMap().GetBindings())) + { + p := m.GetMap().GetBindings()["x"] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := m.GetMap().GetBindings()["y"] + assert.NotNil(t, p.GetCollection()) + assert.Equal(t, c.GetCollection(), p.GetCollection()) + } + + assert.NotNil(t, m2.GetMap()) + assert.Equal(t, 2, len(m2.GetMap().GetBindings())) + { + p := m2.GetMap().GetBindings()["x"] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := m2.GetMap().GetBindings()["y"] + assert.NotNil(t, p.GetMap()) + assert.Equal(t, m.GetMap(), p.GetMap()) + } + +} + +func TestMakeBindingPromise(t *testing.T) { + p := MakeBindingPromise("n1", "x", "y") + assert.NotNil(t, p) + assert.Equal(t, "y", p.GetVar()) + assert.NotNil(t, p.GetBinding().GetPromise()) + assert.Equal(t, "n1", p.GetBinding().GetPromise().GetNodeId()) + assert.Equal(t, "x", p.GetBinding().GetPromise().GetVar()) +} + +func TestMakeBindingDataPromise(t *testing.T) { + p := MakeBindingDataPromise("n1", "x") + assert.NotNil(t, p) + assert.NotNil(t, p.GetPromise()) + assert.Equal(t, "n1", p.GetPromise().GetNodeId()) + assert.Equal(t, "x", p.GetPromise().GetVar()) +} diff --git a/pkg/utils/encoder.go b/pkg/utils/encoder.go new file mode 100644 index 000000000..8d92d92f7 --- /dev/null +++ b/pkg/utils/encoder.go @@ -0,0 +1,55 @@ +package utils + +import ( + "encoding/base32" + "fmt" + "hash/fnv" + "strings" +) + +const specialEncoderKey = "abcdefghijklmnopqrstuvwxyz123456" + +var base32Encoder = base32.NewEncoding(specialEncoderKey).WithPadding(base32.NoPadding) + +// Creates a new UniqueID that is based on the inputID and of a specified length, if the given id is longer than the +// maxLength. +func FixedLengthUniqueID(inputID string, maxLength int) (string, error) { + if len(inputID) <= maxLength { + return inputID, nil + } + + hasher := fnv.New32a() + _, err := hasher.Write([]byte(inputID)) + if err != nil { + return "", err + } + b := hasher.Sum(nil) + // expected length after this step is 8 chars (1 + 7 chars from base32Encoder.EncodeToString(b)) + finalStr := "f" + base32Encoder.EncodeToString(b) + if len(finalStr) > maxLength { + return finalStr, fmt.Errorf("max Length is too small, cannot create an encoded string that is so small") + } + return finalStr, nil +} + +// Creates a new uniqueID using the parts concatenated using `-` and ensures that the uniqueID is not longer than the +// maxLength. In case a simple concatenation yields a longer string, a new hashed ID is created which is always +// around 8 characters in length +func FixedLengthUniqueIDForParts(maxLength int, parts ...string) (string, error) { + b := strings.Builder{} + for i, p := range parts { + if i > 0 { + _, err := b.WriteRune('-') + if err != nil { + return "", err + } + } + + _, err := b.WriteString(p) + if err != nil { + return "", err + } + } + + return FixedLengthUniqueID(b.String(), maxLength) +} diff --git a/pkg/utils/encoder_test.go b/pkg/utils/encoder_test.go new file mode 100644 index 000000000..d4ae8aae7 --- /dev/null +++ b/pkg/utils/encoder_test.go @@ -0,0 +1,61 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFixedLengthUniqueID(t *testing.T) { + tests := []struct { + name string + input string + maxLength int + output string + expectError bool + }{ + {"smallerThanMax", "x", 5, "x", false}, + {"veryLowLimit", "xx", 1, "flfryc2i", true}, + {"highLimit", "xxxxxx", 5, "fufiti6i", true}, + {"higherLimit", "xxxxx", 10, "xxxxx", false}, + {"largeID", "xxxxxxxxxxx", 10, "fggddjly", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + i, err := FixedLengthUniqueID(test.input, test.maxLength) + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, i, test.output) + }) + } +} + +func TestFixedLengthUniqueIDForParts(t *testing.T) { + tests := []struct { + name string + parts []string + maxLength int + output string + expectError bool + }{ + {"smallerThanMax", []string{"x", "y", "z"}, 10, "x-y-z", false}, + {"veryLowLimit", []string{"x", "y"}, 1, "fz2jizji", true}, + {"fittingID", []string{"x"}, 2, "x", false}, + {"highLimit", []string{"x", "y", "z"}, 4, "fxzsoqrq", true}, + {"largeID", []string{"x", "y", "z", "m", "n"}, 8, "fsigbmty", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + i, err := FixedLengthUniqueIDForParts(test.maxLength, test.parts...) + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, i, test.output) + }) + } +} diff --git a/pkg/utils/event_helpers.go b/pkg/utils/event_helpers.go new file mode 100644 index 000000000..61af9f01d --- /dev/null +++ b/pkg/utils/event_helpers.go @@ -0,0 +1,32 @@ +package utils + +import ( + "context" + + "github.com/lyft/flyteidl/clients/go/events" + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/lyft/flytestdlib/logger" +) + +// Construct task event recorder to pass down to plugin. This is a just a wrapper around the normal +// taskEventRecorder that can encapsulate logic to validate and handle errors. +func NewPluginTaskEventRecorder(taskEventRecorder events.TaskEventRecorder) events.TaskEventRecorder { + return &pluginTaskEventRecorder{ + taskEventRecorder: taskEventRecorder, + } +} + +type pluginTaskEventRecorder struct { + taskEventRecorder events.TaskEventRecorder +} + +func (r pluginTaskEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent) error { + err := r.taskEventRecorder.RecordTaskEvent(ctx, event) + if err != nil && eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Task event phase: %s, taskId %s, retry attempt %d - already exists", + event.Phase.String(), event.GetTaskId(), event.RetryAttempt) + return nil + } + return err +} diff --git a/pkg/utils/failing_datastore.go b/pkg/utils/failing_datastore.go new file mode 100644 index 000000000..5dbff0665 --- /dev/null +++ b/pkg/utils/failing_datastore.go @@ -0,0 +1,32 @@ +package utils + +import ( + "context" + "fmt" + "io" + + "github.com/lyft/flytestdlib/storage" +) + +type FailingRawStore struct { +} + +func (FailingRawStore) CopyRaw(ctx context.Context, source, destination storage.DataReference, opts storage.Options) error { + return fmt.Errorf("failed to copy raw") +} + +func (FailingRawStore) GetBaseContainerFQN(ctx context.Context) storage.DataReference { + return "" +} + +func (FailingRawStore) Head(ctx context.Context, reference storage.DataReference) (storage.Metadata, error) { + return nil, fmt.Errorf("failed metadata fetch") +} + +func (FailingRawStore) ReadRaw(ctx context.Context, reference storage.DataReference) (io.ReadCloser, error) { + return nil, fmt.Errorf("failed read raw") +} + +func (FailingRawStore) WriteRaw(ctx context.Context, reference storage.DataReference, size int64, opts storage.Options, raw io.Reader) error { + return fmt.Errorf("failed write raw") +} diff --git a/pkg/utils/failing_datastore_test.go b/pkg/utils/failing_datastore_test.go new file mode 100644 index 000000000..9adb3e885 --- /dev/null +++ b/pkg/utils/failing_datastore_test.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bytes" + "context" + "testing" + + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +func TestFailingRawStore(t *testing.T) { + ctx := context.TODO() + f := FailingRawStore{} + _, err := f.Head(ctx, "") + assert.Error(t, err) + + c := f.GetBaseContainerFQN(ctx) + assert.Equal(t, storage.DataReference(""), c) + + _, err = f.ReadRaw(ctx, "") + assert.Error(t, err) + + err = f.WriteRaw(ctx, "", 0, storage.Options{}, bytes.NewReader(nil)) + assert.Error(t, err) +} diff --git a/pkg/utils/helpers.go b/pkg/utils/helpers.go new file mode 100644 index 000000000..2889164fe --- /dev/null +++ b/pkg/utils/helpers.go @@ -0,0 +1,12 @@ +package utils + +func CopyMap(o map[string]string) (r map[string]string) { + if o == nil { + return nil + } + r = make(map[string]string, len(o)) + for k, v := range o { + r[k] = v + } + return +} diff --git a/pkg/utils/helpers_test.go b/pkg/utils/helpers_test.go new file mode 100644 index 000000000..a9f693ae2 --- /dev/null +++ b/pkg/utils/helpers_test.go @@ -0,0 +1,19 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCopyMap(t *testing.T) { + m := map[string]string{ + "k1": "v1", + "k2": "v2", + } + co := CopyMap(m) + assert.NotNil(t, co) + assert.Equal(t, m, co) + + assert.Nil(t, CopyMap(nil)) +} diff --git a/pkg/utils/k8s.go b/pkg/utils/k8s.go new file mode 100644 index 000000000..5d45ac868 --- /dev/null +++ b/pkg/utils/k8s.go @@ -0,0 +1,105 @@ +package utils + +import ( + "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/pkg/errors" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var NotTheOwnerError = errors.Errorf("FlytePropeller is not the owner") + +// ResourceNvidiaGPU is the name of the Nvidia GPU resource. +const ResourceNvidiaGPU = "nvidia.com/gpu" + +func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar { + envVars := make([]v1.EnvVar, 0, len(env)) + for _, kv := range env { + envVars = append(envVars, v1.EnvVar{Name: kv.Key, Value: kv.Value}) + } + return envVars +} + +// TODO we should modify the container resources to contain a map of enum values? +// Also we should probably create tolerations / taints, but we could do that as a post process +func ToK8sResourceList(resources []*core.Resources_ResourceEntry) (v1.ResourceList, error) { + k8sResources := make(v1.ResourceList, len(resources)) + for _, r := range resources { + rVal := r.Value + v, err := resource.ParseQuantity(rVal) + if err != nil { + return nil, errors.Wrap(err, "Failed to parse resource as a valid quantity.") + } + switch r.Name { + case core.Resources_CPU: + if !v.IsZero() { + k8sResources[v1.ResourceCPU] = v + } + case core.Resources_MEMORY: + if !v.IsZero() { + k8sResources[v1.ResourceMemory] = v + } + case core.Resources_STORAGE: + if !v.IsZero() { + k8sResources[v1.ResourceStorage] = v + } + case core.Resources_GPU: + if !v.IsZero() { + k8sResources[ResourceNvidiaGPU] = v + } + } + } + return k8sResources, nil +} + +func ToK8sResourceRequirements(resources *core.Resources) (*v1.ResourceRequirements, error) { + res := &v1.ResourceRequirements{} + if resources == nil { + return res, nil + } + req, err := ToK8sResourceList(resources.Requests) + if err != nil { + return res, err + } + lim, err := ToK8sResourceList(resources.Limits) + if err != nil { + return res, err + } + res.Limits = lim + res.Requests = req + return res, nil +} + +func GetWorkflowIDFromObject(obj metav1.Object) (v1alpha1.WorkflowID, error) { + controller := metav1.GetControllerOf(obj) + if controller == nil { + return "", NotTheOwnerError + } + if controller.Kind == v1alpha1.FlyteWorkflowKind { + return obj.GetNamespace() + "/" + controller.Name, nil + } + return "", NotTheOwnerError +} + +func GetWorkflowIDFromOwner(reference *metav1.OwnerReference, namespace string) (v1alpha1.WorkflowID, error) { + if reference == nil { + return "", NotTheOwnerError + } + if reference.Kind == v1alpha1.FlyteWorkflowKind { + return namespace + "/" + reference.Name, nil + } + return "", NotTheOwnerError +} +func GetProtoTime(t *metav1.Time) *timestamp.Timestamp { + if t != nil { + pTime, err := ptypes.TimestampProto(t.Time) + if err == nil { + return pTime + } + } + return ptypes.TimestampNow() +} diff --git a/pkg/utils/k8s_test.go b/pkg/utils/k8s_test.go new file mode 100644 index 000000000..806d4dd49 --- /dev/null +++ b/pkg/utils/k8s_test.go @@ -0,0 +1,194 @@ +package utils + +import ( + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" + v12 "k8s.io/api/batch/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v13 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestToK8sEnvVar(t *testing.T) { + e := ToK8sEnvVar([]*core.KeyValuePair{ + {Key: "k1", Value: "v1"}, + {Key: "k2", Value: "v2"}, + }) + + assert.NotEmpty(t, e) + assert.Equal(t, []v1.EnvVar{ + {Name: "k1", Value: "v1"}, + {Name: "k2", Value: "v2"}, + }, e) + + e = ToK8sEnvVar(nil) + assert.Empty(t, e) +} + +func TestToK8sResourceList(t *testing.T) { + { + r, err := ToK8sResourceList([]*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_GPU, Value: "1"}, + {Name: core.Resources_MEMORY, Value: "1024Mi"}, + {Name: core.Resources_STORAGE, Value: "1024Mi"}, + }) + + assert.NoError(t, err) + assert.NotEmpty(t, r) + assert.NotNil(t, r[v1.ResourceCPU]) + assert.Equal(t, resource.MustParse("250m"), r[v1.ResourceCPU]) + assert.Equal(t, resource.MustParse("1"), r[ResourceNvidiaGPU]) + assert.Equal(t, resource.MustParse("1024Mi"), r[v1.ResourceMemory]) + assert.Equal(t, resource.MustParse("1024Mi"), r[v1.ResourceStorage]) + } + { + r, err := ToK8sResourceList([]*core.Resources_ResourceEntry{}) + assert.NoError(t, err) + assert.Empty(t, r) + } + { + _, err := ToK8sResourceList([]*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250x"}, + }) + assert.Error(t, err) + } + +} + +func TestToK8sResourceRequirements(t *testing.T) { + + { + r, err := ToK8sResourceRequirements(nil) + assert.NoError(t, err) + assert.NotNil(t, r) + assert.Empty(t, r.Limits) + assert.Empty(t, r.Requests) + } + { + r, err := ToK8sResourceRequirements(&core.Resources{ + Requests: nil, + Limits: nil, + }) + assert.NoError(t, err) + assert.NotNil(t, r) + assert.Empty(t, r.Limits) + assert.Empty(t, r.Requests) + } + { + r, err := ToK8sResourceRequirements(&core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, r) + assert.Equal(t, resource.MustParse("250m"), r.Requests[v1.ResourceCPU]) + assert.Equal(t, resource.MustParse("1024m"), r.Limits[v1.ResourceCPU]) + } + { + _, err := ToK8sResourceRequirements(&core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "blah"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + }) + assert.Error(t, err) + } + { + _, err := ToK8sResourceRequirements(&core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "blah"}, + }, + }) + assert.Error(t, err) + } +} + +func TestGetWorkflowIDFromObject(t *testing.T) { + { + b := true + j := &v12.Job{ + ObjectMeta: v13.ObjectMeta{ + Namespace: "ns", + OwnerReferences: []v13.OwnerReference{ + { + APIVersion: "test", + Kind: v1alpha1.FlyteWorkflowKind, + Name: "my-id", + UID: "blah", + BlockOwnerDeletion: &b, + Controller: &b, + }, + }, + }, + } + w, err := GetWorkflowIDFromObject(j) + assert.NoError(t, err) + assert.Equal(t, "ns/my-id", w) + } + { + b := true + j := &v12.Job{ + ObjectMeta: v13.ObjectMeta{ + Namespace: "ns", + OwnerReferences: []v13.OwnerReference{ + { + APIVersion: "test", + Kind: "some-other", + Name: "my-id", + UID: "blah", + BlockOwnerDeletion: &b, + Controller: &b, + }, + }, + }, + } + _, err := GetWorkflowIDFromObject(j) + assert.Error(t, err) + } + +} + +func TestGetProtoTime(t *testing.T) { + assert.NotNil(t, GetProtoTime(nil)) + n := time.Now() + nproto, err := ptypes.TimestampProto(n) + assert.NoError(t, err) + assert.Equal(t, nproto, GetProtoTime(&metav1.Time{Time: n})) +} + +func TestGetWorkflowIDFromOwner(t *testing.T) { + tests := []struct { + name string + reference *metav1.OwnerReference + namespace string + expectedOwner string + expectedErr error + }{ + {"nilReference", nil, "", "", NotTheOwnerError}, + {"badReference", &metav1.OwnerReference{Kind: "x"}, "", "", NotTheOwnerError}, + {"wfReference", &metav1.OwnerReference{Kind: v1alpha1.FlyteWorkflowKind, Name: "x"}, "ns", "ns/x", nil}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + o, e := GetWorkflowIDFromOwner(test.reference, test.namespace) + assert.Equal(t, test.expectedOwner, o) + assert.Equal(t, test.expectedErr, e) + }) + } +} diff --git a/pkg/utils/literals.go b/pkg/utils/literals.go new file mode 100644 index 000000000..d882a3f14 --- /dev/null +++ b/pkg/utils/literals.go @@ -0,0 +1,276 @@ +package utils + +import ( + "reflect" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/pkg/errors" +) + +func MakePrimitive(v interface{}) (*core.Primitive, error) { + switch p := v.(type) { + case int: + return &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: int64(p), + }, + }, nil + case int64: + return &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: p, + }, + }, nil + case float64: + return &core.Primitive{ + Value: &core.Primitive_FloatValue{ + FloatValue: p, + }, + }, nil + case time.Time: + t, err := ptypes.TimestampProto(p) + if err != nil { + return nil, err + } + return &core.Primitive{ + Value: &core.Primitive_Datetime{ + Datetime: t, + }, + }, nil + case time.Duration: + d := ptypes.DurationProto(p) + return &core.Primitive{ + Value: &core.Primitive_Duration{ + Duration: d, + }, + }, nil + case string: + return &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: p, + }, + }, nil + case bool: + return &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: p, + }, + }, nil + } + return nil, errors.Errorf("Failed to convert to a known primitive type. Input Type [%v] not supported", reflect.TypeOf(v).String()) +} + +func MustMakePrimitive(v interface{}) *core.Primitive { + f, err := MakePrimitive(v) + if err != nil { + panic(err) + } + return f +} + +func MakePrimitiveLiteral(v interface{}) (*core.Literal, error) { + p, err := MakePrimitive(v) + if err != nil { + return nil, err + } + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: p, + }, + }, + }, + }, nil +} + +func MustMakePrimitiveLiteral(v interface{}) *core.Literal { + p, err := MakePrimitiveLiteral(v) + if err != nil { + panic(err) + } + return p +} + +func MakeLiteralForMap(v map[string]interface{}) (*core.Literal, error) { + m, err := MakeLiteralMap(v) + if err != nil { + return nil, err + } + + return &core.Literal{ + Value: &core.Literal_Map{ + Map: m, + }, + }, nil +} + +func MakeLiteralForCollection(v []interface{}) (*core.Literal, error) { + literals := make([]*core.Literal, 0, len(v)) + for _, val := range v { + l, err := MakeLiteral(val) + if err != nil { + return nil, err + } + + literals = append(literals, l) + } + + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literals, + }, + }, + }, nil +} + +func MakeBinaryLiteral(v []byte) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: v, + }, + }, + }, + }, + } +} + +func MakeLiteral(v interface{}) (*core.Literal, error) { + if v == nil { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_NoneType{ + NoneType: nil, + }, + }, + }, + }, nil + } + switch o := v.(type) { + case *core.Literal: + return o, nil + case []interface{}: + return MakeLiteralForCollection(o) + case map[string]interface{}: + return MakeLiteralForMap(o) + case []byte: + return MakeBinaryLiteral(v.([]byte)), nil + default: + return MakePrimitiveLiteral(o) + } +} + +func MustMakeDefaultLiteralForType(typ *core.LiteralType) *core.Literal { + if res, err := MakeDefaultLiteralForType(typ); err != nil { + panic(err) + } else { + return res + } +} + +func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { + switch t := typ.GetType().(type) { + case *core.LiteralType_Simple: + switch t.Simple { + case core.SimpleType_NONE: + return MakeLiteral(nil) + case core.SimpleType_INTEGER: + return MakeLiteral(int(0)) + case core.SimpleType_FLOAT: + return MakeLiteral(float64(0)) + case core.SimpleType_STRING: + return MakeLiteral("") + case core.SimpleType_BOOLEAN: + return MakeLiteral(false) + case core.SimpleType_DATETIME: + return MakeLiteral(time.Now()) + case core.SimpleType_DURATION: + return MakeLiteral(time.Second) + case core.SimpleType_BINARY: + return MakeLiteral([]byte{}) + //case core.SimpleType_WAITABLE: + //case core.SimpleType_ERROR: + } + return nil, errors.Errorf("Not yet implemented. Default creation is not yet implemented. ") + + case *core.LiteralType_Blob: + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Blob{ + Blob: &core.Blob{ + Metadata: &core.BlobMetadata{ + Type: t.Blob, + }, + Uri: "/tmp/somepath", + }, + }, + }, + }, + }, nil + case *core.LiteralType_CollectionType: + single, err := MakeDefaultLiteralForType(t.CollectionType) + if err != nil { + return nil, err + } + + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{single}, + }, + }, + }, nil + case *core.LiteralType_MapValueType: + single, err := MakeDefaultLiteralForType(t.MapValueType) + if err != nil { + return nil, err + } + + return &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "itemKey": single, + }, + }, + }, + }, nil + //case *core.LiteralType_Schema: + } + + return nil, errors.Errorf("Failed to convert to a known Literal. Input Type [%v] not supported", typ.String()) +} + +func MustMakeLiteral(v interface{}) *core.Literal { + p, err := MakeLiteral(v) + if err != nil { + panic(err) + } + + return p +} + +func MakeLiteralMap(v map[string]interface{}) (*core.LiteralMap, error) { + + literals := make(map[string]*core.Literal, len(v)) + for key, val := range v { + l, err := MakeLiteral(val) + if err != nil { + return nil, err + } + + literals[key] = l + } + + return &core.LiteralMap{ + Literals: literals, + }, nil +} diff --git a/pkg/utils/literals_test.go b/pkg/utils/literals_test.go new file mode 100644 index 000000000..05dd11ff7 --- /dev/null +++ b/pkg/utils/literals_test.go @@ -0,0 +1,200 @@ +package utils + +import ( + "reflect" + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestMakePrimitive(t *testing.T) { + { + v := 1 + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.Value).String()) + assert.Equal(t, int64(v), p.GetInteger()) + } + { + v := int64(1) + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetInteger()) + } + { + v := 1.0 + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetFloatValue()) + } + { + v := "blah" + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetStringValue()) + } + { + v := true + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Boolean", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetBoolean()) + } + { + v := time.Now() + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Datetime", reflect.TypeOf(p.Value).String()) + j, err := ptypes.TimestampProto(v) + assert.NoError(t, err) + assert.Equal(t, j, p.GetDatetime()) + } + { + v := time.Second * 10 + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Duration", reflect.TypeOf(p.Value).String()) + assert.Equal(t, ptypes.DurationProto(v), p.GetDuration()) + } + { + v := struct { + }{} + _, err := MakePrimitive(v) + assert.Error(t, err) + } +} + +func TestMustMakePrimitive(t *testing.T) { + { + v := struct { + }{} + assert.Panics(t, func() { + MustMakePrimitive(v) + }) + } + { + v := time.Second * 10 + p := MustMakePrimitive(v) + assert.Equal(t, "*core.Primitive_Duration", reflect.TypeOf(p.Value).String()) + assert.Equal(t, ptypes.DurationProto(v), p.GetDuration()) + } +} + +func TestMakePrimitiveLiteral(t *testing.T) { + { + v := 1.0 + p, err := MakePrimitiveLiteral(v) + assert.NoError(t, err) + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + } + { + v := struct { + }{} + _, err := MakePrimitiveLiteral(v) + assert.Error(t, err) + } +} + +func TestMustMakePrimitiveLiteral(t *testing.T) { + t.Run("Panic", func(t *testing.T) { + v := struct { + }{} + assert.Panics(t, func() { + MustMakePrimitiveLiteral(v) + }) + }) + t.Run("FloatValue", func(t *testing.T) { + v := 1.0 + p := MustMakePrimitiveLiteral(v) + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + }) +} + +func TestMakeLiteral(t *testing.T) { + t.Run("Primitive", func(t *testing.T) { + lit, err := MakeLiteral("test_string") + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(lit.GetScalar().GetPrimitive().Value).String()) + }) + + t.Run("Array", func(t *testing.T) { + lit, err := MakeLiteral([]interface{}{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, "*core.Literal_Collection", reflect.TypeOf(lit.GetValue()).String()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(lit.GetCollection().Literals[0].GetScalar().GetPrimitive().Value).String()) + }) + + t.Run("Map", func(t *testing.T) { + lit, err := MakeLiteral(map[string]interface{}{ + "key1": []interface{}{1, 2, 3}, + "key2": []interface{}{5}, + }) + assert.NoError(t, err) + assert.Equal(t, "*core.Literal_Map", reflect.TypeOf(lit.GetValue()).String()) + assert.Equal(t, "*core.Literal_Collection", reflect.TypeOf(lit.GetMap().Literals["key1"].GetValue()).String()) + }) + + t.Run("Binary", func(t *testing.T) { + s := MakeBinaryLiteral([]byte{'h'}) + assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) + }) + + t.Run("NoneType", func(t *testing.T) { + p, err := MakeLiteral(nil) + assert.NoError(t, err) + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Scalar_NoneType", reflect.TypeOf(p.GetScalar().Value).String()) + }) +} + +func TestMustMakeLiteral(t *testing.T) { + v := "hello" + l := MustMakeLiteral(v) + assert.NotNil(t, l.GetScalar()) + assert.Equal(t, v, l.GetScalar().GetPrimitive().GetStringValue()) +} + +func TestMakeBinaryLiteral(t *testing.T) { + s := MakeBinaryLiteral([]byte{'h'}) + assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) +} + +func TestMakeDefaultLiteralForType(t *testing.T) { + + tests := [][]interface{}{ + {"Integer", core.SimpleType_INTEGER, "*core.Primitive_Integer"}, + {"Float", core.SimpleType_FLOAT, "*core.Primitive_FloatValue"}, + {"String", core.SimpleType_STRING, "*core.Primitive_StringValue"}, + {"Boolean", core.SimpleType_BOOLEAN, "*core.Primitive_Boolean"}, + {"Duration", core.SimpleType_DURATION, "*core.Primitive_Duration"}, + {"Datetime", core.SimpleType_DATETIME, "*core.Primitive_Datetime"}, + } + + for i := range tests { + name := tests[i][0].(string) + ty := tests[i][1].(core.SimpleType) + tyName := tests[i][2].(string) + + t.Run(name, func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{Simple: ty}}) + assert.NoError(t, err) + assert.Equal(t, tyName, reflect.TypeOf(l.GetScalar().GetPrimitive().Value).String()) + }) + } + + t.Run("Binary", func(t *testing.T) { + s, err := MakeLiteral([]byte{'h'}) + assert.NoError(t, err) + assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) + }) +} diff --git a/pkg/visualize/nodeq.go b/pkg/visualize/nodeq.go new file mode 100644 index 000000000..5b479a079 --- /dev/null +++ b/pkg/visualize/nodeq.go @@ -0,0 +1,35 @@ +package visualize + +import "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +type NodeQ []v1alpha1.NodeID + +func (s *NodeQ) Enqueue(items ...v1alpha1.NodeID) { + *s = append(*s, items...) +} + +func (s NodeQ) HasNext() bool { + return len(s) > 0 +} + +func (s NodeQ) Remaining() int { + return len(s) +} + +func (s *NodeQ) Peek() v1alpha1.NodeID { + if s.HasNext() { + return (*s)[0] + } + + return "" +} + +func (s *NodeQ) Deque() v1alpha1.NodeID { + item := s.Peek() + *s = (*s)[1:] + return item +} + +func NewNodeNameQ(items ...v1alpha1.NodeID) NodeQ { + return NodeQ(items) +} diff --git a/pkg/visualize/sort.go b/pkg/visualize/sort.go new file mode 100644 index 000000000..7b4f62559 --- /dev/null +++ b/pkg/visualize/sort.go @@ -0,0 +1,72 @@ +package visualize + +import ( + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/pkg/errors" +) + +type VisitStatus int8 + +const ( + NotVisited VisitStatus = iota + Visited + Completed +) + +type NodeVisitor map[v1alpha1.NodeID]VisitStatus + +func NewNodeVisitor(nodes []v1alpha1.NodeID) NodeVisitor { + v := make(NodeVisitor, len(nodes)) + for _, n := range nodes { + v[n] = NotVisited + } + return v +} + +func tsortHelper(g v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, visited NodeVisitor, reverseSortedNodes *[]v1alpha1.ExecutableNode) error { + if visited[currentNode.GetID()] == NotVisited { + visited[currentNode.GetID()] = Visited + defer func() { + visited[currentNode.GetID()] = Completed + }() + nodes, err := g.FromNode(currentNode.GetID()) + if err != nil { + return err + } + for _, childID := range nodes { + child, ok := g.GetNode(childID) + if !ok { + return errors.Errorf("Unable to find Node [%s] in Workflow [%s]", childID, g.GetID()) + } + if err := tsortHelper(g, child, visited, reverseSortedNodes); err != nil { + return err + } + } + + *reverseSortedNodes = append(*reverseSortedNodes, currentNode) + return nil + } + // Node was successfully visited previously + if visited[currentNode.GetID()] == Completed { + return nil + } + // Node was partially visited and we are in the subgraph and that reached back to the parent + return errors.Errorf("Cycle detected. Node [%v]", currentNode.GetID()) +} + +func reverseSlice(sl []v1alpha1.ExecutableNode) []v1alpha1.ExecutableNode { + for i := len(sl)/2 - 1; i >= 0; i-- { + opp := len(sl) - 1 - i + sl[i], sl[opp] = sl[opp], sl[i] + } + return sl +} + +func TopologicalSort(g v1alpha1.ExecutableWorkflow) ([]v1alpha1.ExecutableNode, error) { + reverseSortedNodes := make([]v1alpha1.ExecutableNode, 0, 25) + visited := NewNodeVisitor(g.GetNodes()) + if err := tsortHelper(g, g.StartNode(), visited, &reverseSortedNodes); err != nil { + return nil, err + } + return reverseSlice(reverseSortedNodes), nil +} diff --git a/pkg/visualize/visualize.go b/pkg/visualize/visualize.go new file mode 100644 index 000000000..358f6e5cb --- /dev/null +++ b/pkg/visualize/visualize.go @@ -0,0 +1,244 @@ +package visualize + +import ( + "fmt" + "reflect" + "strings" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "k8s.io/apimachinery/pkg/util/sets" +) + +const executionEdgeLabel = "execution" + +type edgeStyle = string + +const ( + styleSolid edgeStyle = "solid" + styleDashed edgeStyle = "dashed" +) + +const staticNodeID = "static" + +func flatten(binding *core.BindingData, flatMap map[common.NodeID]sets.String) { + switch binding.GetValue().(type) { + case *core.BindingData_Collection: + for _, v := range binding.GetCollection().GetBindings() { + flatten(v, flatMap) + } + case *core.BindingData_Map: + for _, v := range binding.GetMap().GetBindings() { + flatten(v, flatMap) + } + case *core.BindingData_Promise: + if _, ok := flatMap[binding.GetPromise().NodeId]; !ok { + flatMap[binding.GetPromise().NodeId] = sets.String{} + } + + flatMap[binding.GetPromise().NodeId].Insert(binding.GetPromise().GetVar()) + case *core.BindingData_Scalar: + if _, ok := flatMap[staticNodeID]; !ok { + flatMap[staticNodeID] = sets.NewString() + } + } +} + +// Returns GraphViz https://www.graphviz.org/ representation of the current state of the state machine. +func WorkflowToGraphViz(g *v1alpha1.FlyteWorkflow) string { + res := fmt.Sprintf("digraph G {rankdir=TB;workflow[label=\"Workflow Id: %v\"];node[style=filled];", + g.ID) + + nodeFinder := func(nodeId common.NodeID) *v1alpha1.NodeSpec { + for _, n := range g.Nodes { + if n.ID == nodeId { + return n + } + } + + return nil + } + + nodeLabel := func(nodeId common.NodeID) string { + node := nodeFinder(nodeId) + return fmt.Sprintf("%v(%v)", node.ID, node.Kind) + } + + edgeLabel := func(nodeFromId, nodeToId common.NodeID) string { + flatMap := make(map[common.NodeID]sets.String) + nodeFrom := nodeFinder(nodeFromId) + nodeTo := nodeFinder(nodeToId) + for _, binding := range nodeTo.GetInputBindings() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[nodeFrom.ID]; found { + return strings.Join(vars.List(), ",") + } else if vars, found := flatMap[""]; found && nodeFromId == common.StartNodeID { + return strings.Join(vars.List(), ",") + } else { + return executionEdgeLabel + } + } + + style := func(edgeLabel string) string { + if edgeLabel == executionEdgeLabel { + return styleDashed + } + + return styleSolid + } + + start := nodeFinder(common.StartNodeID) + res += fmt.Sprintf("\"%v\" [shape=Msquare];", nodeLabel(start.ID)) + visitedNodes := sets.NewString(start.ID) + createdEdges := sets.NewString() + + for nodesToVisit := NewNodeNameQ(start.ID); nodesToVisit.HasNext(); { + node := nodesToVisit.Deque() + nodes, found := g.GetConnections().DownstreamEdges[node] + if found { + nodesToVisit.Enqueue(nodes...) + + for _, child := range nodes { + label := edgeLabel(node, child) + edge := fmt.Sprintf("\"%v\" -> \"%v\" [label=\"%v\",style=\"%v\"];", + nodeLabel(node), + nodeLabel(child), + label, + style(label), + ) + + if !createdEdges.Has(edge) { + res += edge + createdEdges.Insert(edge) + } + } + } + + // add static bindings' links + flatMap := make(common.StringAdjacencyList) + n := nodeFinder(node) + for _, binding := range n.GetInputBindings() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[staticNodeID]; found { + res += fmt.Sprintf("\"static\" -> \"%v\" [label=\"%v\"];", + nodeLabel(node), + strings.Join(vars.List(), ","), + ) + } + + visitedNodes.Insert(node) + } + + res += "}" + + return res +} + +func ToGraphViz(g *core.CompiledWorkflow) string { + res := fmt.Sprintf("digraph G {rankdir=TB;workflow[label=\"Workflow Id: %v\"];node[style=filled];", + g.Template.GetId()) + + nodeFinder := func(nodeId common.NodeID) *core.Node { + for _, n := range g.Template.Nodes { + if n.Id == nodeId { + return n + } + } + + return nil + } + + nodeKind := func(nodeId common.NodeID) string { + node := nodeFinder(nodeId) + if nodeId == common.StartNodeID { + return "start" + } else if nodeId == common.EndNodeID { + return "end" + } else { + return reflect.TypeOf(node.GetTarget()).Name() + } + } + + nodeLabel := func(nodeId common.NodeID) string { + node := nodeFinder(nodeId) + return fmt.Sprintf("%v(%v)", node.GetId(), nodeKind(nodeId)) + } + + edgeLabel := func(nodeFromId, nodeToId common.NodeID) string { + flatMap := make(map[common.NodeID]sets.String) + nodeFrom := nodeFinder(nodeFromId) + nodeTo := nodeFinder(nodeToId) + for _, binding := range nodeTo.GetInputs() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[nodeFrom.GetId()]; found { + return strings.Join(vars.List(), ",") + } else if vars, found := flatMap[""]; found && nodeFromId == common.StartNodeID { + return strings.Join(vars.List(), ",") + } else { + return executionEdgeLabel + } + } + + style := func(edgeLabel string) string { + if edgeLabel == executionEdgeLabel { + return styleDashed + } + + return styleSolid + } + + start := nodeFinder(common.StartNodeID) + res += fmt.Sprintf("\"%v\" [shape=Msquare];", nodeLabel(start.GetId())) + visitedNodes := sets.NewString(start.GetId()) + createdEdges := sets.NewString() + + for nodesToVisit := NewNodeNameQ(start.GetId()); nodesToVisit.HasNext(); { + node := nodesToVisit.Deque() + nodes, found := g.GetConnections().GetDownstream()[node] + if found { + nodesToVisit.Enqueue(nodes.Ids...) + + for _, child := range nodes.Ids { + label := edgeLabel(node, child) + edge := fmt.Sprintf("\"%v\" -> \"%v\" [label=\"%v\",style=\"%v\"];", + nodeLabel(node), + nodeLabel(child), + label, + style(label), + ) + + if !createdEdges.Has(edge) { + res += edge + createdEdges.Insert(edge) + } + } + } + + // add static bindings' links + flatMap := make(common.StringAdjacencyList) + n := nodeFinder(node) + for _, binding := range n.GetInputs() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[staticNodeID]; found { + res += fmt.Sprintf("\"static\" -> \"%v\" [label=\"%v\"];", + nodeLabel(node), + strings.Join(vars.List(), ","), + ) + } + + visitedNodes.Insert(node) + } + + res += "}" + + return res +} diff --git a/raw_examples/README.md b/raw_examples/README.md new file mode 100644 index 000000000..dad5d2fbc --- /dev/null +++ b/raw_examples/README.md @@ -0,0 +1,3 @@ +The intention of these examples is to test basic functionality of propeller. +Usually users should be using the flyteidl interface to interact with propeller +through flytectl diff --git a/raw_examples/example-condition.yaml b/raw_examples/example-condition.yaml new file mode 100644 index 000000000..3af1f14b1 --- /dev/null +++ b/raw_examples/example-condition.yaml @@ -0,0 +1,104 @@ +apiVersion: flyte.lyft.com/v1alpha1 +kind: FlyteWorkflow +metadata: + name: test-branch + namespace: default +tasks: + foo: + id: foo + category: 0 + type: container + metadata: + runtime: + type: 0 + version: "1.18.0" + flavor: python + discoverable: true + interface: + inputs: + - name: x + type: + simple: INTEGER + outputs: + - name: "y" + type: + simple: INTEGER + container: + image: alpine + command: ["echo", "Hello", "{{$input}}", "{{$output}}"] +spec: + id: test-branch + nodes: + start: + id: start + kind: start + input_bindings: + - var: x + binding: + scalar: + primitive: + integer: 5 + - var: "y" + binding: + scalar: + primitive: + integer: 10 + foo1: + id: foo1 + kind: task + input_bindings: + - var: x + binding: + promise: + node_id: start + var: "y" + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + foo2: + id: foo2 + kind: task + input_bindings: + - var: x + binding: + promise: + node_id: start + var: x + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + foobranch: + id: foobranch + kind: branch + input_bindings: + - var: x + binding: + promise: + node_id: start + var: x + branch_node: + if: + condition: + comparison: + left_value: + var: x + operator: GT + right_value: + primitive: + integer: 5 + then: foo1 + else: foo2 + connections: + start: + - foobranch + - foo1 + - foo2 + foobranch: + - foo1 + - foo2 diff --git a/raw_examples/example-inputs.yaml b/raw_examples/example-inputs.yaml new file mode 100644 index 000000000..b3a641cc7 --- /dev/null +++ b/raw_examples/example-inputs.yaml @@ -0,0 +1,61 @@ +apiVersion: flyte.lyft.com/v1alpha1 +kind: FlyteWorkflow +metadata: + name: test-wf-inputs + namespace: default +tasks: + foo: + id: foo + category: 0 + type: container + metadata: + name: foo + runtime: + # Enums are ints + type: 0 + version: "1.18.0" + flavor: python + discoverable: true + description: "Test Task" + interface: + inputs: + - name: x + type: + simple: INTEGER + outputs: + - name: "y" + type: + simple: INTEGER + container: + image: alpine + command: ["echo", "Hello", "{{$input}}", "{{$output}}"] +spec: + id: test-wf + nodes: + start: + id: start + kind: start + input_bindings: + - var: x + binding: + scalar: + primitive: + integer: 5 + foo1: + id: foo1 + kind: task + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + input_bindings: + - var: x + binding: + promise: + node_id: start + var: x + connections: + start: + - foo1 diff --git a/raw_examples/example-noinputs.yaml b/raw_examples/example-noinputs.yaml new file mode 100644 index 000000000..719b582dc --- /dev/null +++ b/raw_examples/example-noinputs.yaml @@ -0,0 +1,41 @@ +apiVersion: flyte.lyft.com/v1alpha1 +kind: FlyteWorkflow +metadata: + name: test-fg + namespace: default +tasks: + foo: + id: foo + # Enums are ints + category: 0 + type: container + metadata: + name: foo + runtime: + # Enums are ints + type: 0 + version: "1.18.0" + flavor: python + discoverable: true + description: "Test Task" + container: + image: alpine + command: ["ls", "${{inputs}}", "${{outputs}}"] +spec: + id: test-wf + nodes: + start: + id: start + kind: start + foo1: + id: foo1 + kind: task + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + connections: + start: + - foo1