diff --git a/.circleci/config.yml b/.circleci/config.yml
index 1964d79609..90298dbca9 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -70,7 +70,7 @@ executors:
parameters:
toolchain_version:
type: string
- default: '{{ checksum ".circleci/config.yml" }}-{{ checksum "~/.arch" }}-{{ checksum "rust-toolchain.toml" }}-{{ checksum "~/.daily_version" }}'
+ default: '{{ checksum ".circleci/config.yml" }}-v2-{{ checksum "~/.arch" }}-{{ checksum "rust-toolchain.toml" }}-{{ checksum "~/.daily_version" }}'
xtask_version:
type: string
default: '{{ checksum ".circleci/config.yml" }}-{{ checksum "~/.arch" }}-{{ checksum "rust-toolchain.toml" }}-{{ checksum "~/.xtask_version" }}'
@@ -91,7 +91,7 @@ parameters:
# forks of the project to run their own tests on their own CircleCI deployments with no
# additional configuration.
common_job_environment: &common_job_environment
- CARGO_NET_GIT_FETCH_WITH_CLI: true
+ CARGO_NET_GIT_FETCH_WITH_CLI: "true"
RUST_BACKTRACE: full
CARGO_INCREMENTAL: 0
commands:
@@ -107,7 +107,7 @@ commands:
- restore_cache:
keys:
- "<< pipeline.parameters.toolchain_version >>"
- - install_debian_packages:
+ - install_packages:
platform: << parameters.platform >>
- install_protoc:
platform: << parameters.platform >>
@@ -205,7 +205,7 @@ commands:
echo "${CIRCLE_PROJECT_REPONAME}-${COMMON_ANCESTOR_REF}" > ~/.merge_version
# Linux specific step to install packages that are needed
- install_debian_packages:
+ install_packages:
parameters:
platform:
type: executor
@@ -222,11 +222,10 @@ commands:
name: Update and install dependencies
command: |
if [[ ! -d "$HOME/.deb" ]]; then
- mkdir ~/.deb
+ mkdir $HOME/.deb
sudo apt-get --download-only -o Dir::Cache="$HOME/.deb" -o Dir::Cache::archives="$HOME/.deb" install libssl-dev libdw-dev cmake
fi
- sudo dpkg -i ~/.deb/*.deb
-
+ sudo dpkg -i $HOME/.deb/*.deb
install_protoc:
parameters:
platform:
@@ -762,4 +761,4 @@ workflows:
branches:
ignore: /.*/
tags:
- only: /v.*/
+ only: /v.*/
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 23ac088772..fd79245f9b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,165 @@ All notable changes to Router will be documented in this file.
This project adheres to [Semantic Versioning v2.0.0](https://semver.org/spec/v2.0.0.html).
+# [1.22.0] - 2023-06-21
+
+## 🚀 Features
+
+### Federated Subscriptions ([PR #3285](https://github.com/apollographql/router/pull/3285))
+
+> ⚠️ **This is an [Enterprise feature](https://www.apollographql.com/blog/platform/evaluating-apollo-router-understanding-free-and-open-vs-commercial-features/) of the Apollo Router.** It requires an organization with a [GraphOS Enterprise plan](https://www.apollographql.com/pricing/).
+>
+> If your organization _doesn't_ currently have an Enterprise plan, you can test out this functionality by signing up for a free [Enterprise trial](https://www.apollographql.com/docs/graphos/org/plans/#enterprise-trials).
+
+
+#### High-Level Overview
+
+##### What are Federated Subscriptions?
+
+This PR adds GraphQL subscription support to the Router for use with Federation. Clients can now use GraphQL subscriptions with the Router to receive realtime updates from a supergraph. With these changes, `subscription` operations are now a first-class supported feature of the Router and Federation, alongside queries and mutations.
+
+```mermaid
+flowchart LR;
+ client(Client);
+ subgraph "Your infrastructure";
+ router(["Apollo Router"]);
+ subgraphA[Products subgraph];
+ subgraphB[Reviews subgraph];
+ router---|Subscribes over WebSocket|subgraphA;
+ router-.-|Can query for entity fields|subgraphB;
+ end;
+ client---|Subscribes over HTTP|router;
+ class client secondary;
+```
+
+##### Client to Router Communication
+
+- Apollo has designed and implemented a new open protocol for handling subscriptions called [multipart subscriptions](https://github.com/apollographql/router/blob/dev/dev-docs/multipart-subscriptions-protocol.md)
+- With this new protocol clients can manage subscriptions with the Router over tried and true HTTP; WebSockets, SSE (server-sent events), etc. are not needed
+- All Apollo clients ([Apollo Client web](https://www.apollographql.com/docs/react/data/subscriptions), [Apollo Kotlin](https://www.apollographql.com/docs/kotlin/essentials/subscriptions), [Apollo iOS](https://www.apollographql.com/docs/ios/fetching/subscriptions)) have been updated to support multipart subscriptions, and can be used out of the box with little to no extra configuration
+- Subscription communication between clients and the Router must use the multipart subscription protocol, meaning only subscriptions over HTTP are supported at this time
+
+##### Router to Subgraph Communication
+
+- The Router communicates with subscription enabled subgraphs using WebSockets
+- By default, the router sends subscription requests to subgraphs using the [graphql-transport-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md) which is implemented in the [graphql-ws](https://github.com/enisdenjo/graphql-ws) library. You can also configure it to use the [graphql-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md) which is implemented in the [subscriptions-transport-ws library](https://github.com/apollographql/subscriptions-transport-ws).
+- Subscription ready subgraphs can be introduced to Federation and the Router as is - no additional configuration is needed on the subgraph side
+
+##### Subscription Execution
+
+When the Router receives a GraphQL subscription request, the generated query plan will contain an initial subscription request to the subgraph that contributed the requested subscription root field.
+
+For example, as a result of a client sending this subscription request to the Router:
+
+```graphql
+subscription {
+ reviewAdded {
+ id
+ body
+ product {
+ id
+ name
+ createdBy {
+ name
+ }
+ }
+ }
+}
+```
+
+The router will send this request to the `reviews` subgraph:
+
+```graphql
+subscription {
+ reviewAdded {
+ id
+ body
+ product {
+ id
+ }
+ }
+}
+```
+
+When the `reviews` subgraph receives new data from its underlying source event stream, that data is sent back to the Router. Once received, the Router continues following the determined query plan to fetch any additional required data from other subgraphs:
+
+Example query sent to the `products` subgraph:
+
+```graphql
+query ($representations: [_Any!]!) {
+ _entities(representations: $representations) {
+ ... on Product {
+ name
+ createdBy {
+ __typename
+ email
+ }
+ }
+ }
+}
+```
+
+Example query sent to the `users` subgraph:
+
+```graphql
+query ($representations: [_Any!]!) {
+ _entities(representations: $representations) {
+ ... on User {
+ name
+ }
+ }
+}
+```
+
+When the Router finishes running the entire query plan, the data is merged back together and returned to the requesting client over HTTP (using the multipart subscriptions protocol).
+
+#### Configuration
+
+Here is a configuration example:
+
+```yaml title="router.yaml"
+subscription:
+ mode:
+ passthrough:
+ all: # The router uses these subscription settings UNLESS overridden per-subgraph
+ path: /subscriptions # The path to use for subgraph subscription endpoints (Default: /ws)
+ subgraphs: # Overrides subscription settings for individual subgraphs
+ reviews: # Overrides settings for the 'reviews' subgraph
+ path: /ws # Overrides '/subscriptions' defined above
+ protocol: graphql_transport_ws # The WebSocket-based protocol to use for subscription communication (Default: graphql_ws)
+```
+
+#### Usage Reporting
+
+Subscription use is tracked in the Router as follows:
+
+- **Subscription registration:** The initial subscription operation sent by a client to the Router that's responsible for starting a new subscription
+- **Subscription notification:** The resolution of the client subscription’s selection set in response to a subscription enabled subgraph source event
+
+Subscription registration and notification (with operation traces and statistics) are sent to Apollo Studio for observability.
+
+#### Advanced Features
+
+This PR includes the following configurable performance optimizations.
+
+#### Deduplication
+
+- If the Router detects that a client is using the same subscription as another client (ie. a subscription with the same HTTP headers and selection set), it will avoid starting a new subscription with the requested subgraph. The Router will reuse the same open subscription instead, and will send the same source events to the new client.
+- This helps reduce the number of WebSockets that need to be opened between the Router and subscription enabled subgraphs, thereby drastically reducing Router to subgraph network traffic and overall latency
+- For example, if 100 clients are subscribed to the same subscription there will be 100 open HTTP connections from the clients to the Router, but only 1 open WebSocket connection from the Router to the subgraph
+- Subscription deduplication between the Router and subgraphs is enabled by default (but can be disabled via the Router config file)
+
+#### Callback Mode
+
+- Instead of sending subscription data between a Router and subgraph over an open WebSocket, the Router can be configured to send the subgraph a callback URL that will then be used to receive all source stream events
+- Subscription enabled subgraphs send source stream events (subscription updates) back to the callback URL by making HTTP POST requests
+- Refer to the [callback mode documentation](https://github.com/apollographql/router/blob/dev/dev-docs/callback_protocol.md) for more details, including an explanation of the callback URL request/response payload format
+- This feature is still experimental and needs to be enabled explicitly in the Router config file
+
+By [@bnjjj](https://github.com/bnjjj) and [@o0Ignition0o](https://github.com/o0ignition0o) in https://github.com/apollographql/router/pull/3285
+
+
+
# [1.21.0] - 2023-06-20
## 🚀 Features
diff --git a/Cargo.lock b/Cargo.lock
index 478f2c9d5f..6cece4c6ea 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -281,7 +281,7 @@ dependencies = [
[[package]]
name = "apollo-router"
-version = "1.21.0"
+version = "1.22.0"
dependencies = [
"access-json",
"anyhow",
@@ -317,6 +317,7 @@ dependencies = [
"graphql_client 0.11.0",
"heck 0.4.1",
"hex",
+ "hmac",
"http",
"http-body",
"http-serde",
@@ -393,6 +394,7 @@ dependencies = [
"tokio",
"tokio-rustls 0.23.4",
"tokio-stream",
+ "tokio-tungstenite",
"tokio-util",
"tonic 0.8.3",
"tonic-build",
@@ -420,7 +422,7 @@ dependencies = [
[[package]]
name = "apollo-router-benchmarks"
-version = "1.21.0"
+version = "1.22.0"
dependencies = [
"apollo-parser 0.4.1",
"apollo-router",
@@ -436,7 +438,7 @@ dependencies = [
[[package]]
name = "apollo-router-scaffold"
-version = "1.21.0"
+version = "1.22.0"
dependencies = [
"anyhow",
"cargo-scaffold",
@@ -645,6 +647,7 @@ checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39"
dependencies = [
"async-trait",
"axum-core",
+ "base64 0.21.2",
"bitflags",
"bytes",
"futures-util",
@@ -663,8 +666,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
+ "sha1 0.10.5",
"sync_wrapper",
"tokio",
+ "tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@@ -6171,6 +6176,22 @@ dependencies = [
"tokio-stream",
]
+[[package]]
+name = "tokio-tungstenite"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
+dependencies = [
+ "futures-util",
+ "log",
+ "rustls 0.20.8",
+ "rustls-native-certs",
+ "tokio",
+ "tokio-rustls 0.23.4",
+ "tungstenite",
+ "webpki",
+]
+
[[package]]
name = "tokio-util"
version = "0.7.8"
@@ -6510,6 +6531,27 @@ dependencies = [
"syn 1.0.109",
]
+[[package]]
+name = "tungstenite"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
+dependencies = [
+ "base64 0.13.1",
+ "byteorder",
+ "bytes",
+ "http",
+ "httparse",
+ "log",
+ "rand 0.8.5",
+ "rustls 0.20.8",
+ "sha1 0.10.5",
+ "thiserror",
+ "url",
+ "utf-8",
+ "webpki",
+]
+
[[package]]
name = "typed-builder"
version = "0.9.1"
@@ -6721,6 +6763,12 @@ dependencies = [
"url",
]
+[[package]]
+name = "utf-8"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
+
[[package]]
name = "utf8parse"
version = "0.2.1"
diff --git a/apollo-router-benchmarks/Cargo.toml b/apollo-router-benchmarks/Cargo.toml
index 8226283491..2ad0ad2885 100644
--- a/apollo-router-benchmarks/Cargo.toml
+++ b/apollo-router-benchmarks/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "apollo-router-benchmarks"
-version = "1.21.0"
+version = "1.22.0"
authors = ["Apollo Graph, Inc. "]
edition = "2021"
license = "Elastic-2.0"
diff --git a/apollo-router-scaffold/Cargo.toml b/apollo-router-scaffold/Cargo.toml
index 8a4f8a3d6d..bc8572a6ee 100644
--- a/apollo-router-scaffold/Cargo.toml
+++ b/apollo-router-scaffold/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "apollo-router-scaffold"
-version = "1.21.0"
+version = "1.22.0"
authors = ["Apollo Graph, Inc. "]
edition = "2021"
license = "Elastic-2.0"
diff --git a/apollo-router-scaffold/templates/base/Cargo.toml b/apollo-router-scaffold/templates/base/Cargo.toml
index 9bdd9df4d7..83cdc5bafb 100644
--- a/apollo-router-scaffold/templates/base/Cargo.toml
+++ b/apollo-router-scaffold/templates/base/Cargo.toml
@@ -22,7 +22,7 @@ apollo-router = { path ="{{integration_test}}apollo-router" }
apollo-router = { git="https://github.com/apollographql/router.git", branch="{{branch}}" }
{{else}}
# Note if you update these dependencies then also update xtask/Cargo.toml
-apollo-router = "1.21.0"
+apollo-router = "1.22.0"
{{/if}}
{{/if}}
async-trait = "0.1.52"
diff --git a/apollo-router-scaffold/templates/base/xtask/Cargo.toml b/apollo-router-scaffold/templates/base/xtask/Cargo.toml
index 0690738259..29204ce3f8 100644
--- a/apollo-router-scaffold/templates/base/xtask/Cargo.toml
+++ b/apollo-router-scaffold/templates/base/xtask/Cargo.toml
@@ -13,7 +13,7 @@ apollo-router-scaffold = { path ="{{integration_test}}apollo-router-scaffold" }
{{#if branch}}
apollo-router-scaffold = { git="https://github.com/apollographql/router.git", branch="{{branch}}" }
{{else}}
-apollo-router-scaffold = { git = "https://github.com/apollographql/router.git", tag = "v1.21.0" }
+apollo-router-scaffold = { git = "https://github.com/apollographql/router.git", tag = "v1.22.0" }
{{/if}}
{{/if}}
anyhow = "1.0.58"
diff --git a/apollo-router/Cargo.toml b/apollo-router/Cargo.toml
index 5c2a00efd9..ff69bb4cb1 100644
--- a/apollo-router/Cargo.toml
+++ b/apollo-router/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "apollo-router"
-version = "1.21.0"
+version = "1.22.0"
authors = ["Apollo Graph, Inc. "]
repository = "https://github.com/apollographql/router/"
documentation = "https://docs.rs/apollo-router"
@@ -215,8 +215,10 @@ urlencoding = "2.1.2"
uuid = { version = "1.3.3", features = ["serde", "v4"] }
yaml-rust = "0.4.5"
wsl = "0.1.0"
+tokio-tungstenite = { version = "0.18.0", features = ["rustls-tls-native-roots"] }
tokio-rustls = "0.23.4"
http-serde = "1.1.2"
+hmac = "0.12.1"
parking_lot = "0.12.1"
memchr = "2.5.0"
brotli = "3.3.4"
@@ -233,6 +235,7 @@ uname = "0.1.1"
tikv-jemallocator = "0.5"
[dev-dependencies]
+axum = { version = "0.6.6", features = ["headers", "json", "original-uri", "ws"] }
ecdsa = { version = "0.15.1", features = ["signing", "pem", "pkcs8"] }
fred = { version = "6.3.0", features = ["enable-rustls", "no-client-setname"] }
futures-test = "0.3.28"
diff --git a/apollo-router/feature_discussions.json b/apollo-router/feature_discussions.json
index c8e2fb9c43..5fc7b56899 100644
--- a/apollo-router/feature_discussions.json
+++ b/apollo-router/feature_discussions.json
@@ -8,4 +8,4 @@
"preview": {
"preview_operation_limits": "https://github.com/apollographql/router/discussions/3040"
}
-}
+}
\ No newline at end of file
diff --git a/apollo-router/src/axum_factory/tests.rs b/apollo-router/src/axum_factory/tests.rs
index 2d01d2779a..017dd70eb1 100644
--- a/apollo-router/src/axum_factory/tests.rs
+++ b/apollo-router/src/axum_factory/tests.rs
@@ -1162,7 +1162,7 @@ async fn it_errors_on_bad_accept_header() -> Result<(), ApolloRouterError> {
);
assert_eq!(
response.text().await.unwrap(),
- r#"{"errors":[{"message":"'accept' header must be one of: \\\"*/*\\\", \"application/json\", \"application/graphql-response+json\" or \"multipart/mixed;boundary=\\\"graphql\\\";deferSpec=20220824\"","extensions":{"code":"INVALID_ACCEPT_HEADER"}}]}"#
+ r#"{"errors":[{"message":"'accept' header must be one of: \\\"*/*\\\", \"application/json\", \"application/graphql-response+json\", \"multipart/mixed;boundary=\\\"graphql\\\";subscriptionSpec=1.0\" or \"multipart/mixed;boundary=\\\"graphql\\\";deferSpec=20220824\"","extensions":{"code":"INVALID_ACCEPT_HEADER"}}]}"#
);
server.shutdown().await
diff --git a/apollo-router/src/configuration/experimental.rs b/apollo-router/src/configuration/experimental.rs
index 06784b4483..fc8011ba83 100644
--- a/apollo-router/src/configuration/experimental.rs
+++ b/apollo-router/src/configuration/experimental.rs
@@ -123,6 +123,9 @@ mod tests {
"sub": {
"experimental_trace_id": "ok"
}
+ },
+ "preview_subscription": {
+
}
});
@@ -133,5 +136,9 @@ mod tests {
"experimental_trace_id".to_string()
]
);
+ assert_eq!(
+ get_configurations(&val, "preview"),
+ vec!["preview_subscription".to_string(),]
+ );
}
}
diff --git a/apollo-router/src/configuration/mod.rs b/apollo-router/src/configuration/mod.rs
index 0d66cd7dca..2fba20f48e 100644
--- a/apollo-router/src/configuration/mod.rs
+++ b/apollo-router/src/configuration/mod.rs
@@ -18,6 +18,8 @@ use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;
+#[cfg(not(test))]
+use std::time::Duration;
use derivative::Derivative;
use displaydoc::Display;
@@ -50,9 +52,21 @@ pub(crate) use self::schema::generate_upgrade;
use self::subgraph::SubgraphConfiguration;
use crate::cache::DEFAULT_CACHE_CAPACITY;
use crate::configuration::schema::Mode;
+use crate::graphql;
+use crate::notification::Notify;
use crate::plugin::plugins;
+#[cfg(not(test))]
+use crate::plugins::subscription::SubscriptionConfig;
+#[cfg(not(test))]
+use crate::plugins::subscription::APOLLO_SUBSCRIPTION_PLUGIN;
+#[cfg(not(test))]
+use crate::plugins::subscription::APOLLO_SUBSCRIPTION_PLUGIN_NAME;
use crate::ApolloRouterError;
+// TODO: Talk it through with the teams
+#[cfg(not(test))]
+static HEARTBEAT_TIMEOUT_DURATION_SECONDS: u64 = 15;
+
static SUPERGRAPH_ENDPOINT_REGEX: Lazy = Lazy::new(|| {
Regex::new(r"(?P.*/)(?P.+)\*$")
.expect("this regex to check the path is valid")
@@ -149,6 +163,9 @@ pub struct Configuration {
#[serde(default)]
#[serde(flatten)]
pub(crate) apollo_plugins: ApolloPlugins,
+
+ #[serde(default, skip_serializing, skip_deserializing)]
+ pub(crate) notify: Notify,
}
impl<'de> serde::Deserialize<'de> for Configuration {
@@ -217,10 +234,24 @@ impl Configuration {
plugins: Map,
apollo_plugins: Map,
tls: Option,
+ notify: Option>,
apq: Option,
operation_limits: Option,
chaos: Option,
) -> Result {
+ #[cfg(not(test))]
+ let notify_queue_cap = match apollo_plugins.get(APOLLO_SUBSCRIPTION_PLUGIN_NAME) {
+ Some(plugin_conf) => {
+ let conf = serde_json::from_value::(plugin_conf.clone())
+ .map_err(|err| ConfigurationError::PluginConfiguration {
+ plugin: APOLLO_SUBSCRIPTION_PLUGIN.to_string(),
+ error: format!("{err:?}"),
+ })?;
+ conf.queue_capacity
+ }
+ None => None,
+ };
+
let conf = Self {
validated_yaml: Default::default(),
supergraph: supergraph.unwrap_or_default(),
@@ -238,6 +269,11 @@ impl Configuration {
plugins: apollo_plugins,
},
tls: tls.unwrap_or_default(),
+ #[cfg(test)]
+ notify: notify.unwrap_or_default(),
+ #[cfg(not(test))]
+ notify: notify.map(|n| n.set_queue_size(notify_queue_cap))
+ .unwrap_or_else(|| Notify::builder().and_queue_size(notify_queue_cap).ttl(Duration::from_secs(HEARTBEAT_TIMEOUT_DURATION_SECONDS)).heartbeat_error_message(graphql::Response::builder().errors(vec![graphql::Error::builder().message("the connection has been closed because it hasn't heartbeat for a while").extension_code("SUBSCRIPTION_HEARTBEAT_ERROR").build()]).build()).build()),
};
conf.validate()
@@ -287,6 +323,7 @@ impl Configuration {
plugins: Map,
apollo_plugins: Map,
tls: Option,
+ notify: Option>,
apq: Option,
operation_limits: Option,
chaos: Option,
@@ -307,6 +344,7 @@ impl Configuration {
plugins: apollo_plugins,
},
tls: tls.unwrap_or_default(),
+ notify: notify.unwrap_or_default(),
apq: apq.unwrap_or_default(),
};
diff --git a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap
index f02fda7160..e919b8cedf 100644
--- a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap
+++ b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap
@@ -1227,6 +1227,149 @@ expression: "&schema"
},
"additionalProperties": false
},
+ "subscription": {
+ "description": "Subscriptions configuration",
+ "type": "object",
+ "properties": {
+ "enable_deduplication": {
+ "description": "Enable the deduplication of subscription (for example if we detect the exact same request to subgraph we won't open a new websocket to the subgraph in passthrough mode) (default: true)",
+ "default": true,
+ "type": "boolean"
+ },
+ "max_opened_subscriptions": {
+ "description": "This is a limit to only have maximum X opened subscriptions at the same time. By default if it's not set there is no limit.",
+ "default": null,
+ "type": "integer",
+ "format": "uint",
+ "minimum": 0.0,
+ "nullable": true
+ },
+ "mode": {
+ "description": "Select a subscription mode (callback or passthrough)",
+ "default": {
+ "preview_callback": null,
+ "passthrough": null
+ },
+ "type": "object",
+ "properties": {
+ "passthrough": {
+ "description": "Enable passthrough mode for subgraph(s)",
+ "type": "object",
+ "properties": {
+ "all": {
+ "description": "Configuration for all subgraphs",
+ "default": null,
+ "type": "object",
+ "properties": {
+ "path": {
+ "description": "Path on which WebSockets are listening",
+ "default": null,
+ "type": "string",
+ "nullable": true
+ },
+ "protocol": {
+ "description": "Which WebSocket GraphQL protocol to use for this subgraph possible values are: 'graphql_ws' | 'graphql_transport_ws' (default: graphql_ws)",
+ "default": "graphql_ws",
+ "type": "string",
+ "enum": [
+ "graphql_ws",
+ "graphql_transport_ws"
+ ]
+ }
+ },
+ "additionalProperties": false,
+ "nullable": true
+ },
+ "subgraphs": {
+ "description": "Configuration for specific subgraphs",
+ "default": {},
+ "type": "object",
+ "additionalProperties": {
+ "description": "WebSocket configuration for a specific subgraph",
+ "type": "object",
+ "properties": {
+ "path": {
+ "description": "Path on which WebSockets are listening",
+ "default": null,
+ "type": "string",
+ "nullable": true
+ },
+ "protocol": {
+ "description": "Which WebSocket GraphQL protocol to use for this subgraph possible values are: 'graphql_ws' | 'graphql_transport_ws' (default: graphql_ws)",
+ "default": "graphql_ws",
+ "type": "string",
+ "enum": [
+ "graphql_ws",
+ "graphql_transport_ws"
+ ]
+ }
+ },
+ "additionalProperties": false
+ }
+ }
+ },
+ "additionalProperties": false,
+ "nullable": true
+ },
+ "preview_callback": {
+ "description": "Enable callback mode for subgraph(s)",
+ "type": "object",
+ "required": [
+ "public_url"
+ ],
+ "properties": {
+ "listen": {
+ "description": "Listen address on which the callback must listen (default: 127.0.0.1:4000)",
+ "writeOnly": true,
+ "anyOf": [
+ {
+ "description": "Socket address.",
+ "type": "string"
+ },
+ {
+ "description": "Unix socket.",
+ "type": "string"
+ }
+ ],
+ "nullable": true
+ },
+ "path": {
+ "description": "Specify on which path you want to listen for callbacks (default: /callback)",
+ "writeOnly": true,
+ "type": "string",
+ "nullable": true
+ },
+ "public_url": {
+ "description": "URL used to access this router instance",
+ "type": "string"
+ },
+ "subgraphs": {
+ "description": "Specify on which subgraph we enable the callback mode for subscription If empty it applies to all subgraphs (passthrough mode takes precedence)",
+ "default": [],
+ "type": "array",
+ "items": {
+ "type": "string"
+ },
+ "uniqueItems": true
+ }
+ },
+ "additionalProperties": false,
+ "nullable": true
+ }
+ },
+ "additionalProperties": false
+ },
+ "queue_capacity": {
+ "description": "It represent the capacity of the in memory queue to know how many events we can keep in a buffer",
+ "default": null,
+ "type": "integer",
+ "format": "uint",
+ "minimum": 0.0,
+ "nullable": true
+ }
+ },
+ "additionalProperties": false
+ },
"supergraph": {
"description": "Configuration for the supergraph",
"default": {
diff --git a/apollo-router/src/error.rs b/apollo-router/src/error.rs
index 9d3e1fe1d1..bfad581e22 100644
--- a/apollo-router/src/error.rs
+++ b/apollo-router/src/error.rs
@@ -87,6 +87,16 @@ pub(crate) enum FetchError {
/// The reason the fetch failed.
reason: String,
},
+ /// Websocket fetch failed from '{service}': {reason}
+ ///
+ /// note that this relates to a transport error and not a GraphQL error
+ SubrequestWsError {
+ /// The service failed.
+ service: String,
+
+ /// The reason the fetch failed.
+ reason: String,
+ },
/// subquery requires field '{field}' but it was not found in the current response
ExecutionFieldNotFound {
@@ -135,6 +145,7 @@ impl FetchError {
}
FetchError::SubrequestMalformedResponse { service, .. }
| FetchError::SubrequestUnexpectedPatchResponse { service }
+ | FetchError::SubrequestWsError { service, .. }
| FetchError::CompressionError { service, .. } => {
extensions
.entry("service")
@@ -181,6 +192,7 @@ impl ErrorExtension for FetchError {
"SUBREQUEST_UNEXPECTED_PATCH_RESPONSE"
}
FetchError::SubrequestHttpError { .. } => "SUBREQUEST_HTTP_ERROR",
+ FetchError::SubrequestWsError { .. } => "SUBREQUEST_WEBSOCKET_ERROR",
FetchError::ExecutionFieldNotFound { .. } => "EXECUTION_FIELD_NOT_FOUND",
FetchError::ExecutionPathNotFound { .. } => "EXECUTION_PATH_NOT_FOUND",
FetchError::CompressionError { .. } => "COMPRESSION_ERROR",
diff --git a/apollo-router/src/lib.rs b/apollo-router/src/lib.rs
index 437e7135db..296dcc8a4d 100644
--- a/apollo-router/src/lib.rs
+++ b/apollo-router/src/lib.rs
@@ -59,8 +59,10 @@ mod http_ext;
mod http_server_factory;
mod introspection;
pub mod layers;
+pub(crate) mod notification;
mod orbiter;
mod plugins;
+pub(crate) mod protocols;
mod query_planner;
mod request;
mod response;
@@ -78,6 +80,7 @@ pub use crate::configuration::ListenAddr;
pub use crate::context::Context;
pub use crate::executable::main;
pub use crate::executable::Executable;
+pub use crate::notification::Notify;
pub use crate::router::ApolloRouterError;
pub use crate::router::ConfigurationSource;
pub use crate::router::LicenseSource;
diff --git a/apollo-router/src/notification.rs b/apollo-router/src/notification.rs
new file mode 100644
index 0000000000..17240c359b
--- /dev/null
+++ b/apollo-router/src/notification.rs
@@ -0,0 +1,979 @@
+//! Internal pub/sub facility for subscription
+
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::task::Context;
+use std::task::Poll;
+use std::time::Duration;
+use std::time::Instant;
+
+use futures::channel::mpsc;
+use futures::channel::mpsc::SendError;
+use futures::channel::oneshot;
+use futures::channel::oneshot::Canceled;
+use futures::Sink;
+use futures::SinkExt;
+use futures::Stream;
+use futures::StreamExt;
+use pin_project_lite::pin_project;
+use thiserror::Error;
+use tokio::sync::broadcast;
+use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
+use tokio_stream::wrappers::BroadcastStream;
+use tokio_stream::wrappers::IntervalStream;
+
+use crate::graphql;
+
+static NOTIFY_CHANNEL_SIZE: usize = 1024;
+static DEFAULT_MSG_CHANNEL_SIZE: usize = 128;
+
+#[derive(Error, Debug)]
+pub(crate) enum NotifyError {
+ #[error("cannot send data to pubsub")]
+ SendError(#[from] SendError),
+ #[error("cannot send data to response stream")]
+ BroadcastSendError(#[from] broadcast::error::SendError),
+ #[error("cannot send data to pubsub because channel has been closed")]
+ Canceled(#[from] Canceled),
+ #[error("this topic doesn't exist")]
+ UnknownTopic,
+}
+
+type ResponseSender =
+ oneshot::Sender>, broadcast::Receiver >)>>;
+
+type ResponseSenderWithCreated = oneshot::Sender<(
+ broadcast::Sender>,
+ broadcast::Receiver >,
+ bool,
+)>;
+
+enum Notification {
+ CreateOrSubscribe {
+ topic: K,
+ // Sender connected to the original source stream
+ msg_sender: broadcast::Sender>,
+ // To know if it has been created or re-used
+ response_sender: ResponseSenderWithCreated,
+ heartbeat_enabled: bool,
+ },
+ Subscribe {
+ topic: K,
+ // Oneshot channel to fetch the receiver
+ response_sender: ResponseSender,
+ },
+ SubscribeIfExist {
+ topic: K,
+ // Oneshot channel to fetch the receiver
+ response_sender: ResponseSender,
+ },
+ Unsubscribe {
+ topic: K,
+ },
+ ForceDelete {
+ topic: K,
+ },
+ Exist {
+ topic: K,
+ response_sender: oneshot::Sender,
+ },
+ InvalidIds {
+ topics: Vec,
+ response_sender: oneshot::Sender<(Vec, Vec)>,
+ },
+ #[cfg(test)]
+ TryDelete {
+ topic: K,
+ },
+ #[cfg(test)]
+ Broadcast {
+ data: V,
+ },
+ #[cfg(test)]
+ Debug {
+ // Returns the number of subscriptions and subscribers
+ response_sender: oneshot::Sender,
+ },
+}
+
+/// In memory pub/sub implementation
+#[derive(Clone)]
+pub struct Notify {
+ sender: mpsc::Sender>,
+ /// Size (number of events) of the channel to receive message
+ pub(crate) queue_size: Option,
+}
+
+#[buildstructor::buildstructor]
+impl Notify
+where
+ K: Send + Hash + Eq + Clone + 'static,
+ V: Send + Sync + Clone + 'static,
+{
+ #[builder]
+ pub(crate) fn new(
+ ttl: Option,
+ heartbeat_error_message: Option,
+ queue_size: Option,
+ ) -> Notify {
+ let (sender, receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
+ tokio::task::spawn(task(receiver, ttl, heartbeat_error_message));
+ Notify { sender, queue_size }
+ }
+
+ #[doc(hidden)]
+ /// NOOP notifier for tests
+ pub fn for_tests() -> Self {
+ let (sender, _receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
+ Notify {
+ sender,
+ queue_size: None,
+ }
+ }
+}
+impl Notify
+where
+ K: Send + Hash + Eq + Clone + 'static,
+ V: Send + Clone + 'static,
+{
+ #[cfg(not(test))]
+ pub(crate) fn set_queue_size(mut self, queue_size: Option) -> Self {
+ self.queue_size = queue_size;
+ self
+ }
+
+ // boolean in the tuple means `created`
+ pub(crate) async fn create_or_subscribe(
+ &mut self,
+ topic: K,
+ heartbeat_enabled: bool,
+ ) -> Result<(Handle, bool), NotifyError> {
+ let (sender, _receiver) =
+ broadcast::channel(self.queue_size.unwrap_or(DEFAULT_MSG_CHANNEL_SIZE));
+
+ let (tx, rx) = oneshot::channel();
+ self.sender
+ .send(Notification::CreateOrSubscribe {
+ topic: topic.clone(),
+ msg_sender: sender,
+ response_sender: tx,
+ heartbeat_enabled,
+ })
+ .await?;
+
+ let (msg_sender, msg_receiver, created) = rx.await?;
+ let handle = Handle::new(
+ topic,
+ self.sender.clone(),
+ msg_sender,
+ BroadcastStream::from(msg_receiver),
+ );
+
+ Ok((handle, created))
+ }
+
+ pub(crate) async fn subscribe(&mut self, topic: K) -> Result, NotifyError> {
+ let (sender, receiver) = oneshot::channel();
+
+ self.sender
+ .send(Notification::Subscribe {
+ topic: topic.clone(),
+ response_sender: sender,
+ })
+ .await?;
+
+ let Some((msg_sender, msg_receiver)) = receiver.await? else {
+ return Err(NotifyError::UnknownTopic);
+ };
+ let handle = Handle::new(
+ topic,
+ self.sender.clone(),
+ msg_sender,
+ BroadcastStream::from(msg_receiver),
+ );
+
+ Ok(handle)
+ }
+
+ pub(crate) async fn subscribe_if_exist(
+ &mut self,
+ topic: K,
+ ) -> Result>, NotifyError> {
+ let (sender, receiver) = oneshot::channel();
+
+ self.sender
+ .send(Notification::SubscribeIfExist {
+ topic: topic.clone(),
+ response_sender: sender,
+ })
+ .await?;
+
+ let Some((msg_sender, msg_receiver)) = receiver.await? else {
+ return Ok(None);
+ };
+ let handle = Handle::new(
+ topic,
+ self.sender.clone(),
+ msg_sender,
+ BroadcastStream::from(msg_receiver),
+ );
+
+ Ok(handle.into())
+ }
+
+ pub(crate) async fn exist(&mut self, topic: K) -> Result> {
+ // Channel to check if the topic still exists or not
+ let (response_tx, response_rx) = oneshot::channel();
+
+ self.sender
+ .send(Notification::Exist {
+ topic,
+ response_sender: response_tx,
+ })
+ .await?;
+
+ let resp = response_rx.await?;
+
+ Ok(resp)
+ }
+
+ pub(crate) async fn invalid_ids(
+ &mut self,
+ topics: Vec,
+ ) -> Result<(Vec, Vec), NotifyError> {
+ // Channel to check if the topic still exists or not
+ let (response_tx, response_rx) = oneshot::channel();
+
+ self.sender
+ .send(Notification::InvalidIds {
+ topics,
+ response_sender: response_tx,
+ })
+ .await?;
+
+ let resp = response_rx.await?;
+
+ Ok(resp)
+ }
+
+ /// Delete the topic even if several subscribers are still listening
+ pub(crate) async fn force_delete(&mut self, topic: K) -> Result<(), NotifyError> {
+ // if disconnected, we don't care (the task was stopped)
+ self.sender
+ .send(Notification::ForceDelete { topic })
+ .await
+ .map_err(std::convert::Into::into)
+ }
+
+ /// Delete the topic if and only if one or zero subscriber is still listening
+ /// This function is not async to allow it to be used in a Drop impl
+ #[cfg(test)]
+ pub(crate) fn try_delete(&mut self, topic: K) -> Result<(), NotifyError> {
+ // if disconnected, we don't care (the task was stopped)
+ self.sender
+ .try_send(Notification::TryDelete { topic })
+ .map_err(|try_send_error| try_send_error.into_send_error().into())
+ }
+
+ #[cfg(test)]
+ pub(crate) async fn broadcast(&mut self, data: V) -> Result<(), NotifyError> {
+ self.sender
+ .send(Notification::Broadcast { data })
+ .await
+ .map_err(std::convert::Into::into)
+ }
+
+ #[cfg(test)]
+ pub(crate) async fn debug(&mut self) -> Result> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.sender
+ .send(Notification::Debug {
+ response_sender: response_tx,
+ })
+ .await?;
+
+ Ok(response_rx.await.unwrap())
+ }
+}
+
+#[cfg(test)]
+impl Default for Notify
+where
+ K: Send + Hash + Eq + Clone + 'static,
+ V: Send + Sync + Clone + 'static,
+{
+ /// Useless notify mainly for test
+ fn default() -> Self {
+ Self::for_tests()
+ }
+}
+
+impl Debug for Notify {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Notify").finish()
+ }
+}
+
+struct HandleGuard
+where
+ K: Clone,
+{
+ topic: K,
+ pubsub_sender: mpsc::Sender>,
+}
+
+impl Clone for HandleGuard
+where
+ K: Clone,
+{
+ fn clone(&self) -> Self {
+ Self {
+ topic: self.topic.clone(),
+ pubsub_sender: self.pubsub_sender.clone(),
+ }
+ }
+}
+
+impl Drop for HandleGuard
+where
+ K: Clone,
+{
+ fn drop(&mut self) {
+ let err = self.pubsub_sender.try_send(Notification::Unsubscribe {
+ topic: self.topic.clone(),
+ });
+ if let Err(err) = err {
+ tracing::trace!("cannot unsubscribe {err:?}");
+ }
+ }
+}
+
+pin_project! {
+pub struct Handle
+where
+ K: Clone,
+{
+ handle_guard: HandleGuard,
+ #[pin]
+ msg_sender: broadcast::Sender>,
+ #[pin]
+ msg_receiver: BroadcastStream >,
+}
+}
+
+impl Clone for Handle
+where
+ K: Clone,
+ V: Clone + Send + 'static,
+{
+ fn clone(&self) -> Self {
+ Self {
+ handle_guard: self.handle_guard.clone(),
+ msg_receiver: BroadcastStream::new(self.msg_sender.subscribe()),
+ msg_sender: self.msg_sender.clone(),
+ }
+ }
+}
+
+impl Handle
+where
+ K: Clone,
+{
+ fn new(
+ topic: K,
+ pubsub_sender: mpsc::Sender>,
+ msg_sender: broadcast::Sender>,
+ msg_receiver: BroadcastStream >,
+ ) -> Self {
+ Self {
+ handle_guard: HandleGuard {
+ topic,
+ pubsub_sender,
+ },
+ msg_sender,
+ msg_receiver,
+ }
+ }
+
+ pub(crate) fn into_stream(self) -> HandleStream {
+ HandleStream {
+ handle_guard: self.handle_guard,
+ msg_receiver: self.msg_receiver,
+ }
+ }
+
+ pub(crate) fn into_sink(self) -> HandleSink {
+ HandleSink {
+ handle_guard: self.handle_guard,
+ msg_sender: self.msg_sender,
+ }
+ }
+
+ /// Return a sink and a stream
+ pub fn split(self) -> (HandleSink, HandleStream) {
+ (
+ HandleSink {
+ handle_guard: self.handle_guard.clone(),
+ msg_sender: self.msg_sender,
+ },
+ HandleStream {
+ handle_guard: self.handle_guard,
+ msg_receiver: self.msg_receiver,
+ },
+ )
+ }
+}
+
+pin_project! {
+pub struct HandleStream
+where
+ K: Clone,
+{
+ handle_guard: HandleGuard,
+ #[pin]
+ msg_receiver: BroadcastStream>,
+}
+}
+
+impl Stream for HandleStream
+where
+ K: Clone,
+ V: Clone + 'static + Send,
+{
+ type Item = V;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
+ let mut this = self.as_mut().project();
+
+ match Pin::new(&mut this.msg_receiver).poll_next(cx) {
+ Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
+ tracing::info!(monotonic_counter.apollo_router_skipped_event_count = 1u64,);
+ self.poll_next(cx)
+ }
+ Poll::Ready(None) => Poll::Ready(None),
+ Poll::Ready(Some(Ok(Some(val)))) => Poll::Ready(Some(val)),
+ Poll::Ready(Some(Ok(None))) => Poll::Ready(None),
+ Poll::Pending => Poll::Pending,
+ }
+ }
+}
+
+pin_project! {
+pub struct HandleSink
+where
+ K: Clone,
+{
+ handle_guard: HandleGuard,
+ #[pin]
+ msg_sender: broadcast::Sender>,
+}
+}
+
+impl HandleSink
+where
+ K: Clone,
+ V: Clone + 'static + Send,
+{
+ /// Send data to the subscribed topic
+ pub(crate) fn send_sync(&mut self, data: V) -> Result<(), NotifyError> {
+ self.msg_sender.send(data.into()).map_err(|err| {
+ NotifyError::BroadcastSendError(broadcast::error::SendError(err.0.unwrap()))
+ })?;
+
+ Ok(())
+ }
+}
+
+impl Sink for HandleSink
+where
+ K: Clone,
+ V: Clone + 'static + Send,
+{
+ type Error = graphql::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
+ self.msg_sender.send(Some(item)).map_err(|_err| {
+ graphql::Error::builder()
+ .message("cannot send payload through pubsub")
+ .extension_code("NOTIFICATION_HANDLE_SEND_ERROR")
+ .build()
+ })?;
+ Ok(())
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_close(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll> {
+ let topic = self.handle_guard.topic.clone();
+ let _ = self
+ .handle_guard
+ .pubsub_sender
+ .try_send(Notification::ForceDelete { topic });
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl Handle where K: Clone {}
+
+async fn task(
+ mut receiver: mpsc::Receiver>,
+ ttl: Option,
+ heartbeat_error_message: Option,
+) where
+ K: Send + Hash + Eq + Clone + 'static,
+ V: Send + Clone + 'static,
+{
+ let mut pubsub: PubSub = PubSub::new(ttl);
+
+ let mut ttl_fut: Box + Send + Unpin> = match ttl {
+ Some(ttl) => Box::new(IntervalStream::new(tokio::time::interval(ttl))),
+ None => Box::new(tokio_stream::pending()),
+ };
+
+ loop {
+ tokio::select! {
+ _ = ttl_fut.next() => {
+ let heartbeat_error_message = heartbeat_error_message.clone();
+ pubsub.kill_dead_topics(heartbeat_error_message).await;
+ tracing::info!(
+ value.apollo_router_opened_subscriptions = pubsub.subscriptions.len() as u64,
+ );
+ }
+ message = receiver.next() => {
+ match message {
+ Some(message) => {
+ match message {
+ Notification::Unsubscribe { topic } => pubsub.unsubscribe(topic),
+ Notification::ForceDelete { topic } => pubsub.force_delete(topic),
+ Notification::CreateOrSubscribe { topic, msg_sender, response_sender, heartbeat_enabled } => {
+ pubsub.subscribe_or_create(topic, msg_sender, response_sender, heartbeat_enabled);
+ }
+ Notification::Subscribe {
+ topic,
+ response_sender,
+ } => {
+ pubsub.subscribe(topic, response_sender);
+ }
+ Notification::SubscribeIfExist {
+ topic,
+ response_sender,
+ } => {
+ if pubsub.is_used(&topic) {
+ pubsub.subscribe(topic, response_sender);
+ } else {
+ pubsub.force_delete(topic);
+ let _ = response_sender.send(None);
+ }
+ }
+ Notification::InvalidIds {
+ topics,
+ response_sender,
+ } => {
+ let invalid_topics = pubsub.invalid_topics(topics);
+ let _ = response_sender.send(invalid_topics);
+ }
+ Notification::Exist {
+ topic,
+ response_sender,
+ } => {
+ let exist = pubsub.exist(&topic);
+ let _ = response_sender.send(exist);
+ if exist {
+ pubsub.touch(&topic);
+ }
+ }
+ #[cfg(test)]
+ Notification::TryDelete { topic } => pubsub.try_delete(topic),
+ #[cfg(test)]
+ Notification::Broadcast { data } => {
+ pubsub.broadcast(data).await;
+ }
+ #[cfg(test)]
+ Notification::Debug { response_sender } => {
+ let _ = response_sender.send(pubsub.subscriptions.len());
+ }
+ }
+ },
+ None => break,
+ }
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Subscription {
+ msg_sender: broadcast::Sender>,
+ heartbeat_enabled: bool,
+ updated_at: Instant,
+}
+
+impl Subscription {
+ fn new(msg_sender: broadcast::Sender>, heartbeat_enabled: bool) -> Self {
+ Self {
+ msg_sender,
+ heartbeat_enabled,
+ updated_at: Instant::now(),
+ }
+ }
+ // Update the updated_at value
+ fn touch(&mut self) {
+ self.updated_at = Instant::now();
+ }
+}
+
+struct PubSub
+where
+ K: Hash + Eq,
+{
+ subscriptions: HashMap>,
+ ttl: Option,
+}
+
+impl Default for PubSub
+where
+ K: Hash + Eq,
+{
+ fn default() -> Self {
+ Self {
+ // subscribers: HashMap::new(),
+ subscriptions: HashMap::new(),
+ ttl: None,
+ }
+ }
+}
+
+impl PubSub
+where
+ K: Hash + Eq + Clone,
+ V: Clone + 'static,
+{
+ fn new(ttl: Option) -> Self {
+ Self {
+ subscriptions: HashMap::new(),
+ ttl,
+ }
+ }
+
+ fn create_topic(
+ &mut self,
+ topic: K,
+ sender: broadcast::Sender>,
+ heartbeat_enabled: bool,
+ ) {
+ self.subscriptions
+ .insert(topic, Subscription::new(sender, heartbeat_enabled));
+ }
+
+ fn subscribe(&mut self, topic: K, sender: ResponseSender) {
+ match self.subscriptions.get_mut(&topic) {
+ Some(subscription) => {
+ let _ = sender.send(Some((
+ subscription.msg_sender.clone(),
+ subscription.msg_sender.subscribe(),
+ )));
+ }
+ None => {
+ let _ = sender.send(None);
+ }
+ }
+ }
+
+ fn subscribe_or_create(
+ &mut self,
+ topic: K,
+ msg_sender: broadcast::Sender>,
+ sender: ResponseSenderWithCreated,
+ heartbeat_enabled: bool,
+ ) {
+ match self.subscriptions.get(&topic) {
+ Some(subscription) => {
+ let _ = sender.send((
+ subscription.msg_sender.clone(),
+ subscription.msg_sender.subscribe(),
+ false,
+ ));
+ }
+ None => {
+ self.create_topic(topic, msg_sender.clone(), heartbeat_enabled);
+
+ let _ = sender.send((msg_sender.clone(), msg_sender.subscribe(), true));
+ }
+ }
+ }
+
+ fn unsubscribe(&mut self, topic: K) {
+ let mut topic_to_delete = false;
+ match self.subscriptions.get(&topic) {
+ Some(subscription) => {
+ topic_to_delete = subscription.msg_sender.receiver_count() == 0;
+ }
+ None => tracing::trace!("Cannot find the subscription to unsubscribe"),
+ }
+ if topic_to_delete {
+ self.subscriptions.remove(&topic);
+ };
+ }
+
+ /// Check if the topic is used by anyone else than the current handle
+ fn is_used(&self, topic: &K) -> bool {
+ self.subscriptions
+ .get(topic)
+ .map(|s| s.msg_sender.receiver_count() > 0)
+ .unwrap_or_default()
+ }
+
+ /// Update the heartbeat
+ fn touch(&mut self, topic: &K) {
+ if let Some(sub) = self.subscriptions.get_mut(topic) {
+ sub.touch();
+ }
+ }
+
+ /// Check if the topic exists
+ fn exist(&self, topic: &K) -> bool {
+ self.subscriptions.contains_key(topic)
+ }
+
+ /// Given a list of topics, returns the list of valid and invalid topics
+ /// Heartbeat the given valid topics
+ fn invalid_topics(&mut self, topics: Vec) -> (Vec, Vec) {
+ topics.into_iter().fold(
+ (Vec::new(), Vec::new()),
+ |(mut valid_ids, mut invalid_ids), e| {
+ match self.subscriptions.get_mut(&e) {
+ Some(sub) => {
+ sub.touch();
+ valid_ids.push(e);
+ }
+ None => {
+ invalid_ids.push(e);
+ }
+ }
+
+ (valid_ids, invalid_ids)
+ },
+ )
+ }
+
+ /// clean all topics which didn't heartbeat
+ async fn kill_dead_topics(&mut self, heartbeat_error_message: Option) {
+ if let Some(ttl) = self.ttl {
+ let drained = self.subscriptions.drain();
+ let (remaining_subs, closed_subs) = drained.into_iter().fold(
+ (HashMap::new(), HashMap::new()),
+ |(mut acc, mut acc_error), (topic, sub)| {
+ if (!sub.heartbeat_enabled || sub.updated_at.elapsed() <= ttl)
+ && sub.msg_sender.receiver_count() > 0
+ {
+ acc.insert(topic, sub);
+ } else {
+ acc_error.insert(topic, sub);
+ }
+
+ (acc, acc_error)
+ },
+ );
+ self.subscriptions = remaining_subs;
+
+ // Send error message to all killed connections
+ for (_subscriber_id, subscription) in closed_subs {
+ if let Some(heartbeat_error_message) = &heartbeat_error_message {
+ let _ = subscription
+ .msg_sender
+ .send(heartbeat_error_message.clone().into());
+ let _ = subscription.msg_sender.send(None);
+ }
+ }
+ }
+ }
+
+ #[cfg(test)]
+ fn try_delete(&mut self, topic: K) {
+ if let Some(sub) = self.subscriptions.get(&topic) {
+ if sub.msg_sender.receiver_count() > 1 {
+ return;
+ }
+ }
+
+ self.force_delete(topic);
+ }
+
+ fn force_delete(&mut self, topic: K) {
+ tracing::trace!("deleting subscription");
+ let sub = self.subscriptions.remove(&topic);
+ if let Some(sub) = sub {
+ let _ = sub.msg_sender.send(None);
+ }
+ }
+
+ #[cfg(test)]
+ async fn broadcast(&mut self, value: V) -> Option<()>
+ where
+ V: Clone,
+ {
+ let mut fut = vec![];
+ for (sub_id, sub) in &self.subscriptions {
+ let cloned_value = value.clone();
+ let sub_id = sub_id.clone();
+ fut.push(
+ sub.msg_sender
+ .send(cloned_value.into())
+ .is_err()
+ .then_some(sub_id),
+ );
+ }
+ // clean closed sender
+ let sub_to_clean: Vec = fut.into_iter().flatten().collect();
+ self.subscriptions
+ .retain(|k, s| s.msg_sender.receiver_count() > 0 && !sub_to_clean.contains(k));
+
+ Some(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use uuid::Uuid;
+
+ use super::*;
+
+ #[tokio::test]
+ async fn subscribe() {
+ let mut notify = Notify::builder().build();
+ let topic_1 = Uuid::new_v4();
+ let topic_2 = Uuid::new_v4();
+
+ let (handle1, created) = notify.create_or_subscribe(topic_1, false).await.unwrap();
+ assert!(created);
+ let (_handle2, created) = notify.create_or_subscribe(topic_2, false).await.unwrap();
+ assert!(created);
+
+ let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
+ let handle_1_other = notify.subscribe(topic_1).await.unwrap();
+ let mut cloned_notify = notify.clone();
+
+ let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
+ handle
+ .send_sync(serde_json_bytes::json!({"test": "ok"}))
+ .unwrap();
+ drop(handle);
+ drop(handle1);
+ let mut handle_1_bis = handle_1_bis.into_stream();
+ let new_msg = handle_1_bis.next().await.unwrap();
+ assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
+ let mut handle_1_other = handle_1_other.into_stream();
+ let new_msg = handle_1_other.next().await.unwrap();
+ assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
+
+ assert!(notify.exist(topic_1).await.unwrap());
+ assert!(notify.exist(topic_2).await.unwrap());
+
+ drop(_handle2);
+ drop(handle_1_bis);
+ drop(handle_1_other);
+
+ let subscriptions_nb = notify.debug().await.unwrap();
+ assert_eq!(subscriptions_nb, 0);
+ }
+
+ #[tokio::test]
+ async fn it_subscribe_and_delete() {
+ let mut notify = Notify::builder().build();
+ let topic_1 = Uuid::new_v4();
+ let topic_2 = Uuid::new_v4();
+
+ let (handle1, created) = notify.create_or_subscribe(topic_1, true).await.unwrap();
+ assert!(created);
+ let (_handle2, created) = notify.create_or_subscribe(topic_2, true).await.unwrap();
+ assert!(created);
+
+ let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
+ let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
+ let mut cloned_notify = notify.clone();
+ let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
+ handle
+ .send_sync(serde_json_bytes::json!({"test": "ok"}))
+ .unwrap();
+ drop(handle);
+ assert!(notify.exist(topic_1).await.unwrap());
+ drop(_handle_1_bis);
+ drop(_handle_1_other);
+
+ notify.try_delete(topic_1).unwrap();
+
+ let subscriptions_nb = notify.debug().await.unwrap();
+ assert_eq!(subscriptions_nb, 1);
+
+ assert!(!notify.exist(topic_1).await.unwrap());
+
+ notify.force_delete(topic_1).await.unwrap();
+
+ let mut handle1 = handle1.into_stream();
+ let new_msg = handle1.next().await.unwrap();
+ assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
+ assert!(handle1.next().await.is_none());
+ assert!(notify.exist(topic_2).await.unwrap());
+ notify.try_delete(topic_2).unwrap();
+
+ let subscriptions_nb = notify.debug().await.unwrap();
+ assert_eq!(subscriptions_nb, 0);
+ }
+
+ #[tokio::test]
+ async fn it_test_ttl() {
+ let mut notify = Notify::builder()
+ .ttl(Duration::from_millis(100))
+ .heartbeat_error_message(serde_json_bytes::json!({"error": "connection_closed"}))
+ .build();
+ let topic_1 = Uuid::new_v4();
+ let topic_2 = Uuid::new_v4();
+
+ let (handle1, created) = notify.create_or_subscribe(topic_1, true).await.unwrap();
+ assert!(created);
+ let (_handle2, created) = notify.create_or_subscribe(topic_2, true).await.unwrap();
+ assert!(created);
+
+ let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
+ let handle_1_other = notify.subscribe(topic_1).await.unwrap();
+ let mut cloned_notify = notify.clone();
+ tokio::spawn(async move {
+ let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
+ handle
+ .send_sync(serde_json_bytes::json!({"test": "ok"}))
+ .unwrap();
+ });
+ drop(handle1);
+
+ let mut handle_1_bis = handle_1_bis.into_stream();
+ let new_msg = handle_1_bis.next().await.unwrap();
+ assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
+ let mut handle_1_other = handle_1_other.into_stream();
+ let new_msg = handle_1_other.next().await.unwrap();
+ assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
+
+ tokio::time::sleep(Duration::from_millis(200)).await;
+ let res = handle_1_bis.next().await.unwrap();
+ assert_eq!(res, serde_json_bytes::json!({"error": "connection_closed"}));
+
+ assert!(handle_1_bis.next().await.is_none());
+
+ assert!(!notify.exist(topic_1).await.unwrap());
+ assert!(!notify.exist(topic_2).await.unwrap());
+
+ let subscriptions_nb = notify.debug().await.unwrap();
+ assert_eq!(subscriptions_nb, 0);
+ }
+}
diff --git a/apollo-router/src/plugin/mod.rs b/apollo-router/src/plugin/mod.rs
index ad74aa4970..fcbd0d8c05 100644
--- a/apollo-router/src/plugin/mod.rs
+++ b/apollo-router/src/plugin/mod.rs
@@ -37,7 +37,9 @@ use tower::BoxError;
use tower::Service;
use tower::ServiceBuilder;
+use crate::graphql;
use crate::layers::ServiceBuilderExt;
+use crate::notification::Notify;
use crate::router_factory::Endpoint;
use crate::services::execution;
use crate::services::router;
@@ -45,8 +47,11 @@ use crate::services::subgraph;
use crate::services::supergraph;
use crate::ListenAddr;
-type InstanceFactory =
- fn(&serde_json::Value, Arc) -> BoxFuture, BoxError>>;
+type InstanceFactory = fn(
+ &serde_json::Value,
+ Arc,
+ Notify,
+) -> BoxFuture, BoxError>>;
type SchemaFactory = fn(&mut SchemaGenerator) -> schemars::schema::Schema;
@@ -61,34 +66,103 @@ pub struct PluginInit {
pub config: T,
/// Router Supergraph Schema (schema definition language)
pub supergraph_sdl: Arc,
+
+ pub(crate) notify: Notify,
}
impl PluginInit
where
T: for<'de> Deserialize<'de>,
{
+ #[deprecated = "use PluginInit::builder() instead"]
/// Create a new PluginInit for the supplied config and SDL.
pub fn new(config: T, supergraph_sdl: Arc) -> Self {
- PluginInit {
- config,
- supergraph_sdl,
- }
+ Self::builder()
+ .config(config)
+ .supergraph_sdl(supergraph_sdl)
+ .notify(Notify::builder().build())
+ .build()
}
/// Try to create a new PluginInit for the supplied JSON and SDL.
///
/// This will fail if the supplied JSON cannot be deserialized into the configuration
/// struct.
+ #[deprecated = "use PluginInit::try_builder() instead"]
pub fn try_new(
config: serde_json::Value,
supergraph_sdl: Arc,
+ ) -> Result {
+ Self::try_builder()
+ .config(config)
+ .supergraph_sdl(supergraph_sdl)
+ .notify(Notify::builder().build())
+ .build()
+ }
+
+ #[cfg(test)]
+ pub(crate) fn fake_new(config: T, supergraph_sdl: Arc) -> Self {
+ PluginInit {
+ config,
+ supergraph_sdl,
+ notify: Notify::for_tests(),
+ }
+ }
+}
+
+#[buildstructor::buildstructor]
+impl PluginInit
+where
+ T: for<'de> Deserialize<'de>,
+{
+ /// Create a new PluginInit builder
+ #[builder(entry = "builder", exit = "build", visibility = "pub")]
+ /// Build a new PluginInit for the supplied configuration and SDL.
+ ///
+ /// You can reuse a notify instance, or Build your own.
+ pub(crate) fn new_builder(
+ config: T,
+ supergraph_sdl: Arc,
+ notify: Notify,
+ ) -> Self {
+ PluginInit {
+ config,
+ supergraph_sdl,
+ notify,
+ }
+ }
+
+ #[builder(entry = "try_builder", exit = "build", visibility = "pub")]
+ /// Try to build a new PluginInit for the supplied json configuration and SDL.
+ ///
+ /// You can reuse a notify instance, or Build your own.
+ /// invoking build() will fail if the JSON doesn't comply with the configuration format.
+ pub(crate) fn try_new_builder(
+ config: serde_json::Value,
+ supergraph_sdl: Arc,
+ notify: Notify,
) -> Result {
let config: T = serde_json::from_value(config)?;
Ok(PluginInit {
config,
supergraph_sdl,
+ notify,
})
}
+
+ /// Create a new PluginInit builder
+ #[builder(entry = "fake_builder", exit = "build", visibility = "pub")]
+ fn fake_new_builder(
+ config: T,
+ supergraph_sdl: Option>,
+ notify: Option>,
+ ) -> Self {
+ PluginInit {
+ config,
+ supergraph_sdl: supergraph_sdl.unwrap_or_default(),
+ notify: notify.unwrap_or_else(Notify::for_tests),
+ }
+ }
}
/// Factories for plugin schema and configuration.
@@ -120,9 +194,13 @@ impl PluginFactory {
tracing::debug!(%plugin_factory_name, "creating plugin factory");
PluginFactory {
name: plugin_factory_name,
- instance_factory: |configuration, schema| {
+ instance_factory: |configuration, schema, notify| {
Box::pin(async move {
- let init = PluginInit::try_new(configuration.clone(), schema)?;
+ let init = PluginInit::try_builder()
+ .config(configuration.clone())
+ .supergraph_sdl(schema)
+ .notify(notify)
+ .build()?;
let plugin = P::new(init).await?;
Ok(Box::new(plugin) as Box)
})
@@ -136,8 +214,9 @@ impl PluginFactory {
&self,
configuration: &serde_json::Value,
supergraph_sdl: Arc,
+ notify: Notify,
) -> Result, BoxError> {
- (self.instance_factory)(configuration, supergraph_sdl).await
+ (self.instance_factory)(configuration, supergraph_sdl, notify).await
}
#[cfg(test)]
@@ -145,7 +224,7 @@ impl PluginFactory {
&self,
configuration: &serde_json::Value,
) -> Result, BoxError> {
- (self.instance_factory)(configuration, Default::default()).await
+ (self.instance_factory)(configuration, Default::default(), Default::default()).await
}
pub(crate) fn create_schema(&self, gen: &mut SchemaGenerator) -> schemars::schema::Schema {
diff --git a/apollo-router/src/plugin/test/mock/canned.rs b/apollo-router/src/plugin/test/mock/canned.rs
index 5230ed3cab..227a0ba806 100644
--- a/apollo-router/src/plugin/test/mock/canned.rs
+++ b/apollo-router/src/plugin/test/mock/canned.rs
@@ -36,6 +36,12 @@ pub(crate) fn accounts_subgraph() -> MockSubgraph {
]
}
}}
+ ),
+ (
+ json! {{
+ "query": "subscription{userWasCreated{name}}",
+ }},
+ json! {{}}
)
].into_iter().map(|(query, response)| (serde_json::from_value(query).unwrap(), serde_json::from_value(response).unwrap())).collect();
MockSubgraph::new(account_mocks)
diff --git a/apollo-router/src/plugin/test/mock/subgraph.rs b/apollo-router/src/plugin/test/mock/subgraph.rs
index 20362dbe0a..5d1a6da714 100644
--- a/apollo-router/src/plugin/test/mock/subgraph.rs
+++ b/apollo-router/src/plugin/test/mock/subgraph.rs
@@ -11,9 +11,11 @@ use http::StatusCode;
use tower::BoxError;
use tower::Service;
+use crate::graphql;
use crate::graphql::Request;
use crate::graphql::Response;
use crate::json_ext::Object;
+use crate::notification::Handle;
use crate::services::SubgraphRequest;
use crate::services::SubgraphResponse;
@@ -24,6 +26,7 @@ pub struct MockSubgraph {
// using an arc to improve efficiency when service is cloned
mocks: Arc,
extensions: Option,
+ subscription_stream: Option>,
}
impl MockSubgraph {
@@ -31,6 +34,7 @@ impl MockSubgraph {
Self {
mocks: Arc::new(mocks),
extensions: None,
+ subscription_stream: None,
}
}
@@ -42,6 +46,14 @@ impl MockSubgraph {
self.extensions = Some(extensions);
self
}
+
+ pub fn with_subscription_stream(
+ mut self,
+ subscription_stream: Handle,
+ ) -> Self {
+ self.subscription_stream = Some(subscription_stream);
+ self
+ }
}
/// Builder for `MockSubgraph`
@@ -49,6 +61,7 @@ impl MockSubgraph {
pub struct MockSubgraphBuilder {
mocks: MockResponses,
extensions: Option,
+ subscription_stream: Option>,
}
impl MockSubgraphBuilder {
pub fn with_extensions(mut self, extensions: Object) -> Self {
@@ -68,10 +81,19 @@ impl MockSubgraphBuilder {
self
}
+ pub fn with_subscription_stream(
+ mut self,
+ subscription_stream: Handle,
+ ) -> Self {
+ self.subscription_stream = Some(subscription_stream);
+ self
+ }
+
pub fn build(self) -> MockSubgraph {
MockSubgraph {
mocks: Arc::new(self.mocks),
extensions: self.extensions,
+ subscription_stream: self.subscription_stream,
}
}
}
@@ -87,8 +109,43 @@ impl Service for MockSubgraph {
Poll::Ready(Ok(()))
}
- fn call(&mut self, req: SubgraphRequest) -> Self::Future {
- let response = if let Some(response) = self.mocks.get(req.subgraph_request.body()) {
+ fn call(&mut self, mut req: SubgraphRequest) -> Self::Future {
+ let body = req.subgraph_request.body_mut();
+
+ if let Some(sub_stream) = &mut req.subscription_stream {
+ sub_stream
+ .try_send(
+ self.subscription_stream
+ .take()
+ .expect("must have a subscription stream set")
+ .into_stream(),
+ )
+ .unwrap();
+ }
+
+ // Redact the callback url and subscription_id because it generates a subscription uuid
+ if let Some(serde_json_bytes::Value::Object(subscription_ext)) =
+ body.extensions.get_mut("subscription")
+ {
+ if let Some(callback_url) = subscription_ext.get_mut("callback_url") {
+ let mut cb_url = url::Url::parse(
+ callback_url
+ .as_str()
+ .expect("callback_url extension must be a string"),
+ )
+ .expect("callback_url must be a valid URL");
+ cb_url.path_segments_mut().unwrap().pop();
+ cb_url.path_segments_mut().unwrap().push("subscription_id");
+
+ *callback_url = serde_json_bytes::Value::String(cb_url.to_string().into());
+ }
+ if let Some(subscription_id) = subscription_ext.get_mut("subscription_id") {
+ *subscription_id =
+ serde_json_bytes::Value::String("subscription_id".to_string().into());
+ }
+ }
+
+ let response = if let Some(response) = self.mocks.get(body) {
// Build an http Response
let http_response = http::Response::builder()
.status(StatusCode::OK)
@@ -99,7 +156,7 @@ impl Service for MockSubgraph {
let error = crate::error::Error::builder()
.message(format!(
"couldn't find mock for query {}",
- serde_json::to_string(&req.subgraph_request.body()).unwrap()
+ serde_json::to_string(body).unwrap()
))
.extension_code("FETCH_ERROR".to_string())
.extensions(self.extensions.clone().unwrap_or_default())
diff --git a/apollo-router/src/plugins/csrf.rs b/apollo-router/src/plugins/csrf.rs
index 807900ffa1..98a2fe1181 100644
--- a/apollo-router/src/plugins/csrf.rs
+++ b/apollo-router/src/plugins/csrf.rs
@@ -299,7 +299,7 @@ mod csrf_tests {
.unwrap())
});
- let service_stack = Csrf::new(PluginInit::new(config, Default::default()))
+ let service_stack = Csrf::new(PluginInit::fake_new(config, Default::default()))
.await
.unwrap()
.supergraph_service(mock_service.boxed());
@@ -316,7 +316,7 @@ mod csrf_tests {
}
async fn assert_rejected(config: CSRFConfig, request: supergraph::Request) {
- let service_stack = Csrf::new(PluginInit::new(config, Default::default()))
+ let service_stack = Csrf::new(PluginInit::fake_new(config, Default::default()))
.await
.unwrap()
.supergraph_service(MockSupergraphService::new().boxed());
diff --git a/apollo-router/src/plugins/forbid_mutations.rs b/apollo-router/src/plugins/forbid_mutations.rs
index 5298d7b90a..136b8de3ba 100644
--- a/apollo-router/src/plugins/forbid_mutations.rs
+++ b/apollo-router/src/plugins/forbid_mutations.rs
@@ -93,7 +93,7 @@ mod forbid_http_get_mutations_tests {
.times(1)
.returning(move |_| Ok(ExecutionResponse::fake_builder().build().unwrap()));
- let service_stack = ForbidMutations::new(PluginInit::new(
+ let service_stack = ForbidMutations::new(PluginInit::fake_new(
ForbidMutationsConfig(true),
Default::default(),
))
@@ -120,7 +120,7 @@ mod forbid_http_get_mutations_tests {
.build();
let expected_status = StatusCode::BAD_REQUEST;
- let service_stack = ForbidMutations::new(PluginInit::new(
+ let service_stack = ForbidMutations::new(PluginInit::fake_new(
ForbidMutationsConfig(true),
Default::default(),
))
@@ -144,7 +144,7 @@ mod forbid_http_get_mutations_tests {
.times(1)
.returning(move |_| Ok(ExecutionResponse::fake_builder().build().unwrap()));
- let service_stack = ForbidMutations::new(PluginInit::new(
+ let service_stack = ForbidMutations::new(PluginInit::fake_new(
ForbidMutationsConfig(false),
Default::default(),
))
diff --git a/apollo-router/src/plugins/headers.rs b/apollo-router/src/plugins/headers.rs
index d5b3fbb78b..7a6e56d4ac 100644
--- a/apollo-router/src/plugins/headers.rs
+++ b/apollo-router/src/plugins/headers.rs
@@ -762,6 +762,8 @@ mod test {
.expect("expecting valid request"),
operation_kind: OperationKind::Query,
context: ctx,
+ subscription_stream: None,
+ connection_closed_signal: None,
}
}
diff --git a/apollo-router/src/plugins/mod.rs b/apollo-router/src/plugins/mod.rs
index 83ee925887..2bd2270df1 100644
--- a/apollo-router/src/plugins/mod.rs
+++ b/apollo-router/src/plugins/mod.rs
@@ -20,7 +20,7 @@ macro_rules! schemar_fn {
};
}
-mod authentication;
+pub(crate) mod authentication;
mod authorization;
mod coprocessor;
#[cfg(test)]
@@ -32,5 +32,6 @@ mod headers;
mod include_subgraph_errors;
pub(crate) mod override_url;
pub(crate) mod rhai;
+pub(crate) mod subscription;
pub(crate) mod telemetry;
pub(crate) mod traffic_shaping;
diff --git a/apollo-router/src/plugins/override_url.rs b/apollo-router/src/plugins/override_url.rs
index 6e2894ced1..0c3b9a35b3 100644
--- a/apollo-router/src/plugins/override_url.rs
+++ b/apollo-router/src/plugins/override_url.rs
@@ -107,6 +107,7 @@ mod tests {
)
.unwrap(),
Default::default(),
+ Default::default(),
)
.await
.unwrap();
diff --git a/apollo-router/src/plugins/rhai/engine.rs b/apollo-router/src/plugins/rhai/engine.rs
index 000398efdc..737be696b8 100644
--- a/apollo-router/src/plugins/rhai/engine.rs
+++ b/apollo-router/src/plugins/rhai/engine.rs
@@ -35,6 +35,7 @@ use crate::graphql::Request;
use crate::graphql::Response;
use crate::http_ext;
use crate::plugins::authentication::APOLLO_AUTHENTICATION_JWT_CLAIMS;
+use crate::plugins::subscription::SUBSCRIPTION_WS_CUSTOM_CONNECTION_PARAMS;
use crate::Context;
const CANNOT_ACCESS_HEADERS_ON_A_DEFERRED_RESPONSE: &str =
@@ -1145,6 +1146,10 @@ impl Rhai {
"APOLLO_AUTHENTICATION_JWT_CLAIMS".into(),
APOLLO_AUTHENTICATION_JWT_CLAIMS.to_string().into(),
);
+ global_variables.insert(
+ "APOLLO_SUBSCRIPTION_WS_CUSTOM_CONNECTION_PARAMS".into(),
+ SUBSCRIPTION_WS_CUSTOM_CONNECTION_PARAMS.to_string().into(),
+ );
let shared_globals = Arc::new(global_variables);
diff --git a/apollo-router/src/plugins/subscription.rs b/apollo-router/src/plugins/subscription.rs
new file mode 100644
index 0000000000..ad44089614
--- /dev/null
+++ b/apollo-router/src/plugins/subscription.rs
@@ -0,0 +1,1240 @@
+use std::collections::HashMap;
+use std::collections::HashSet;
+use std::ops::ControlFlow;
+use std::task::Poll;
+
+use bytes::Buf;
+use futures::future::BoxFuture;
+use hmac::Hmac;
+use hmac::Mac;
+use http::Method;
+use http::StatusCode;
+use multimap::MultiMap;
+use once_cell::sync::OnceCell;
+use schemars::JsonSchema;
+use serde::Deserialize;
+use serde::Serialize;
+use sha2::Digest;
+use sha2::Sha256;
+use tower::BoxError;
+use tower::Service;
+use tower::ServiceBuilder;
+use tower::ServiceExt;
+use tracing_futures::Instrument;
+use uuid::Uuid;
+
+use crate::context::Context;
+use crate::graphql;
+use crate::graphql::Response;
+use crate::json_ext::Object;
+use crate::layers::ServiceBuilderExt;
+use crate::notification::Notify;
+use crate::plugin::Plugin;
+use crate::plugin::PluginInit;
+use crate::protocols::websocket::WebSocketProtocol;
+use crate::query_planner::OperationKind;
+use crate::register_plugin;
+use crate::services::router;
+use crate::services::subgraph;
+use crate::Endpoint;
+use crate::ListenAddr;
+
+type HmacSha256 = Hmac;
+pub(crate) const APOLLO_SUBSCRIPTION_PLUGIN: &str = "apollo.subscription";
+#[cfg(not(test))]
+pub(crate) const APOLLO_SUBSCRIPTION_PLUGIN_NAME: &str = "subscription";
+pub(crate) static SUBSCRIPTION_CALLBACK_HMAC_KEY: OnceCell = OnceCell::new();
+pub(crate) const SUBSCRIPTION_WS_CUSTOM_CONNECTION_PARAMS: &str =
+ "apollo.subscription.custom_connection_params";
+
+#[derive(Debug, Clone)]
+pub(crate) struct Subscription {
+ notify: Notify,
+ callback_hmac_key: Option,
+ pub(crate) config: SubscriptionConfig,
+}
+
+/// Subscriptions configuration
+#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
+#[serde(deny_unknown_fields, default)]
+pub(crate) struct SubscriptionConfig {
+ /// Select a subscription mode (callback or passthrough)
+ pub(crate) mode: SubscriptionModeConfig,
+ /// Enable the deduplication of subscription (for example if we detect the exact same request to subgraph we won't open a new websocket to the subgraph in passthrough mode)
+ /// (default: true)
+ pub(crate) enable_deduplication: bool,
+ /// This is a limit to only have maximum X opened subscriptions at the same time. By default if it's not set there is no limit.
+ pub(crate) max_opened_subscriptions: Option,
+ /// It represent the capacity of the in memory queue to know how many events we can keep in a buffer
+ pub(crate) queue_capacity: Option,
+}
+
+impl Default for SubscriptionConfig {
+ fn default() -> Self {
+ Self {
+ mode: Default::default(),
+ enable_deduplication: true,
+ max_opened_subscriptions: None,
+ queue_capacity: None,
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default, JsonSchema)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct SubscriptionModeConfig {
+ #[serde(rename = "preview_callback")]
+ /// Enable callback mode for subgraph(s)
+ pub(crate) callback: Option,
+ /// Enable passthrough mode for subgraph(s)
+ pub(crate) passthrough: Option,
+}
+
+impl SubscriptionModeConfig {
+ pub(crate) fn get_subgraph_config(&self, service_name: &str) -> Option {
+ if let Some(passthrough_cfg) = &self.passthrough {
+ if let Some(subgraph_cfg) = passthrough_cfg.subgraphs.get(service_name) {
+ return SubscriptionMode::Passthrough(subgraph_cfg.clone()).into();
+ }
+ if let Some(all_cfg) = &passthrough_cfg.all {
+ return SubscriptionMode::Passthrough(all_cfg.clone()).into();
+ }
+ }
+
+ if let Some(callback_cfg) = &self.callback {
+ if callback_cfg.subgraphs.contains(service_name) || callback_cfg.subgraphs.is_empty() {
+ let callback_cfg = CallbackMode {
+ public_url: callback_cfg.public_url.clone(),
+ listen: callback_cfg.listen.clone(),
+ path: callback_cfg.path.clone(),
+ subgraphs: HashSet::new(), // We don't need it
+ };
+ return SubscriptionMode::Callback(callback_cfg).into();
+ }
+ }
+
+ None
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default, JsonSchema)]
+#[serde(deny_unknown_fields, default)]
+pub(crate) struct SubgraphPassthroughMode {
+ /// Configuration for all subgraphs
+ pub(crate) all: Option,
+ /// Configuration for specific subgraphs
+ pub(crate) subgraphs: HashMap,
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub(crate) enum SubscriptionMode {
+ /// Using a callback url
+ Callback(CallbackMode),
+ /// Using websocket to directly connect to subgraph
+ Passthrough(WebSocketConfiguration),
+}
+
+/// Using a callback url
+#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct CallbackMode {
+ #[schemars(with = "String")]
+ /// URL used to access this router instance
+ pub(crate) public_url: url::Url,
+ // `skip_serializing` We don't need it in the context
+ /// Listen address on which the callback must listen (default: 127.0.0.1:4000)
+ #[serde(skip_serializing)]
+ listen: Option,
+ // `skip_serializing` We don't need it in the context
+ /// Specify on which path you want to listen for callbacks (default: /callback)
+ #[serde(skip_serializing)]
+ path: Option,
+
+ /// Specify on which subgraph we enable the callback mode for subscription
+ /// If empty it applies to all subgraphs (passthrough mode takes precedence)
+ #[serde(default)]
+ subgraphs: HashSet,
+}
+
+/// Using websocket to directly connect to subgraph
+#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize, JsonSchema)]
+#[serde(deny_unknown_fields, default)]
+pub(crate) struct PassthroughMode {
+ /// WebSocket configuration for specific subgraphs
+ subgraph: SubgraphPassthroughMode,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize, JsonSchema)]
+#[serde(deny_unknown_fields, default)]
+/// WebSocket configuration for a specific subgraph
+pub(crate) struct WebSocketConfiguration {
+ /// Path on which WebSockets are listening
+ pub(crate) path: Option,
+ /// Which WebSocket GraphQL protocol to use for this subgraph possible values are: 'graphql_ws' | 'graphql_transport_ws' (default: graphql_ws)
+ pub(crate) protocol: WebSocketProtocol,
+}
+
+fn default_path() -> String {
+ String::from("/callback")
+}
+
+fn default_listen_addr() -> ListenAddr {
+ ListenAddr::SocketAddr("127.0.0.1:4000".parse().expect("valid ListenAddr"))
+}
+
+#[async_trait::async_trait]
+impl Plugin for Subscription {
+ type Config = SubscriptionConfig;
+
+ async fn new(init: PluginInit) -> Result {
+ let mut callback_hmac_key = None;
+ if init.config.mode.callback.is_some() {
+ callback_hmac_key = Some(
+ SUBSCRIPTION_CALLBACK_HMAC_KEY
+ .get_or_init(|| Uuid::new_v4().to_string())
+ .clone(),
+ );
+ }
+
+ Ok(Subscription {
+ notify: init.notify,
+ callback_hmac_key,
+ config: init.config,
+ })
+ }
+
+ fn subgraph_service(
+ &self,
+ _subgraph_name: &str,
+ service: subgraph::BoxService,
+ ) -> subgraph::BoxService {
+ let enabled = self.config.mode.callback.is_some() || self.config.mode.passthrough.is_some();
+ ServiceBuilder::new()
+ .checkpoint(move |req: subgraph::Request| {
+ if req.operation_kind == OperationKind::Subscription && !enabled {
+ Ok(ControlFlow::Break(subgraph::Response::builder().context(req.context).error(graphql::Error::builder().message("cannot execute a subscription if it's not enabled in the configuration").extension_code("SUBSCRIPTION_DISABLED").build()).extensions(Object::default()).build()))
+ } else {
+ Ok(ControlFlow::Continue(req))
+ }
+ }).service(service)
+ .boxed()
+ }
+
+ fn web_endpoints(&self) -> MultiMap {
+ let mut map = MultiMap::new();
+
+ if let Some(CallbackMode { listen, path, .. }) = &self.config.mode.callback {
+ let path = path.clone().unwrap_or_else(default_path);
+ let path = path.trim_end_matches('/');
+ let callback_hmac_key = self
+ .callback_hmac_key
+ .clone()
+ .expect("cannot run subscription in callback mode without a hmac key");
+ let endpoint = Endpoint::from_router_service(
+ format!("{path}/:callback"),
+ CallbackService::new(self.notify.clone(), path.to_string(), callback_hmac_key)
+ .boxed(),
+ );
+ map.insert(listen.clone().unwrap_or_else(default_listen_addr), endpoint);
+ }
+
+ map
+ }
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+#[serde(tag = "kind", rename = "lowercase")]
+pub(crate) enum CallbackPayload {
+ #[serde(rename = "subscription")]
+ Subscription(SubscriptionPayload),
+}
+
+impl CallbackPayload {
+ fn id(&self) -> &String {
+ match self {
+ CallbackPayload::Subscription(subscription_payload) => subscription_payload.id(),
+ }
+ }
+
+ fn verifier(&self) -> &String {
+ match self {
+ CallbackPayload::Subscription(subscription_payload) => subscription_payload.verifier(),
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize, JsonSchema)]
+#[serde(deny_unknown_fields, default)]
+/// Callback payload when a subscription id is incorrect
+pub(crate) struct InvalidIdsPayload {
+ /// List of invalid ids
+ pub(crate) invalid_ids: Vec,
+ pub(crate) id: String,
+ pub(crate) verifier: String,
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+#[serde(tag = "action", rename = "lowercase")]
+pub(crate) enum SubscriptionPayload {
+ #[serde(rename = "check")]
+ Check { id: String, verifier: String },
+ #[serde(rename = "heartbeat")]
+ Heartbeat {
+ /// Id sent with the corresponding verifier
+ id: String,
+ /// List of ids to heartbeat
+ ids: Vec,
+ /// Verifier received with the corresponding id
+ verifier: String,
+ },
+ #[serde(rename = "next")]
+ Next {
+ id: String,
+ payload: Response,
+ verifier: String,
+ },
+ #[serde(rename = "complete")]
+ Complete {
+ id: String,
+ verifier: String,
+ errors: Option>,
+ },
+}
+
+impl SubscriptionPayload {
+ fn id(&self) -> &String {
+ match self {
+ SubscriptionPayload::Check { id, .. }
+ | SubscriptionPayload::Heartbeat { id, .. }
+ | SubscriptionPayload::Next { id, .. }
+ | SubscriptionPayload::Complete { id, .. } => id,
+ }
+ }
+
+ fn verifier(&self) -> &String {
+ match self {
+ SubscriptionPayload::Check { verifier, .. }
+ | SubscriptionPayload::Heartbeat { verifier, .. }
+ | SubscriptionPayload::Next { verifier, .. }
+ | SubscriptionPayload::Complete { verifier, .. } => verifier,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub(crate) struct CallbackService {
+ notify: Notify,
+ path: String,
+ callback_hmac_key: String,
+}
+
+impl CallbackService {
+ pub(crate) fn new(
+ notify: Notify,
+ path: String,
+ callback_hmac_key: String,
+ ) -> Self {
+ Self {
+ notify,
+ path,
+ callback_hmac_key,
+ }
+ }
+}
+
+impl Service for CallbackService {
+ type Response = router::Response;
+ type Error = BoxError;
+ type Future = BoxFuture<'static, Result>;
+
+ fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll> {
+ Ok(()).into()
+ }
+
+ fn call(&mut self, req: router::Request) -> Self::Future {
+ let mut notify = self.notify.clone();
+ let path = self.path.clone();
+ let callback_hmac_key = self.callback_hmac_key.clone();
+ Box::pin(
+ async move {
+ let (parts, body) = req.router_request.into_parts();
+ let sub_id = parts
+ .uri
+ .path()
+ .trim_start_matches(&format!("{path}/"))
+ .to_string();
+
+ match parts.method {
+ Method::POST => {
+ let cb_body = hyper::body::to_bytes(body)
+ .await
+ .map_err(|e| format!("failed to get the request body: {e}"))
+ .and_then(|bytes| {
+ serde_json::from_reader::<_, CallbackPayload>(bytes.reader())
+ .map_err(|err| {
+ format!(
+ "failed to deserialize the request body into JSON: {err}"
+ )
+ })
+ });
+ let cb_body = match cb_body {
+ Ok(cb_body) => cb_body,
+ Err(err) => {
+ return Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::BAD_REQUEST)
+ .body(err.into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ });
+ }
+ };
+ let id = cb_body.id().clone();
+
+ // Hash verifier to sha256 to mitigate timing attack
+ // Check verifier
+ let verifier = cb_body.verifier();
+ let mut verifier_hasher = Sha256::new();
+ verifier_hasher.update(verifier.as_bytes());
+ let hashed_verifier = verifier_hasher.finalize();
+
+ let mut mac = HmacSha256::new_from_slice(callback_hmac_key.as_bytes())?;
+ mac.update(id.as_bytes());
+ let result = mac.finalize();
+ let expected_verifier = hex::encode(result.into_bytes());
+ let mut verifier_hasher = Sha256::new();
+ verifier_hasher.update(expected_verifier.as_bytes());
+ let expected_hashed_verifier = verifier_hasher.finalize();
+
+ if hashed_verifier != expected_hashed_verifier {
+ return Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::UNAUTHORIZED)
+ .body("verifier doesn't match".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ });
+ }
+
+ if let Err(res) = ensure_id_consistency(&req.context, &sub_id, &id) {
+ return Ok(res);
+ }
+
+ match cb_body {
+ CallbackPayload::Subscription(SubscriptionPayload::Next {
+ mut payload,
+ ..
+ }) => {
+ let mut handle = match notify.subscribe_if_exist(id).await? {
+ Some(handle) => handle.into_sink(),
+ None => {
+ return Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body("suscription doesn't exist".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ });
+ }
+ };
+ // Keep the subscription to the client opened
+ payload.subscribed = Some(true);
+ handle.send_sync(payload)?;
+
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::OK)
+ .body::("".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ }
+ CallbackPayload::Subscription(SubscriptionPayload::Check {
+ ..
+ }) => {
+ if notify.exist(id).await? {
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::NO_CONTENT)
+ .body::("".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ } else {
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body("suscription doesn't exist".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ }
+ }
+ CallbackPayload::Subscription(SubscriptionPayload::Heartbeat {
+ ids,
+ id,
+ verifier,
+ }) => {
+ if !ids.contains(&id) {
+ return Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::UNAUTHORIZED)
+ .body("id used for the verifier is not part of ids array".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ });
+ }
+
+ let (mut valid_ids, invalid_ids) = notify.invalid_ids(ids).await?;
+ if invalid_ids.is_empty() {
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::NO_CONTENT)
+ .body::("".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ } else if valid_ids.is_empty() {
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body("suscriptions don't exist".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ } else {
+ let (id, verifier) = if invalid_ids.contains(&id) {
+ (id, verifier)
+ } else {
+ let new_id = valid_ids.pop().expect("valid_ids is not empty, checked in the previous if block");
+ // Generate new verifier
+ let mut mac = HmacSha256::new_from_slice(
+ callback_hmac_key.as_bytes(),
+ )?;
+ mac.update(new_id.as_bytes());
+ let result = mac.finalize();
+ let verifier = hex::encode(result.into_bytes());
+
+ (new_id, verifier)
+ };
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body(serde_json::to_string_pretty(&InvalidIdsPayload{
+ invalid_ids,
+ id,
+ verifier,
+ })?.into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ }
+ }
+ CallbackPayload::Subscription(SubscriptionPayload::Complete {
+ errors,
+ ..
+ }) => {
+ if let Some(errors) = errors {
+ let mut handle =
+ notify.subscribe(id.clone()).await?.into_sink();
+ handle.send_sync(
+ graphql::Response::builder().errors(errors).build(),
+ )?;
+ }
+ notify.force_delete(id).await?;
+ Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .body::("".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ })
+ }
+ }
+ }
+ _ => Ok(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::METHOD_NOT_ALLOWED)
+ .body::("".into())
+ .map_err(BoxError::from)?,
+ context: req.context,
+ }),
+ }
+ }
+ .instrument(tracing::info_span!("subscription_callback")),
+ )
+ }
+}
+
+pub(crate) fn create_verifier(sub_id: &str) -> Result {
+ let callback_hmac_key = SUBSCRIPTION_CALLBACK_HMAC_KEY
+ .get()
+ .ok_or("subscription callback hmac key is not available")?;
+ let mut mac = HmacSha256::new_from_slice(callback_hmac_key.as_bytes())?;
+ mac.update(sub_id.as_bytes());
+ let result = mac.finalize();
+ let verifier = hex::encode(result.into_bytes());
+
+ Ok(verifier)
+}
+
+fn ensure_id_consistency(
+ context: &Context,
+ id_from_path: &str,
+ id_from_body: &str,
+) -> Result<(), router::Response> {
+ (id_from_path != id_from_body)
+ .then(|| {
+ Err(router::Response {
+ response: http::Response::builder()
+ .status(StatusCode::BAD_REQUEST)
+ .body::("id from url path and id from body are different".into())
+ .expect("this body is valid"),
+ context: context.clone(),
+ })
+ })
+ .unwrap_or_else(|| Ok(()))
+}
+
+#[cfg(test)]
+mod tests {
+ use std::str::FromStr;
+
+ use futures::StreamExt;
+ use serde_json::Value;
+ use tower::util::BoxService;
+ use tower::Service;
+ use tower::ServiceExt;
+
+ use super::*;
+ use crate::graphql::Request;
+ use crate::http_ext;
+ use crate::plugin::test::MockSubgraphService;
+ use crate::plugin::DynPlugin;
+ use crate::services::SubgraphRequest;
+ use crate::services::SubgraphResponse;
+ use crate::Notify;
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn it_test_callback_endpoint() {
+ let mut notify = Notify::builder().build();
+ let dyn_plugin: Box = crate::plugin::plugins()
+ .find(|factory| factory.name == APOLLO_SUBSCRIPTION_PLUGIN)
+ .expect("Plugin not found")
+ .create_instance(
+ &Value::from_str(
+ r#"{
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ "subgraphs": ["test"]
+ }
+ }
+ }"#,
+ )
+ .unwrap(),
+ Default::default(),
+ notify.clone(),
+ )
+ .await
+ .unwrap();
+
+ let http_req_prom = http::Request::get("http://localhost:4000/subscription/callback")
+ .body(Default::default())
+ .unwrap();
+ let mut web_endpoint = dyn_plugin
+ .web_endpoints()
+ .into_iter()
+ .next()
+ .unwrap()
+ .1
+ .into_iter()
+ .next()
+ .unwrap()
+ .into_router();
+ let resp = web_endpoint
+ .ready()
+ .await
+ .unwrap()
+ .call(http_req_prom)
+ .await
+ .unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
+ let new_sub_id = uuid::Uuid::new_v4().to_string();
+ let (handler, _created) = notify
+ .create_or_subscribe(new_sub_id.clone(), true)
+ .await
+ .unwrap();
+ let verifier = create_verifier(&new_sub_id).unwrap();
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Check {
+ id: new_sub_id.clone(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NO_CONTENT);
+
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Next {
+ id: new_sub_id.clone(),
+ payload: graphql::Response::builder()
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::OK);
+ let mut handler = handler.into_stream();
+ let msg = handler.next().await.unwrap();
+
+ assert_eq!(
+ msg,
+ graphql::Response::builder()
+ .subscribed(true)
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build()
+ );
+ drop(handler);
+
+ // Should answer NOT FOUND because I dropped the only existing handler and so no one is still listening to the sub
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Next {
+ id: new_sub_id.clone(),
+ payload: graphql::Response::builder()
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
+
+ // Should answer NOT FOUND because I dropped the only existing handler and so no one is still listening to the sub
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(
+ SubscriptionPayload::Heartbeat {
+ id: new_sub_id.clone(),
+ ids: vec![new_sub_id, "FAKE_SUB_ID".to_string()],
+ verifier: verifier.clone(),
+ },
+ ))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn it_test_callback_endpoint_with_bad_verifier() {
+ let mut notify = Notify::builder().build();
+ let dyn_plugin: Box = crate::plugin::plugins()
+ .find(|factory| factory.name == APOLLO_SUBSCRIPTION_PLUGIN)
+ .expect("Plugin not found")
+ .create_instance(
+ &Value::from_str(
+ r#"{
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ "subgraphs": ["test"]
+ }
+ }
+ }"#,
+ )
+ .unwrap(),
+ Default::default(),
+ notify.clone(),
+ )
+ .await
+ .unwrap();
+
+ let http_req_prom = http::Request::get("http://localhost:4000/subscription/callback")
+ .body(Default::default())
+ .unwrap();
+ let mut web_endpoint = dyn_plugin
+ .web_endpoints()
+ .into_iter()
+ .next()
+ .unwrap()
+ .1
+ .into_iter()
+ .next()
+ .unwrap()
+ .into_router();
+ let resp = web_endpoint
+ .ready()
+ .await
+ .unwrap()
+ .call(http_req_prom)
+ .await
+ .unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
+ let new_sub_id = uuid::Uuid::new_v4().to_string();
+ let (_handler, _created) = notify
+ .create_or_subscribe(new_sub_id.clone(), true)
+ .await
+ .unwrap();
+ let verifier = String::from("XXX");
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Check {
+ id: new_sub_id.clone(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
+
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Next {
+ id: new_sub_id.clone(),
+ payload: graphql::Response::builder()
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn it_test_callback_endpoint_with_complete_subscription() {
+ let mut notify = Notify::builder().build();
+ let dyn_plugin: Box = crate::plugin::plugins()
+ .find(|factory| factory.name == APOLLO_SUBSCRIPTION_PLUGIN)
+ .expect("Plugin not found")
+ .create_instance(
+ &Value::from_str(
+ r#"{
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ "subgraphs": ["test"]
+ }
+ }
+ }"#,
+ )
+ .unwrap(),
+ Default::default(),
+ notify.clone(),
+ )
+ .await
+ .unwrap();
+
+ let http_req_prom = http::Request::get("http://localhost:4000/subscription/callback")
+ .body(Default::default())
+ .unwrap();
+ let mut web_endpoint = dyn_plugin
+ .web_endpoints()
+ .into_iter()
+ .next()
+ .unwrap()
+ .1
+ .into_iter()
+ .next()
+ .unwrap()
+ .into_router();
+ let resp = web_endpoint
+ .ready()
+ .await
+ .unwrap()
+ .call(http_req_prom)
+ .await
+ .unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
+ let new_sub_id = uuid::Uuid::new_v4().to_string();
+ let (handler, _created) = notify
+ .create_or_subscribe(new_sub_id.clone(), true)
+ .await
+ .unwrap();
+ let verifier = create_verifier(&new_sub_id).unwrap();
+
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Check {
+ id: new_sub_id.clone(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NO_CONTENT);
+
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Next {
+ id: new_sub_id.clone(),
+ payload: graphql::Response::builder()
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build(),
+ verifier: verifier.clone(),
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::OK);
+ let mut handler = handler.into_stream();
+ let msg = handler.next().await.unwrap();
+
+ assert_eq!(
+ msg,
+ graphql::Response::builder()
+ .subscribed(true)
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build()
+ );
+
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(
+ SubscriptionPayload::Complete {
+ id: new_sub_id.clone(),
+ errors: Some(vec![graphql::Error::builder()
+ .message("cannot complete the subscription")
+ .extension_code("SUBSCRIPTION_ERROR")
+ .build()]),
+ verifier: verifier.clone(),
+ },
+ ))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.clone().oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::ACCEPTED);
+ let msg = handler.next().await.unwrap();
+
+ assert_eq!(
+ msg,
+ graphql::Response::builder()
+ .errors(vec![graphql::Error::builder()
+ .message("cannot complete the subscription")
+ .extension_code("SUBSCRIPTION_ERROR")
+ .build()])
+ .build()
+ );
+
+ // Should answer NOT FOUND because we completed the sub
+ let http_req = http::Request::post(format!(
+ "http://localhost:4000/subscription/callback/{new_sub_id}"
+ ))
+ .body(hyper::Body::from(
+ serde_json::to_vec(&CallbackPayload::Subscription(SubscriptionPayload::Next {
+ id: new_sub_id.clone(),
+ payload: graphql::Response::builder()
+ .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
+ .build(),
+ verifier,
+ }))
+ .unwrap(),
+ ))
+ .unwrap();
+ let resp = web_endpoint.oneshot(http_req).await.unwrap();
+ assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn it_test_subgraph_service_with_subscription_disabled() {
+ let dyn_plugin: Box = crate::plugin::plugins()
+ .find(|factory| factory.name == APOLLO_SUBSCRIPTION_PLUGIN)
+ .expect("Plugin not found")
+ .create_instance(
+ &Value::from_str(r#"{}"#).unwrap(),
+ Default::default(),
+ Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let mut mock_subgraph_service = MockSubgraphService::new();
+ mock_subgraph_service
+ .expect_call()
+ .times(0)
+ .returning(move |req: SubgraphRequest| {
+ Ok(SubgraphResponse::fake_builder()
+ .context(req.context)
+ .build())
+ });
+
+ let mut subgraph_service =
+ dyn_plugin.subgraph_service("my_subgraph_name", BoxService::new(mock_subgraph_service));
+ let subgraph_req = SubgraphRequest::fake_builder()
+ .subgraph_request(
+ http_ext::Request::fake_builder()
+ .body(
+ Request::fake_builder()
+ .query(String::from(
+ "subscription {\n userWasCreated {\n username\n }\n}",
+ ))
+ .build(),
+ )
+ .build()
+ .unwrap(),
+ )
+ .operation_kind(OperationKind::Subscription)
+ .build();
+ let subgraph_response = subgraph_service
+ .ready()
+ .await
+ .unwrap()
+ .call(subgraph_req)
+ .await
+ .unwrap();
+
+ assert_eq!(subgraph_response.response.body(), &graphql::Response::builder().data(serde_json_bytes::Value::Null).error(graphql::Error::builder().message("cannot execute a subscription if it's not enabled in the configuration").extension_code("SUBSCRIPTION_DISABLED").build()).extensions(Object::default()).build());
+ }
+
+ #[test]
+ fn it_test_subscription_config() {
+ let config_with_callback: SubscriptionConfig = serde_json::from_value(serde_json::json!({
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ "subgraphs": ["test"]
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_callback.mode.get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Callback(
+ serde_json::from_value::(serde_json::json!({
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ "subgraphs": []
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_with_callback_default: SubscriptionConfig =
+ serde_json::from_value(serde_json::json!({
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_callback_default
+ .mode
+ .get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Callback(
+ serde_json::from_value::(serde_json::json!({
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ "subgraphs": []
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_with_passthrough: SubscriptionConfig =
+ serde_json::from_value(serde_json::json!({
+ "mode": {
+ "passthrough": {
+ "subgraphs": {
+ "test": {
+ "path": "/ws",
+ }
+ }
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_passthrough.mode.get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Passthrough(
+ serde_json::from_value::(serde_json::json!({
+ "path": "/ws",
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_with_passthrough_override: SubscriptionConfig =
+ serde_json::from_value(serde_json::json!({
+ "mode": {
+ "passthrough": {
+ "all": {
+ "path": "/wss",
+ "protocol": "graphql_transport_ws"
+ },
+ "subgraphs": {
+ "test": {
+ "path": "/ws",
+ }
+ }
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_passthrough_override
+ .mode
+ .get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Passthrough(
+ serde_json::from_value::(serde_json::json!({
+ "path": "/ws",
+ "protocol": "graphql_ws"
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_with_passthrough_all: SubscriptionConfig =
+ serde_json::from_value(serde_json::json!({
+ "mode": {
+ "passthrough": {
+ "all": {
+ "path": "/wss",
+ "protocol": "graphql_transport_ws"
+ },
+ "subgraphs": {
+ "foo": {
+ "path": "/ws",
+ }
+ }
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_passthrough_all.mode.get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Passthrough(
+ serde_json::from_value::(serde_json::json!({
+ "path": "/wss",
+ "protocol": "graphql_transport_ws"
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_with_both_mode: SubscriptionConfig = serde_json::from_value(serde_json::json!({
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ },
+ "passthrough": {
+ "subgraphs": {
+ "foo": {
+ "path": "/ws",
+ }
+ }
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_both_mode.mode.get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Callback(
+ serde_json::from_value::(serde_json::json!({
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_with_passthrough_precedence: SubscriptionConfig =
+ serde_json::from_value(serde_json::json!({
+ "mode": {
+ "preview_callback": {
+ "public_url": "http://localhost:4000",
+ "path": "/subscription/callback",
+ },
+ "passthrough": {
+ "all": {
+ "path": "/wss",
+ "protocol": "graphql_transport_ws"
+ },
+ "subgraphs": {
+ "foo": {
+ "path": "/ws",
+ }
+ }
+ }
+ }
+ }))
+ .unwrap();
+
+ let subgraph_cfg = config_with_passthrough_precedence
+ .mode
+ .get_subgraph_config("test");
+ assert_eq!(
+ subgraph_cfg,
+ Some(SubscriptionMode::Passthrough(
+ serde_json::from_value::(serde_json::json!({
+ "path": "/wss",
+ "protocol": "graphql_transport_ws"
+ }))
+ .unwrap()
+ ))
+ );
+
+ let config_without_mode: SubscriptionConfig =
+ serde_json::from_value(serde_json::json!({})).unwrap();
+
+ let subgraph_cfg = config_without_mode.mode.get_subgraph_config("test");
+ assert_eq!(subgraph_cfg, None);
+ }
+}
+
+register_plugin!("apollo", "subscription", Subscription);
diff --git a/apollo-router/src/plugins/telemetry/apollo.rs b/apollo-router/src/plugins/telemetry/apollo.rs
index 329518d348..5ecef6cfdd 100644
--- a/apollo-router/src/plugins/telemetry/apollo.rs
+++ b/apollo-router/src/plugins/telemetry/apollo.rs
@@ -280,18 +280,25 @@ pub(crate) struct LicensedOperationCountByType {
#[derive(Debug, Serialize, PartialEq, Eq, Hash, Clone, Copy)]
#[serde(rename_all = "kebab-case")]
pub(crate) enum OperationSubType {
- // TODO
+ SubscriptionEvent,
+ SubscriptionRequest,
}
impl OperationSubType {
pub(crate) const fn as_str(&self) -> &'static str {
- ""
+ match self {
+ OperationSubType::SubscriptionEvent => "subscription-event",
+ OperationSubType::SubscriptionRequest => "subscription-request",
+ }
}
}
impl Display for OperationSubType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "")
+ match self {
+ OperationSubType::SubscriptionEvent => write!(f, "subscription-event"),
+ OperationSubType::SubscriptionRequest => write!(f, "subscription-request"),
+ }
}
}
diff --git a/apollo-router/src/plugins/telemetry/metrics/apollo.rs b/apollo-router/src/plugins/telemetry/metrics/apollo.rs
index 0cb9baa92d..f03e526922 100644
--- a/apollo-router/src/plugins/telemetry/metrics/apollo.rs
+++ b/apollo-router/src/plugins/telemetry/metrics/apollo.rs
@@ -73,7 +73,9 @@ mod test {
use crate::plugins::telemetry::apollo::ENDPOINT_DEFAULT;
use crate::plugins::telemetry::apollo_exporter::Sender;
use crate::plugins::telemetry::Telemetry;
+ use crate::plugins::telemetry::OPERATION_KIND;
use crate::plugins::telemetry::STUDIO_EXCLUDE;
+ use crate::query_planner::OperationKind;
use crate::services::SupergraphRequest;
use crate::Context;
use crate::TestHarness;
@@ -115,6 +117,23 @@ mod test {
Ok(())
}
+ #[tokio::test(flavor = "multi_thread")]
+ async fn apollo_metrics_for_subscription() -> Result<(), BoxError> {
+ let query = "subscription {userWasCreated{name}}";
+ let context = Context::new();
+ let _ = context
+ .insert(OPERATION_KIND, OperationKind::Subscription)
+ .unwrap();
+ let results = get_metrics_for_request(query, None, Some(context)).await?;
+ let mut settings = insta::Settings::clone_current();
+ settings.set_sort_maps(true);
+ settings.add_redaction("[].request_id", "[REDACTED]");
+ settings.bind(|| {
+ insta::assert_json_snapshot!(results);
+ });
+ Ok(())
+ }
+
#[tokio::test(flavor = "multi_thread")]
async fn apollo_metrics_multiple_operations() -> Result<(), BoxError> {
let query = "query {topProducts{name}} query {topProducts{name}}";
@@ -248,7 +267,7 @@ mod test {
async fn create_plugin_with_apollo_config(
apollo_config: apollo::Config,
) -> Result {
- Telemetry::new(PluginInit::new(
+ Telemetry::new(PluginInit::fake_new(
config::Conf {
logging: Default::default(),
metrics: None,
diff --git a/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs b/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs
index e5de294ef0..d7a85a5cd8 100644
--- a/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs
+++ b/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs
@@ -277,7 +277,6 @@ mod test {
let metric_1 = create_test_metric("client_1", "version_1", "report_key_1");
let metric_2 = create_test_metric("client_1", "version_1", "report_key_1");
let aggregated_metrics = Report::new(vec![metric_1, metric_2]);
-
insta::with_settings!({sort_maps => true}, {
insta::assert_json_snapshot!(aggregated_metrics);
});
diff --git a/apollo-router/src/plugins/telemetry/metrics/snapshots/apollo_router__plugins__telemetry__metrics__apollo__test__apollo_metrics_for_subscription.snap b/apollo-router/src/plugins/telemetry/metrics/snapshots/apollo_router__plugins__telemetry__metrics__apollo__test__apollo_metrics_for_subscription.snap
new file mode 100644
index 0000000000..c50e070b89
--- /dev/null
+++ b/apollo-router/src/plugins/telemetry/metrics/snapshots/apollo_router__plugins__telemetry__metrics__apollo__test__apollo_metrics_for_subscription.snap
@@ -0,0 +1,61 @@
+---
+source: apollo-router/src/plugins/telemetry/metrics/apollo.rs
+expression: results
+---
+[
+ {
+ "request_id": "[REDACTED]",
+ "stats": {
+ "# -\nsubscription{userWasCreated{name}}": {
+ "stats_with_context": {
+ "context": {
+ "client_name": "test_client",
+ "client_version": "1.0-test",
+ "operation_type": "subscription",
+ "operation_subtype": "subscription-request"
+ },
+ "query_latency_stats": {
+ "latency": {
+ "secs": 0,
+ "nanos": 100000000
+ },
+ "cache_hit": false,
+ "persisted_query_hit": null,
+ "cache_latency": null,
+ "root_error_stats": {
+ "children": {},
+ "errors_count": 0,
+ "requests_with_errors_count": 0
+ },
+ "has_errors": true,
+ "public_cache_ttl_latency": null,
+ "private_cache_ttl_latency": null,
+ "registered_operation": false,
+ "forbidden_operation": false,
+ "without_field_instrumentation": false
+ },
+ "per_type_stat": {}
+ },
+ "referenced_fields_by_type": {
+ "Subscription": {
+ "field_names": [
+ "userWasCreated"
+ ],
+ "is_interface": false
+ },
+ "User": {
+ "field_names": [
+ "name"
+ ],
+ "is_interface": false
+ }
+ }
+ }
+ },
+ "licensed_operation_count_by_type": {
+ "type": "subscription",
+ "subtype": "subscription-request",
+ "licensed_operation_count": 1
+ }
+ }
+]
diff --git a/apollo-router/src/plugins/telemetry/mod.rs b/apollo-router/src/plugins/telemetry/mod.rs
index 02204506df..daa4faf2c7 100644
--- a/apollo-router/src/plugins/telemetry/mod.rs
+++ b/apollo-router/src/plugins/telemetry/mod.rs
@@ -66,6 +66,7 @@ use self::reload::reload_fmt;
use self::reload::reload_metrics;
use self::reload::NullFieldFormatter;
use self::reload::OPENTELEMETRY_TRACER_HANDLE;
+use self::tracing::apollo_telemetry::APOLLO_PRIVATE_DURATION_NS;
use self::tracing::reload::ReloadTracer;
use crate::layers::ServiceBuilderExt;
use crate::plugin::Plugin;
@@ -126,7 +127,9 @@ pub(crate) const EXECUTION_SPAN_NAME: &str = "execution";
const CLIENT_NAME: &str = "apollo_telemetry::client_name";
const CLIENT_VERSION: &str = "apollo_telemetry::client_version";
const SUBGRAPH_FTV1: &str = "apollo_telemetry::subgraph_ftv1";
-const OPERATION_KIND: &str = "apollo_telemetry::operation_kind";
+pub(crate) const OPERATION_KIND: &str = "apollo_telemetry::operation_kind";
+pub(crate) const GRAPHQL_OPERATION_NAME_CONTEXT_KEY: &str =
+ "apollo_telemetry::graphql_operation_name";
pub(crate) const STUDIO_EXCLUDE: &str = "apollo_telemetry::studio::exclude";
pub(crate) const LOGGING_DISPLAY_HEADERS: &str = "apollo_telemetry::logging::display_headers";
pub(crate) const LOGGING_DISPLAY_BODY: &str = "apollo_telemetry::logging::display_body";
@@ -270,7 +273,7 @@ impl Plugin for Telemetry {
let response: Result = fut.await;
span.record(
- "apollo_private.duration_ns",
+ APOLLO_PRIVATE_DURATION_NS,
start.elapsed().as_nanos() as i64,
);
@@ -672,6 +675,11 @@ impl Telemetry {
.operation_name
.as_deref()
.unwrap_or_default();
+ if let Some(operation_name) = &http_request.body().operation_name {
+ let _ = request
+ .context
+ .insert(GRAPHQL_OPERATION_NAME_CONTEXT_KEY, operation_name.clone());
+ }
let span = info_span!(
SUPERGRAPH_SPAN_NAME,
@@ -1079,6 +1087,8 @@ impl Telemetry {
match result {
Err(e) => {
if !matches!(sender, Sender::Noop) {
+ let operation_subtype = (operation_kind == OperationKind::Subscription)
+ .then_some(OperationSubType::SubscriptionRequest);
Self::update_apollo_metrics(
ctx,
field_level_instrumentation_ratio,
@@ -1086,7 +1096,7 @@ impl Telemetry {
true,
start.elapsed(),
operation_kind,
- None,
+ operation_subtype,
);
}
let mut metric_attrs = Vec::new();
@@ -1116,29 +1126,60 @@ impl Telemetry {
}
Ok(router_response) => {
let mut has_errors = !router_response.response.status().is_success();
-
+ if operation_kind == OperationKind::Subscription {
+ Self::update_apollo_metrics(
+ ctx,
+ field_level_instrumentation_ratio,
+ sender.clone(),
+ has_errors,
+ start.elapsed(),
+ operation_kind,
+ Some(OperationSubType::SubscriptionRequest),
+ );
+ }
Ok(router_response.map(move |response_stream| {
let sender = sender.clone();
let ctx = ctx.clone();
response_stream
- .map(move |response| {
+ .enumerate()
+ .map(move |(idx, response)| {
if !response.errors.is_empty() {
has_errors = true;
}
- if !response.has_next.unwrap_or(false)
- && !matches!(sender, Sender::Noop)
- {
- Self::update_apollo_metrics(
- &ctx,
- field_level_instrumentation_ratio,
- sender.clone(),
- has_errors,
- start.elapsed(),
- operation_kind,
- None,
- );
+ if !matches!(sender, Sender::Noop) {
+ if operation_kind == OperationKind::Subscription {
+ // Don't send for the first empty response because it's a heartbeat
+ if idx != 0 {
+ // Only for subscription events
+ Self::update_apollo_metrics(
+ &ctx,
+ field_level_instrumentation_ratio,
+ sender.clone(),
+ has_errors,
+ response
+ .created_at
+ .map(|c| c.elapsed())
+ .unwrap_or_else(|| start.elapsed()),
+ operation_kind,
+ Some(OperationSubType::SubscriptionEvent),
+ );
+ }
+ } else {
+ // If it's the last response
+ if !response.has_next.unwrap_or(false) {
+ Self::update_apollo_metrics(
+ &ctx,
+ field_level_instrumentation_ratio,
+ sender.clone(),
+ has_errors,
+ start.elapsed(),
+ operation_kind,
+ None,
+ );
+ }
+ }
}
response
@@ -1665,6 +1706,7 @@ mod tests {
.create_instance(
&serde_json::json!({"apollo": {"schema_id":"abc"}, "tracing": {}}),
Default::default(),
+ Default::default(),
)
.await
.unwrap();
@@ -1813,6 +1855,7 @@ mod tests {
}
}),
Default::default(),
+ Default::default(),
)
.await
.unwrap();
@@ -1981,6 +2024,7 @@ mod tests {
)
.unwrap(),
Default::default(),
+ Default::default(),
)
.await
.unwrap();
@@ -2262,6 +2306,7 @@ mod tests {
)
.unwrap(),
Default::default(),
+ Default::default(),
)
.await
.unwrap();
diff --git a/apollo-router/src/plugins/telemetry/tracing/apollo_telemetry.rs b/apollo-router/src/plugins/telemetry/tracing/apollo_telemetry.rs
index c0e913001c..5a6ac3db25 100644
--- a/apollo-router/src/plugins/telemetry/tracing/apollo_telemetry.rs
+++ b/apollo-router/src/plugins/telemetry/tracing/apollo_telemetry.rs
@@ -31,8 +31,10 @@ use crate::axum_factory::utils::REQUEST_SPAN_NAME;
use crate::plugins::telemetry;
use crate::plugins::telemetry::apollo::ErrorConfiguration;
use crate::plugins::telemetry::apollo::ErrorsConfiguration;
+use crate::plugins::telemetry::apollo::OperationSubType;
use crate::plugins::telemetry::apollo::SingleReport;
use crate::plugins::telemetry::apollo_exporter::proto;
+use crate::plugins::telemetry::apollo_exporter::proto::reports::trace::http::Method;
use crate::plugins::telemetry::apollo_exporter::proto::reports::trace::http::Values;
use crate::plugins::telemetry::apollo_exporter::proto::reports::trace::query_plan_node::ConditionNode;
use crate::plugins::telemetry::apollo_exporter::proto::reports::trace::query_plan_node::DeferNode;
@@ -58,6 +60,8 @@ use crate::plugins::telemetry::EXECUTION_SPAN_NAME;
use crate::plugins::telemetry::ROUTER_SPAN_NAME;
use crate::plugins::telemetry::SUBGRAPH_SPAN_NAME;
use crate::plugins::telemetry::SUPERGRAPH_SPAN_NAME;
+use crate::query_planner::subscription::SUBSCRIPTION_EVENT_SPAN_NAME;
+use crate::query_planner::OperationKind;
use crate::query_planner::CONDITION_ELSE_SPAN_NAME;
use crate::query_planner::CONDITION_IF_SPAN_NAME;
use crate::query_planner::CONDITION_SPAN_NAME;
@@ -69,7 +73,8 @@ use crate::query_planner::FLATTEN_SPAN_NAME;
use crate::query_planner::PARALLEL_SPAN_NAME;
use crate::query_planner::SEQUENCE_SPAN_NAME;
-const APOLLO_PRIVATE_DURATION_NS: Key = Key::from_static_str("apollo_private.duration_ns");
+pub(crate) const APOLLO_PRIVATE_DURATION_NS: &str = "apollo_private.duration_ns";
+const APOLLO_PRIVATE_DURATION_NS_KEY: Key = Key::from_static_str(APOLLO_PRIVATE_DURATION_NS);
const APOLLO_PRIVATE_SENT_TIME_OFFSET: Key =
Key::from_static_str("apollo_private.sent_time_offset");
const APOLLO_PRIVATE_GRAPHQL_VARIABLES: Key =
@@ -144,8 +149,10 @@ pub(crate) struct Exporter {
errors_configuration: ErrorsConfiguration,
}
+#[derive(Debug)]
enum TreeData {
Request(Result, Error>),
+ SubscriptionEvent(Result, Error>),
Router {
http: Box,
client_name: Option,
@@ -211,7 +218,7 @@ impl Exporter {
duration_ns: 0,
root: None,
details: None,
- http: Some(http),
+ http: (http.method != Method::Unknown as i32).then_some(http),
..Default::default()
};
@@ -226,12 +233,14 @@ impl Exporter {
client_version,
duration_ns,
} => {
- let root_http = root_trace
- .http
- .as_mut()
- .expect("http was extracted earlier, qed");
- root_http.request_headers = http.request_headers;
- root_http.response_headers = http.response_headers;
+ if http.method != Method::Unknown as i32 {
+ let root_http = root_trace
+ .http
+ .as_mut()
+ .expect("http was extracted earlier, qed");
+ root_http.request_headers = http.request_headers;
+ root_http.response_headers = http.response_headers;
+ }
root_trace.client_name = client_name.unwrap_or_default();
root_trace.client_version = client_version.unwrap_or_default();
root_trace.duration_ns = duration_ns;
@@ -249,9 +258,24 @@ impl Exporter {
});
}
TreeData::Execution(operation_type) => {
+ if operation_type == OperationKind::Subscription.as_apollo_operation_type() {
+ root_trace.operation_subtype = if root_trace.http.is_some() {
+ OperationSubType::SubscriptionRequest.to_string()
+ } else {
+ OperationSubType::SubscriptionEvent.to_string()
+ };
+ }
root_trace.operation_type = operation_type;
}
- _ => panic!("should never have had other node types"),
+ TreeData::Trace(_) => {
+ continue;
+ }
+ other => {
+ tracing::error!(
+ "should never have had other node types, current type is: {other:?}"
+ );
+ return Err(Error::TraceParsingFailed);
+ }
}
}
@@ -262,13 +286,14 @@ impl Exporter {
self.extract_data_from_spans(&span)?
.pop()
.and_then(|node| {
- if let TreeData::Request(trace) = node {
- Some(trace)
- } else {
- None
+ match node {
+ TreeData::Request(trace) | TreeData::SubscriptionEvent(trace) => {
+ Some(trace)
+ }
+ _ => None
}
})
- .expect("root trace must exist because it is constructed on the request span, qed")
+ .expect("root trace must exist because it is constructed on the request or subscription_event span, qed")
}
fn extract_data_from_spans(&mut self, span: &LightSpanData) -> Result, Error> {
@@ -401,7 +426,7 @@ impl Exporter {
.and_then(extract_string),
duration_ns: span
.attributes
- .get(&APOLLO_PRIVATE_DURATION_NS)
+ .get(&APOLLO_PRIVATE_DURATION_NS_KEY)
.and_then(extract_i64)
.map(|e| e as u64)
.unwrap_or_default(),
@@ -481,6 +506,48 @@ impl Exporter {
));
child_nodes
}
+ SUBSCRIPTION_EVENT_SPAN_NAME => {
+ // To put the duration
+ child_nodes.push(TreeData::Router {
+ http: Box::new(extract_http_data(span)),
+ client_name: span.attributes.get(&CLIENT_NAME).and_then(extract_string),
+ client_version: span
+ .attributes
+ .get(&CLIENT_VERSION)
+ .and_then(extract_string),
+ duration_ns: span
+ .attributes
+ .get(&APOLLO_PRIVATE_DURATION_NS_KEY)
+ .and_then(extract_i64)
+ .map(|e| e as u64)
+ .unwrap_or_default(),
+ });
+
+ // To put the signature and operation name
+ child_nodes.push(TreeData::Supergraph {
+ operation_signature: span
+ .attributes
+ .get(&APOLLO_PRIVATE_OPERATION_SIGNATURE)
+ .and_then(extract_string)
+ .unwrap_or_default(),
+ operation_name: span
+ .attributes
+ .get(&OPERATION_NAME)
+ .and_then(extract_string)
+ .unwrap_or_default(),
+ variables_json: HashMap::new(),
+ });
+
+ child_nodes.push(TreeData::Execution(
+ OperationKind::Subscription
+ .as_apollo_operation_type()
+ .to_string(),
+ ));
+
+ vec![TreeData::SubscriptionEvent(
+ self.extract_root_trace(span, child_nodes),
+ )]
+ }
_ => child_nodes,
})
}
@@ -629,7 +696,7 @@ impl SpanExporter for Exporter {
// We may get spans that simply don't complete. These need to be cleaned up after a period. It's the price of using ftv1.
let mut traces: Vec<(String, proto::reports::Trace)> = Vec::new();
for span in batch {
- if span.name == REQUEST_SPAN_NAME {
+ if span.name == REQUEST_SPAN_NAME || span.name == SUBSCRIPTION_EVENT_SPAN_NAME {
match self.extract_trace(span.into()) {
Ok(mut trace) => {
let mut operation_signature = Default::default();
@@ -665,6 +732,7 @@ impl SpanExporter for Exporter {
.push(len, span.into());
}
}
+ tracing::info!(value.apollo_router_span_lru_size = self.spans_by_parent_id.len() as u64,);
let mut report = telemetry::apollo::Report::default();
report += SingleReport::Traces(TracesReport { traces });
let exporter = self.report_exporter.clone();
@@ -782,6 +850,7 @@ mod test {
for t in tree_data {
match t {
TreeData::Request(_) => elements.push("request"),
+ TreeData::SubscriptionEvent(_) => elements.push("subscription_event"),
TreeData::Supergraph { .. } => elements.push("supergraph"),
TreeData::QueryPlanNode(_) => elements.push("query_plan_node"),
TreeData::DeferPrimary(_) => elements.push("defer_primary"),
diff --git a/apollo-router/src/plugins/traffic_shaping/mod.rs b/apollo-router/src/plugins/traffic_shaping/mod.rs
index 09c954ec02..f183edf0f4 100644
--- a/apollo-router/src/plugins/traffic_shaping/mod.rs
+++ b/apollo-router/src/plugins/traffic_shaping/mod.rs
@@ -754,7 +754,7 @@ mod test {
)
.unwrap();
- let shaping_config = TrafficShaping::new(PluginInit::new(config, Default::default()))
+ let shaping_config = TrafficShaping::new(PluginInit::fake_builder().config(config).build())
.await
.unwrap();
diff --git a/apollo-router/src/protocols/mod.rs b/apollo-router/src/protocols/mod.rs
new file mode 100644
index 0000000000..c93950fdf2
--- /dev/null
+++ b/apollo-router/src/protocols/mod.rs
@@ -0,0 +1,2 @@
+pub(crate) mod multipart;
+pub(crate) mod websocket;
diff --git a/apollo-router/src/protocols/multipart.rs b/apollo-router/src/protocols/multipart.rs
new file mode 100644
index 0000000000..b617c5ff50
--- /dev/null
+++ b/apollo-router/src/protocols/multipart.rs
@@ -0,0 +1,208 @@
+use std::pin::Pin;
+use std::task::Poll;
+use std::time::Duration;
+
+use bytes::Bytes;
+use futures::stream::select;
+use futures::stream::StreamExt;
+use futures::Stream;
+use serde::Serialize;
+use serde_json_bytes::Value;
+use tokio_stream::wrappers::IntervalStream;
+
+use crate::graphql;
+
+#[cfg(test)]
+const HEARTBEAT_INTERVAL: Duration = Duration::from_millis(10);
+#[cfg(not(test))]
+const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
+
+#[derive(thiserror::Error, Debug)]
+pub(crate) enum Error {
+ #[error("serialization error")]
+ SerdeError(#[from] serde_json::Error),
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub(crate) enum ProtocolMode {
+ Subscription,
+ Defer,
+}
+
+#[derive(Clone, Debug, Serialize)]
+struct SubscriptionPayload {
+ payload: Option,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ errors: Vec,
+}
+
+pub(crate) struct Multipart {
+ stream: Pin> + Send>>,
+ is_first_chunk: bool,
+ is_terminated: bool,
+ mode: ProtocolMode,
+}
+
+impl Multipart {
+ pub(crate) fn new(stream: S, mode: ProtocolMode) -> Self
+ where
+ S: Stream- + Send + 'static,
+ {
+ let stream = match mode {
+ ProtocolMode::Subscription => select(
+ stream.map(Some),
+ IntervalStream::new(tokio::time::interval(HEARTBEAT_INTERVAL)).map(|_| None),
+ )
+ .boxed(),
+ ProtocolMode::Defer => stream.map(Some).boxed(),
+ };
+
+ Self {
+ stream,
+ is_first_chunk: true,
+ is_terminated: false,
+ mode,
+ }
+ }
+}
+
+impl Stream for Multipart {
+ type Item = Result
;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll> {
+ if self.is_terminated {
+ return Poll::Ready(None);
+ }
+ match self.stream.as_mut().poll_next(cx) {
+ Poll::Ready(message) => match message {
+ Some(None) => {
+ // It's the ticker for heartbeat for subscription
+ let buf = if self.is_first_chunk {
+ self.is_first_chunk = false;
+ Bytes::from_static(
+ &b"\r\n--graphql\r\ncontent-type: application/json\r\n\r\n{}\r\n--graphql\r\n"[..]
+ )
+ } else {
+ Bytes::from_static(
+ &b"content-type: application/json\r\n\r\n{}\r\n--graphql\r\n"[..],
+ )
+ };
+
+ Poll::Ready(Some(Ok(buf)))
+ }
+ Some(Some(mut response)) => {
+ let mut buf = if self.is_first_chunk {
+ self.is_first_chunk = false;
+ Vec::from(&b"\r\n--graphql\r\ncontent-type: application/json\r\n\r\n"[..])
+ } else {
+ Vec::from(&b"content-type: application/json\r\n\r\n"[..])
+ };
+ let is_still_open =
+ response.has_next.unwrap_or(false) || response.subscribed.unwrap_or(false);
+ match self.mode {
+ ProtocolMode::Subscription => {
+ let resp = SubscriptionPayload {
+ errors: if is_still_open {
+ Vec::new()
+ } else {
+ response.errors.drain(..).collect()
+ },
+ payload: match response.data {
+ None | Some(Value::Null) => None,
+ _ => response.into(),
+ },
+ };
+
+ serde_json::to_writer(&mut buf, &resp)?;
+ }
+ ProtocolMode::Defer => {
+ serde_json::to_writer(&mut buf, &response)?;
+ }
+ }
+
+ if is_still_open {
+ buf.extend_from_slice(b"\r\n--graphql\r\n");
+ } else {
+ self.is_terminated = true;
+ buf.extend_from_slice(b"\r\n--graphql--\r\n");
+ }
+
+ Poll::Ready(Some(Ok(buf.into())))
+ }
+ None => Poll::Ready(None),
+ },
+ Poll::Pending => Poll::Pending,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use futures::stream;
+ use serde_json_bytes::ByteString;
+
+ use super::*;
+
+ // TODO add test with empty stream
+
+ #[tokio::test]
+ async fn test_heartbeat_and_boundaries() {
+ let responses = vec![
+ graphql::Response::builder()
+ .data(serde_json_bytes::Value::String(ByteString::from(
+ String::from("foo"),
+ )))
+ .subscribed(true)
+ .build(),
+ graphql::Response::builder()
+ .data(serde_json_bytes::Value::String(ByteString::from(
+ String::from("bar"),
+ )))
+ .subscribed(true)
+ .build(),
+ graphql::Response::builder()
+ .data(serde_json_bytes::Value::String(ByteString::from(
+ String::from("foobar"),
+ )))
+ .build(),
+ ];
+ let gql_responses = stream::iter(responses);
+
+ let mut protocol = Multipart::new(gql_responses, ProtocolMode::Subscription);
+ let heartbeat = String::from(
+ "\r\n--graphql\r\ncontent-type: application/json\r\n\r\n{}\r\n--graphql\r\n",
+ );
+ let mut curr_index = 0;
+ while let Some(resp) = protocol.next().await {
+ let res = String::from_utf8(resp.unwrap().to_vec()).unwrap();
+ if res == heartbeat {
+ continue;
+ } else {
+ match curr_index {
+ 0 => {
+ assert_eq!(res, "\r\n--graphql\r\ncontent-type: application/json\r\n\r\n{\"payload\":{\"data\":\"foo\"}}\r\n--graphql\r\n");
+ }
+ 1 => {
+ assert_eq!(
+ res,
+ "content-type: application/json\r\n\r\n{\"payload\":{\"data\":\"bar\"}}\r\n--graphql\r\n"
+ );
+ }
+ 2 => {
+ assert_eq!(
+ res,
+ "content-type: application/json\r\n\r\n{\"payload\":{\"data\":\"foobar\"}}\r\n--graphql--\r\n"
+ );
+ }
+ _ => {
+ panic!("should not happened, test failed");
+ }
+ }
+ curr_index += 1;
+ }
+ }
+ }
+}
diff --git a/apollo-router/src/protocols/websocket.rs b/apollo-router/src/protocols/websocket.rs
new file mode 100644
index 0000000000..d8fe3778db
--- /dev/null
+++ b/apollo-router/src/protocols/websocket.rs
@@ -0,0 +1,815 @@
+use std::pin::Pin;
+use std::task::Poll;
+use std::time::Duration;
+
+use futures::future;
+use futures::Future;
+use futures::Sink;
+use futures::SinkExt;
+use futures::Stream;
+use futures::StreamExt;
+use http::HeaderValue;
+use pin_project_lite::pin_project;
+use schemars::JsonSchema;
+use serde::Deserialize;
+use serde::Serialize;
+use serde_json_bytes::Value;
+use tokio::io::AsyncRead;
+use tokio::io::AsyncWrite;
+use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
+use tokio_tungstenite::tungstenite::protocol::CloseFrame;
+use tokio_tungstenite::tungstenite::Message;
+use tokio_tungstenite::WebSocketStream;
+
+use crate::graphql;
+
+const CONNECTION_ACK_TIMEOUT: Duration = Duration::from_secs(5);
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema, Copy)]
+#[serde(rename_all = "snake_case")]
+pub(crate) enum WebSocketProtocol {
+ // New one
+ GraphqlWs,
+ #[serde(rename = "graphql_transport_ws")]
+ // Old one
+ SubscriptionsTransportWs,
+}
+
+impl Default for WebSocketProtocol {
+ fn default() -> Self {
+ Self::GraphqlWs
+ }
+}
+
+impl From for HeaderValue {
+ fn from(value: WebSocketProtocol) -> Self {
+ match value {
+ WebSocketProtocol::GraphqlWs => HeaderValue::from_static("graphql-transport-ws"),
+ WebSocketProtocol::SubscriptionsTransportWs => HeaderValue::from_static("graphql-ws"),
+ }
+ }
+}
+
+impl WebSocketProtocol {
+ fn subscribe(&self, id: String, payload: graphql::Request) -> ClientMessage {
+ match self {
+ // old
+ WebSocketProtocol::SubscriptionsTransportWs => ClientMessage::OldStart { id, payload },
+ // new
+ WebSocketProtocol::GraphqlWs => ClientMessage::Subscribe { id, payload },
+ }
+ }
+
+ fn complete(&self, id: String) -> ClientMessage {
+ match self {
+ // old
+ WebSocketProtocol::SubscriptionsTransportWs => ClientMessage::OldStop { id },
+ // new
+ WebSocketProtocol::GraphqlWs => ClientMessage::Complete { id },
+ }
+ }
+}
+
+/// A websocket message received from the client
+#[derive(Deserialize, Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+#[allow(clippy::large_enum_variant)] // Request is at fault
+pub(crate) enum ClientMessage {
+ /// A new connection
+ ConnectionInit {
+ /// Optional init payload from the client
+ payload: Option,
+ },
+ /// The start of a Websocket subscription
+ Subscribe {
+ /// Message ID
+ id: String,
+ /// The GraphQL Request - this can be modified by protocol implementors
+ /// to add files uploads.
+ payload: graphql::Request,
+ },
+ #[serde(rename = "start")]
+ /// For old protocol
+ OldStart {
+ /// Message ID
+ id: String,
+ /// The GraphQL Request - this can be modified by protocol implementors
+ /// to add files uploads.
+ payload: graphql::Request,
+ },
+ /// The end of a Websocket subscription
+ Complete {
+ /// Message ID
+ id: String,
+ },
+ /// For old protocol
+ #[serde(rename = "stop")]
+ OldStop {
+ /// Message ID
+ id: String,
+ },
+ /// Connection terminated by the client
+ ConnectionTerminate,
+ /// Useful for detecting failed connections, displaying latency metrics or
+ /// other types of network probing.
+ ///
+ /// Reference:
+ Ping {
+ /// Additional details about the ping.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ payload: Option,
+ },
+ /// The response to the Ping message.
+ ///
+ /// Reference:
+ Pong {
+ /// Additional details about the pong.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ payload: Option,
+ },
+}
+
+#[derive(Deserialize, Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub(crate) enum ServerMessage {
+ ConnectionAck,
+ /// subscriptions-transport-ws protocol alias for next payload
+ #[serde(alias = "data")]
+ /// graphql-ws protocol next payload
+ Next {
+ id: String,
+ payload: graphql::Response,
+ },
+ #[serde(alias = "connection_error")]
+ Error {
+ id: String,
+ payload: ServerError,
+ },
+ Complete {
+ id: String,
+ },
+ #[serde(alias = "ka")]
+ KeepAlive,
+ /// The response to the Ping message.
+ ///
+ /// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong
+ Pong {
+ payload: Option,
+ },
+ Ping {
+ payload: Option,
+ },
+}
+
+#[derive(Deserialize, Serialize, Debug, Clone)]
+#[serde(untagged)]
+pub(crate) enum ServerError {
+ Error(graphql::Error),
+ Errors(Vec),
+}
+
+impl From for Vec {
+ fn from(value: ServerError) -> Self {
+ match value {
+ ServerError::Error(e) => vec![e],
+ ServerError::Errors(e) => e,
+ }
+ }
+}
+
+impl ServerMessage {
+ fn into_graphql_response(self) -> (Option, bool) {
+ match self {
+ ServerMessage::Next { id: _, mut payload } => {
+ payload.subscribed = Some(true);
+ (Some(payload), false)
+ }
+ ServerMessage::Error { id: _, payload } => (
+ Some(
+ graphql::Response::builder()
+ .errors(payload.into())
+ .subscribed(false)
+ .build(),
+ ),
+ true,
+ ),
+ ServerMessage::Complete { .. } => (None, true),
+ ServerMessage::ConnectionAck | ServerMessage::Pong { .. } => (None, false),
+ ServerMessage::Ping { .. } => (None, false),
+ ServerMessage::KeepAlive => (None, false),
+ }
+ }
+
+ fn id(&self) -> Option {
+ match self {
+ ServerMessage::ConnectionAck
+ | ServerMessage::KeepAlive
+ | ServerMessage::Ping { .. }
+ | ServerMessage::Pong { .. } => None,
+ ServerMessage::Next { id, .. }
+ | ServerMessage::Error { id, .. }
+ | ServerMessage::Complete { id } => Some(id.to_string()),
+ }
+ }
+}
+
+pin_project! {
+pub(crate) struct GraphqlWebSocket {
+ #[pin]
+ stream: S,
+ id: String,
+ protocol: WebSocketProtocol,
+ // Booleans for state machine when closing the stream
+ completed: bool,
+ terminated: bool,
+}
+}
+
+impl GraphqlWebSocket
+where
+ S: Stream- > + Sink
+ std::marker::Unpin,
+{
+ pub(crate) async fn new(
+ mut stream: S,
+ id: String,
+ protocol: WebSocketProtocol,
+ connection_params: Option,
+ ) -> Result {
+ let connection_init_msg = match connection_params {
+ Some(connection_params) => ClientMessage::ConnectionInit {
+ payload: Some(serde_json_bytes::json!({
+ "connectionParams": connection_params
+ })),
+ },
+ None => ClientMessage::ConnectionInit { payload: None },
+ };
+ stream.send(connection_init_msg).await.map_err(|_err| {
+ graphql::Error::builder()
+ .message("cannot send connection init through websocket connection")
+ .extension_code("WEBSOCKET_INIT_ERROR")
+ .build()
+ })?;
+
+ let resp = tokio::time::timeout(CONNECTION_ACK_TIMEOUT, stream.next())
+ .await
+ .map_err(|_| {
+ graphql::Error::builder()
+ .message("cannot receive connection ack from websocket connection")
+ .extension_code("WEBSOCKET_ACK_ERROR_TIMEOUT")
+ .build()
+ })?;
+ if !matches!(resp, Some(Ok(ServerMessage::ConnectionAck))) {
+ return Err(graphql::Error::builder()
+ .message("didn't receive the connection ack from websocket connection")
+ .extension_code("WEBSOCKET_ACK_ERROR")
+ .build());
+ }
+
+ Ok(Self {
+ stream,
+ id,
+ protocol,
+ completed: false,
+ terminated: false,
+ })
+ }
+}
+
+#[derive(thiserror::Error, Debug)]
+pub(crate) enum Error {
+ #[error("websocket error")]
+ WebSocketError(#[from] tokio_tungstenite::tungstenite::Error),
+ #[error("deserialization/serialization error")]
+ SerdeError(#[from] serde_json::Error),
+}
+
+pub(crate) fn convert_websocket_stream(
+ stream: WebSocketStream,
+ id: String,
+) -> impl Stream