From 199415d65e543833c56c5ba409d1aabc515b2bf9 Mon Sep 17 00:00:00 2001 From: Marcel <34819524+MarcelCoding@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:57:39 +0100 Subject: [PATCH] Rewrite (#105) * Rewrite * things * Added option to disable masking removed additional buffers * Improved websocket frame parsing, removed dynamic allocations of payload buffers * fixed typo * added messages, added splitting, basic close handling * improved errors * added missing payload len check * added connection cleanup to server * added comment * improved comment * sending close frane, error (not io) occurs while sending * Update README.md (#106) * Update README.md * Update README.md * improved errors * updated readme * moved websocket impl into its own lib `wsocket` * updated dependencies * applied klemens feedback from #107 * applied klemens feedback from #107 --- .editorconfig | 3 + Cargo.lock | 530 ++++++++++++++----------------- Cargo.toml | 1 + README.md | 15 +- zia-client/Cargo.toml | 17 +- zia-client/src/app.rs | 133 ++++++++ zia-client/src/cfg.rs | 9 + zia-client/src/handler.rs | 81 ----- zia-client/src/main.rs | 70 +++- zia-client/src/upstream/mod.rs | 44 --- zia-client/src/upstream/tcp.rs | 151 --------- zia-client/src/upstream/ws.rs | 100 ------ zia-common/Cargo.toml | 15 +- zia-common/src/forwarding/mod.rs | 5 - zia-common/src/forwarding/tcp.rs | 153 --------- zia-common/src/forwarding/ws.rs | 137 -------- zia-common/src/lib.rs | 23 +- zia-common/src/pool.rs | 82 +++++ zia-common/src/read.rs | 98 ++++++ zia-common/src/stream.rs | 173 ---------- zia-common/src/write.rs | 114 +++++++ zia-server/Cargo.toml | 15 +- zia-server/src/cfg.rs | 15 +- zia-server/src/listener/mod.rs | 10 - zia-server/src/listener/tcp.rs | 56 ---- zia-server/src/listener/ws.rs | 57 ---- zia-server/src/main.rs | 136 +++++++- 27 files changed, 917 insertions(+), 1326 deletions(-) create mode 100644 .editorconfig create mode 100644 zia-client/src/app.rs delete mode 100644 zia-client/src/handler.rs delete mode 100644 zia-client/src/upstream/mod.rs delete mode 100644 zia-client/src/upstream/tcp.rs delete mode 100644 zia-client/src/upstream/ws.rs delete mode 100644 zia-common/src/forwarding/mod.rs delete mode 100644 zia-common/src/forwarding/tcp.rs delete mode 100644 zia-common/src/forwarding/ws.rs create mode 100644 zia-common/src/pool.rs create mode 100644 zia-common/src/read.rs delete mode 100644 zia-common/src/stream.rs create mode 100644 zia-common/src/write.rs delete mode 100644 zia-server/src/listener/mod.rs delete mode 100644 zia-server/src/listener/tcp.rs delete mode 100644 zia-server/src/listener/ws.rs diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..0da8f80 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,3 @@ +[*] +indent_size = 2 +indent_style = space diff --git a/Cargo.lock b/Cargo.lock index e8280b1..f70ea7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "anstream" -version = "0.5.0" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f58811cfac344940f1a400b6e6231ce35171f614f26439e80f8c1465c5cc0c" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", @@ -33,15 +33,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.1" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" dependencies = [ "utf8parse", ] @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "2.1.0" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f54d10c6dfa51283a066ceab3ec1ab78d13fae00aa49243a45e4571fb79dfd" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys", @@ -67,9 +67,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.74" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c6f84b74db2535ebae81eede2f39b947dcbf01d093ae5f791e5dd414a1bf289" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" [[package]] name = "async-http-proxy" @@ -77,34 +77,17 @@ version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29faa5d4d308266048bd7505ba55484315a890102f9345b9ff4b87de64201592" dependencies = [ - "base64", + "base64 0.13.1", "httparse", "thiserror", "tokio", ] -[[package]] -name = "async-trait" -version = "0.1.73" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - [[package]] name = "backtrace" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" dependencies = [ "addr2line", "cc", @@ -121,6 +104,12 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.21.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" + [[package]] name = "block-buffer" version = "0.10.4" @@ -130,29 +119,17 @@ dependencies = [ "generic-array", ] -[[package]] -name = "bumpalo" -version = "3.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" - -[[package]] -name = "byteorder" -version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" - [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.82" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "305fe645edc1442a0fa8b6726ba61d422798d37a52e12eaecf4b022ebbb88f01" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "libc", ] @@ -165,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.2" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a13b88d2c62ff462f88e4a121f17a82c1af05693a2f192b5c38d14de73c19f6" +checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" dependencies = [ "clap_builder", "clap_derive", @@ -175,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.2" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bb9faaa7c2ef94b2743a21f5a29e6f0010dff4caa69ac8e9d6cf8b6fa74da08" +checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" dependencies = [ "anstream", "anstyle", @@ -187,9 +164,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.4.2" +version = "4.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", @@ -199,9 +176,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "colorchoice" @@ -211,9 +188,9 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "cpufeatures" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" dependencies = [ "libc", ] @@ -228,12 +205,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "data-encoding" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" - [[package]] name = "digest" version = "0.10.7" @@ -244,6 +215,21 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "fastwebsockets" +version = "0.4.4" +source = "git+https://github.com/MarcelCoding/fastwebsockets?branch=split#fa0521b583d88e88a74ac1e0b50957a4d3244c45" +dependencies = [ + "base64 0.21.5", + "hyper", + "pin-project", + "rand", + "sha1", + "thiserror", + "tokio", + "utf-8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -260,35 +246,36 @@ dependencies = [ ] [[package]] -name = "futures-core" -version = "0.3.28" +name = "futures-channel" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +dependencies = [ + "futures-core", +] [[package]] -name = "futures-sink" -version = "0.3.28" +name = "futures-core" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-core", - "futures-sink", "futures-task", "pin-project-lite", "pin-utils", - "slab", ] [[package]] @@ -303,9 +290,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -314,9 +301,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] name = "heck" @@ -326,27 +313,67 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", "itoa", ] +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "0.14.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.4.10", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "idna" version = "0.4.0" @@ -363,15 +390,6 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" -[[package]] -name = "js-sys" -version = "0.3.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" -dependencies = [ - "wasm-bindgen", -] - [[package]] name = "lazy_static" version = "1.4.0" @@ -380,9 +398,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "log" @@ -392,9 +410,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "miniz_oxide" @@ -407,9 +425,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "wasi", @@ -438,9 +456,9 @@ dependencies = [ [[package]] name = "object" -version = "0.31.1" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -485,9 +503,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12cc1b0bf1727a77a54b6654e7b5f1af8604923edc8b81885f8ec92f9e3f0a05" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -503,18 +521,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -551,17 +569,16 @@ dependencies = [ [[package]] name = "ring" -version = "0.16.20" +version = "0.17.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" dependencies = [ "cc", + "getrandom", "libc", - "once_cell", "spin", "untrusted", - "web-sys", - "winapi", + "windows-sys", ] [[package]] @@ -572,9 +589,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustls" -version = "0.21.6" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", "ring", @@ -584,9 +601,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.4" +version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ "ring", "untrusted", @@ -594,9 +611,9 @@ dependencies = [ [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ "ring", "untrusted", @@ -604,18 +621,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.183" +version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" +checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.183" +version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" +checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" dependencies = [ "proc-macro2", "quote", @@ -624,9 +641,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -635,9 +652,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.4" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" dependencies = [ "lazy_static", ] @@ -652,25 +669,26 @@ dependencies = [ ] [[package]] -name = "slab" -version = "0.4.8" +name = "smallvec" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" -dependencies = [ - "autocfg", -] +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] -name = "smallvec" -version = "1.11.0" +name = "socket2" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" +dependencies = [ + "libc", + "winapi", +] [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", "windows-sys", @@ -678,9 +696,9 @@ dependencies = [ [[package]] name = "spin" -version = "0.5.2" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "strsim" @@ -690,9 +708,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "2.0.28" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", @@ -701,18 +719,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.46" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9207952ae1a003f42d3d5e892dac3c6ba42aa6ac0c79a6a91a2b5cb4253e75c" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.46" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1728216d3244de4f14f14f8c15c79be1a7c67867d28d69b719690e2a19fb445" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", @@ -746,9 +764,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.32.0" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ "backtrace", "bytes", @@ -757,16 +775,16 @@ dependencies = [ "num_cpus", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.5", "tokio-macros", "windows-sys", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -784,24 +802,17 @@ dependencies = [ ] [[package]] -name = "tokio-tungstenite" -version = "0.20.0" +name = "tower-service" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2dbec703c26b00d74844519606ef15d09a7d6857860f84ad223dec002ddea2" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -809,9 +820,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -820,9 +831,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -830,20 +841,20 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ - "lazy_static", "log", + "once_cell", "tracing-core", ] [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -854,29 +865,16 @@ dependencies = [ ] [[package]] -name = "tungstenite" -version = "0.20.0" +name = "try-lock" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e862a1c4128df0112ab625f55cd5c934bcb4312ba80b39ae4b4835a3fd58e649" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http", - "httparse", - "log", - "rand", - "sha1", - "thiserror", - "url", - "utf-8", -] +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" @@ -886,9 +884,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -901,15 +899,15 @@ dependencies = [ [[package]] name = "untrusted" -version = "0.7.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", "idna", @@ -942,74 +940,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.87" +name = "want" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", + "try-lock", ] [[package]] -name = "wasm-bindgen-shared" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" - -[[package]] -name = "web-sys" -version = "0.3.64" +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" -dependencies = [ - "js-sys", - "wasm-bindgen", -] +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "webpki-roots" @@ -1050,9 +993,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1eeca1c172a285ee6c2c84c341ccea837e7c01b12fbb2d0fe3c9e550ce49ec8" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -1065,59 +1008,75 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b10d0c968ba7f6166195e13d593af609ec2e3d24f916f081690695cf5eaffb2f" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "571d8d4e62f26d4932099a9efe89660e8bd5087775a2ab5cdd8b747b811f1058" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2229ad223e178db5fbbc8bd8d3835e51e566b8474bfca58d2e6150c48bb723cd" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600956e2d840c194eedfc5d18f8242bc2e17c7775b6684488af3a9fff6fe3287" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea99ff3f8b49fb7a8e0d305e5aec485bd068c2ba691b6e277d29eaeac945868a" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1a05a1ece9a7a0d5a7ccf30ba2c33e3a61a30e042ffd247567d1de1d94120d" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" -version = "0.48.2" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "wsocket" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d419259aba16b663966e29e6d7c6ecfa0bb8425818bb96f6f1f3c3eb71a6e7b9" +checksum = "c8405de0c5f7bdd3e25b537e91e3bba56b4a8efd9d5599a369ac4104e739023b" +dependencies = [ + "rand", + "thiserror", + "tokio", + "tracing", +] [[package]] name = "zia-client" version = "0.0.0-git" dependencies = [ "anyhow", - "async-trait", + "async-http-proxy", "clap", - "futures-util", + "fastwebsockets", + "hyper", + "once_cell", "tokio", - "tokio-tungstenite", + "tokio-rustls", "tracing", "tracing-subscriber", "url", + "webpki-roots", + "wsocket", "zia-common", ] @@ -1126,16 +1085,10 @@ name = "zia-common" version = "0.0.0-git" dependencies = [ "anyhow", - "async-http-proxy", - "futures-util", - "once_cell", - "pin-project", + "hyper", "tokio", - "tokio-rustls", - "tokio-tungstenite", "tracing", - "url", - "webpki-roots", + "wsocket", ] [[package]] @@ -1143,12 +1096,15 @@ name = "zia-server" version = "0.0.0-git" dependencies = [ "anyhow", - "async-trait", "clap", - "futures-util", + "fastwebsockets", + "hyper", + "once_cell", + "pin-project", "tokio", - "tokio-tungstenite", "tracing", "tracing-subscriber", + "webpki-roots", + "wsocket", "zia-common", ] diff --git a/Cargo.toml b/Cargo.toml index af9ae16..a16ff00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "zia-client", "zia-common", diff --git a/README.md b/README.md index f25d85b..1d2b9d6 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,7 @@ graph LR C ---|UDP| D[Wireguard Server] ``` -The benefit is that Websocket uses Http. If you are in a restricted network where you can only access external services, -and you can only use a provided Http proxy you can proxy your Wireguard Udp traffic over Websocket. +The benefit is that Websocket uses Http. If you are in a restricted network where you can only access external services, and you can only use a provided Http proxy you can proxy your Wireguard Udp traffic over Websocket. ```mermaid graph LR @@ -24,14 +23,12 @@ graph LR ## Mode -| Name | Description | -|-----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Websocket | The UDP datagrams are wrapped inside websocket frames. These frames are then transmitted to the server and there unwrapped. | -| TCP | The UDP datagrams are prefixed with a 8 bit length of the datagram and then transmitted to the server in TCP packages. At the server these packages are unwrapped and forwarded to the actual UDP upstream. | +| Name | Description | +|-----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Websocket | The UDP datagrams are wrapped inside websocket frames. These frames are then transmitted to the server and there unwrapped. | +| TCP | The UDP datagrams are prefixed with a 16 bit length of the datagram and then transmitted to the server in TCP packages. At the server these packages are unwrapped and forwarded to the actual UDP upstream. | -The client is capable of doing an TLSv2 or TLSv3 handshake, the server isn't. The client is also able to do a TLS -handshake between the HTTPS proxy and the Server. In a case where an end to end (zia-client <-> zia-server) TLS -encryption should happen, you have to proxy the traffic for the server using a reverse proxy. +The client is capable of doing a TLSv2 or TLSv3 handshake, the server isn't able to handle TLS requests. In a case where an end to end (zia-client <-> zia-server) TLS encryption should happen, you have to proxy the traffic for the server using a reverse proxy. ## Client diff --git a/zia-client/Cargo.toml b/zia-client/Cargo.toml index 86a63ca..fdec144 100644 --- a/zia-client/Cargo.toml +++ b/zia-client/Cargo.toml @@ -7,20 +7,23 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -tokio = { version = "1.32", default-features = false, features = ["macros", "net", "rt-multi-thread", "signal", "sync"] } +fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } +tokio = { version = "1.34", default-features = false, features = ["rt-multi-thread", "macros", "net", "sync", "time", "signal"] } +async-http-proxy = { version = "1.2", default-features = false, features = ["runtime-tokio", "basic-auth"] } +hyper = { version = "0.14", default-features = false, features = [] } tracing-subscriber = { version = "0.3", features = ["tracing-log"] } -tokio-tungstenite = { version = "0.20", default-features = false, features = ["handshake"] } -futures-util = { version = "0.3", default-features = false } clap = { version = "4.4", features = ["derive", "env"] } +wsocket = { version = "0.1", features = ["client"] } url = { version = "2.4", features = ["serde"] } -async-trait = "0.1" +webpki-roots = "0.25" +tokio-rustls = "0.24" +once_cell = "1.18" tracing = "0.1" anyhow = "1.0" -zia-common = { path = "../zia-common" } +zia-common = { path = '../zia-common' } [package.metadata.generate-rpm] assets = [ - # { source = "target/release/status-node", dest = "/usr/bin/status-node", mode = "0755" }, - { source = "../LICENSE", dest = "/usr/share/doc/zia-client/LICENSE", doc = true, mode = "0644" }, + { source = "../LICENSE", dest = "/usr/share/doc/zia-client/LICENSE", doc = true, mode = "0644" }, ] diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs new file mode 100644 index 0000000..08cacd9 --- /dev/null +++ b/zia-client/src/app.rs @@ -0,0 +1,133 @@ +use std::future::Future; +use std::sync::Arc; + +use anyhow::anyhow; +use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; +use hyper::header::{ + CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, +}; +use hyper::upgrade::Upgraded; +use hyper::{Body, Request}; +use once_cell::sync::Lazy; +use tokio::io::BufStream; +use tokio::net::TcpStream; +use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; +use tokio_rustls::TlsConnector; +use tracing::info; +use url::Url; +use wsocket::WebSocket; + +use zia_common::{ReadConnection, WriteConnection, MAX_DATAGRAM_SIZE}; + +static TLS_CONNECTOR: Lazy<TlsConnector> = Lazy::new(|| { + let mut store = RootCertStore::empty(); + store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints) + })); + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(store) + .with_no_client_auth(); + + TlsConnector::from(Arc::new(config)) +}); + +pub(crate) async fn open_connection( + upstream: &Url, + proxy: &Option<Url>, + websocket_masking: bool, +) -> anyhow::Result<(ReadConnection<Upgraded>, WriteConnection<Upgraded>)> { + let upstream_host = upstream + .host_str() + .ok_or_else(|| anyhow!("Upstream url is missing host"))?; + let upstream_port = upstream + .port_or_known_default() + .ok_or_else(|| anyhow!("Upstream url is missing port"))?; + + let stream = match proxy { + None => { + let stream = TcpStream::connect((upstream_host, upstream_port)).await?; + stream.set_nodelay(true)?; + + info!("Connected to tcp"); + + stream + } + Some(proxy) => { + assert_eq!(proxy.scheme(), "http"); + + let proxy_host = proxy + .host_str() + .ok_or_else(|| anyhow!("Proxy url is missing host"))?; + let proxy_port = proxy + .port_or_known_default() + .ok_or_else(|| anyhow!("Proxy url is missing port"))?; + + let mut stream = TcpStream::connect((proxy_host, proxy_port)).await?; + stream.set_nodelay(true)?; + + info!("Connected to tcp"); + + match proxy.password() { + Some(password) => { + http_connect_tokio_with_basic_auth( + &mut stream, + upstream_host, + upstream_port, + proxy.username(), + password, + ) + .await? + } + None => http_connect_tokio(&mut stream, upstream_host, upstream_port).await?, + }; + + info!("Proxy handshake"); + + stream + } + }; + + let stream = BufStream::new(stream); + + let req = Request::get(upstream.to_string()) + .header(HOST, format!("{}:{}", upstream_host, upstream_port)) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) + .header(SEC_WEBSOCKET_VERSION, "13") + .header(USER_AGENT, "zia") + .body(Body::empty())?; + + let (ws, _) = if upstream.scheme() == "wss" { + let domain = ServerName::try_from(upstream_host)?; + let stream = TLS_CONNECTOR.connect(domain, stream).await?; + info!("Upgraded to tls"); + + fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? + } else { + fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? + }; + + info!("Finished websocket handshake"); + + let ws = WebSocket::client(ws.into_inner(), MAX_DATAGRAM_SIZE, websocket_masking); + + let (read, write) = ws.split(); + + Ok((ReadConnection::new(read), WriteConnection::new(write))) +} + +// Tie hyper's executor to tokio runtime +struct SpawnExecutor; + +impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} diff --git a/zia-client/src/cfg.rs b/zia-client/src/cfg.rs index 9384287..83f9555 100644 --- a/zia-client/src/cfg.rs +++ b/zia-client/src/cfg.rs @@ -12,4 +12,13 @@ pub(crate) struct ClientCfg { pub(crate) upstream: Url, #[arg(short, long, env = "ZIA_PROXY")] pub(crate) proxy: Option<Url>, + #[arg(short, long, env = "ZIA_COUNT")] + pub(crate) count: usize, + #[arg( + short = 'm', + long, + env = "ZIA_WEBSOCKET_MASKING", + default_value = "false" + )] + pub(crate) websocket_masking: bool, } diff --git a/zia-client/src/handler.rs b/zia-client/src/handler.rs deleted file mode 100644 index 77d008f..0000000 --- a/zia-client/src/handler.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; - -use tokio::net::UdpSocket; -use tokio::sync::Mutex; -use tracing::{error, info, warn}; - -use crate::upstream::{Upstream, UpstreamSink, UpstreamStream}; - -pub(crate) struct UdpHandler<U: Upstream> { - upstream: U, - known: Arc<Mutex<HashMap<SocketAddr, U::Sink>>>, -} - -impl<U: Upstream + Send> UdpHandler<U> { - pub(crate) fn new(upstream: U) -> Self { - UdpHandler { - upstream, - known: Arc::new(Mutex::new(HashMap::new())), - } - } - - pub(crate) async fn listen(self, socket: Arc<UdpSocket>) -> anyhow::Result<()> - where - <U as Upstream>::Stream: Send + 'static, - <U as Upstream>::Sink: Send + 'static, - { - let mut buf = U::buffer(); - - loop { - let (read, addr) = socket.recv_from(&mut buf[U::BUF_SKIP..]).await?; - - let mut known = self.known.lock().await; - let mut upstream = match known.entry(addr) { - Entry::Occupied(occupied) => occupied, - Entry::Vacant(vacant) => { - info!("New socket at {}/udp, opening upstream connection...", addr); - - let (sink, mut stream) = match self.upstream.open().await { - Ok(conn) => conn, - Err(err) => { - warn!("Error while opening upstream connection: {err}"); - continue; - } - }; - - let known = self.known.clone(); - let socket = socket.clone(); - - tokio::spawn(async move { - if let Err(err) = stream.connect(&socket, addr).await { - warn!( - "Unable to read from upstream or write to udp socket: {}", - err - ); - } - - if let Some(mut sink) = known.lock().await.remove(&addr) { - if let Err(err) = sink.close().await { - error!("Unable to close upstream sink, closing...: {}", err); - } - } - }); - - vacant.insert_entry(sink) - } - }; - - let sink = upstream.get_mut(); - if let Err(err) = sink.write(&mut buf[..U::BUF_SKIP + read]).await { - warn!("Unable to write to upstream, closing...: {}", err); - - if let Err(err) = upstream.remove().close().await { - error!("Unable to close upstream sink: {}", err); - } - }; - } - } -} diff --git a/zia-client/src/main.rs b/zia-client/src/main.rs index 18c03cf..953c386 100644 --- a/zia-client/src/main.rs +++ b/zia-client/src/main.rs @@ -1,19 +1,22 @@ -#![feature(entry_insert)] - use std::net::SocketAddr; -use clap::Parser; +use std::sync::Arc; +use clap::Parser; use tokio::net::UdpSocket; use tokio::select; use tokio::signal::ctrl_c; +use tokio::sync::RwLock; +use tokio::task::JoinSet; use tracing::info; use url::Url; +use zia_common::{ReadPool, WritePool}; + +use crate::app::open_connection; use crate::cfg::ClientCfg; +mod app; mod cfg; -mod handler; -mod upstream; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -22,13 +25,13 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); select! { - result = tokio::spawn(listen(config.listen_addr, config.upstream, config.proxy)) => { + result = tokio::spawn(listen(config.listen_addr, config.upstream, config.proxy, config.count, config.websocket_masking)) => { result??; info!("Socket closed, quitting..."); }, result = shutdown_signal() => { result?; - info!("Termination signal received, quitting...") + info!("Termination signal received, quitting..."); } } @@ -62,19 +65,52 @@ async fn shutdown_signal() -> anyhow::Result<()> { } } -async fn listen(addr: SocketAddr, upstream: Url, proxy: Option<Url>) -> anyhow::Result<()> { - let inbound = UdpSocket::bind(addr).await?; - info!("Listening on {}/udp", inbound.local_addr()?); +async fn listen( + addr: SocketAddr, + upstream: Url, + proxy: Option<Url>, + connection_count: usize, + websocket_masking: bool, +) -> anyhow::Result<()> { + let socket = Arc::new(UdpSocket::bind(addr).await?); + + let upstream = Arc::new(upstream); + let proxy = Arc::new(proxy); + + let mut conns = JoinSet::new(); + for _ in 0..connection_count { + let upstream = upstream.clone(); + let proxy = proxy.clone(); + conns.spawn(async move { open_connection(&upstream, &proxy, websocket_masking).await }); + } + + let addr = Arc::new(RwLock::new(Option::None)); + + let write_pool = WritePool::new(socket.clone(), addr.clone()); + let read_pool = ReadPool::new(socket, addr); - if let Some(proxy) = &proxy { - info!("Using upstream at {} via proxy {}...", upstream, proxy); - } else { - info!("Using upstream at {}...", upstream); + while let Some(connection) = conns.join_next().await.transpose()? { + let (read, write) = connection?; + read_pool.push(read).await; + write_pool.push(write).await; } - upstream::transmit(inbound, &upstream, &proxy).await?; + info!("Connected to upstream"); - info!("Transmission via {} closed", upstream); + let write_handle = tokio::spawn(async move { + loop { + write_pool.execute().await?; + } + }); - Ok(()) + select! { + result = write_handle => { + info!("Write pool finished"); + result? + }, + result = read_pool.join() => { + info!("Read pool finished"); + result + }, + } } diff --git a/zia-client/src/upstream/mod.rs b/zia-client/src/upstream/mod.rs deleted file mode 100644 index d24d145..0000000 --- a/zia-client/src/upstream/mod.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::net::SocketAddr; - -use anyhow::anyhow; -use async_trait::async_trait; -use tokio::net::UdpSocket; -use url::Url; - -mod tcp; -mod ws; - -pub(crate) async fn transmit( - socket: UdpSocket, - upstream: &Url, - proxy: &Option<Url>, -) -> anyhow::Result<()> { - match upstream.scheme().to_lowercase().as_str() { - "tcp" | "tcps" => tcp::transmit(socket, upstream, proxy).await, - "ws" | "wss" => ws::transmit(socket, upstream, proxy).await, - _ => Err(anyhow!("Unsupported upstream scheme {}", upstream.scheme())), - } -} - -#[async_trait] -pub(crate) trait Upstream { - type Sink: UpstreamSink; - type Stream: UpstreamStream; - - const BUF_SIZE: usize; - const BUF_SKIP: usize; - - fn buffer() -> Box<[u8]>; - async fn open(&self) -> anyhow::Result<(Self::Sink, Self::Stream)>; -} - -#[async_trait] -pub(crate) trait UpstreamSink { - async fn write(&mut self, buf: &mut [u8]) -> anyhow::Result<()>; - async fn close(&mut self) -> anyhow::Result<()>; -} - -#[async_trait] -pub(crate) trait UpstreamStream { - async fn connect(&mut self, socket: &UdpSocket, addr: SocketAddr) -> anyhow::Result<()>; -} diff --git a/zia-client/src/upstream/tcp.rs b/zia-client/src/upstream/tcp.rs deleted file mode 100644 index 5e03dc8..0000000 --- a/zia-client/src/upstream/tcp.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::mem; -use std::net::SocketAddr; -use std::sync::Arc; - -use anyhow::anyhow; -use async_trait::async_trait; -use tokio::io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, UdpSocket}; -use url::Url; - -use zia_common::Stream; - -use crate::handler::UdpHandler; -use crate::upstream::{Upstream, UpstreamSink, UpstreamStream}; - -pub(crate) async fn transmit( - socket: UdpSocket, - upstream: &Url, - proxy: &Option<Url>, -) -> anyhow::Result<()> { - let upstream = TcpUpstream { - url: upstream.clone(), - proxy: proxy.clone(), - }; - - let handler = UdpHandler::new(upstream); - handler.listen(Arc::new(socket)).await -} - -struct TcpUpstream { - url: Url, - proxy: Option<Url>, -} - -struct TcpUpstreamSink { - inner: WriteHalf<Stream<TcpStream>>, -} - -struct TcpUpstreamStream { - inner: ReadHalf<Stream<TcpStream>>, -} - -#[async_trait] -impl Upstream for TcpUpstream { - type Sink = TcpUpstreamSink; - type Stream = TcpUpstreamStream; - - /// A UDP datagram header has a 16 bit field containing an unsigned integer - /// describing the length of the datagram (including the header itself). - /// The max value is 2^16 = 65536 bytes. But since that includes the - /// UDP header, this constant is 8 bytes more than any UDP socket - /// read operation would ever return. We are going to use that extra space. - const BUF_SIZE: usize = u16::MAX as usize; - const BUF_SKIP: usize = mem::size_of::<u16>(); - - fn buffer() -> Box<[u8]> { - Box::new([0_u8; Self::BUF_SIZE]) - } - - async fn open(&self) -> anyhow::Result<(Self::Sink, Self::Stream)> { - let stream = Stream::connect(&self.url, &self.proxy).await?; - let (stream, sink) = split(stream); - - let sink = Self::Sink { inner: sink }; - let stream = Self::Stream { inner: stream }; - - Ok((sink, stream)) - } -} - -#[async_trait] -impl UpstreamSink for TcpUpstreamSink { - async fn write(&mut self, buf: &mut [u8]) -> anyhow::Result<()> { - let len = (buf.len() - TcpUpstream::BUF_SKIP) as u16; - - buf[0..TcpUpstream::BUF_SKIP].copy_from_slice(&len.to_le_bytes()); - - self.inner.write_all(buf).await?; - - Ok(()) - } - - async fn close(&mut self) -> anyhow::Result<()> { - Ok(self.inner.shutdown().await?) - } -} - -#[async_trait] -impl UpstreamStream for TcpUpstreamStream { - async fn connect(&mut self, socket: &UdpSocket, addr: SocketAddr) -> anyhow::Result<()> { - let mut buf = TcpUpstream::buffer(); - let mut unprocessed = 0; - - loop { - let read = self.inner.read(&mut buf[unprocessed..]).await?; - - if read == 0 { - return Err(anyhow!("End of tcp upstream stream.")); - } - - unprocessed += read; - - let processed = forward(socket, addr, &buf[..unprocessed]).await?; - - // discard processed bytes - if unprocessed > processed { - buf.copy_within(processed..unprocessed, 0); - } - - unprocessed -= processed; - } - } -} - -async fn forward(socket: &UdpSocket, addr: SocketAddr, buf: &[u8]) -> anyhow::Result<usize> { - let mut start = 0; - - loop { - let body = start + TcpUpstream::BUF_SKIP; - - let len: [u8; TcpUpstream::BUF_SKIP] = match buf.get(start..body) { - Some(header) => header.try_into().unwrap(), - // not enough bytes for a complete header - None => return Ok(start), - }; - - let len = u16::from_le_bytes(len) as usize; - let end = body + len; - - let data = match buf.get(body..end) { - Some(data) => data, - // not enough bytes for a complete dataframe - None => return Ok(start), - }; - - let written = socket.send_to(data, addr).await?; - assert_eq!(len, written, "Did not send entire UDP datagram"); - - start = end; - } -} - -// Creates and returns a buffer on the heap with enough space to contain any possible -// UDP datagram. -// -// This is put on the heap and in a separate function to avoid the 64k buffer from ending -// up on the stack and blowing up the size of the futures using it. -// #[inline(never)] -// fn datagram_buffer<U: Upstream>() -> Box<[u8; U::BUF_SIZE]> { -// Box::new([0u8; U::BUF_SIZE]) -// } diff --git a/zia-client/src/upstream/ws.rs b/zia-client/src/upstream/ws.rs deleted file mode 100644 index bc52829..0000000 --- a/zia-client/src/upstream/ws.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::mem; -use std::net::SocketAddr; -use std::sync::Arc; - -use async_trait::async_trait; -use futures_util::stream::{SplitSink, SplitStream}; -use futures_util::{SinkExt, StreamExt}; -use tokio::net::{TcpStream, UdpSocket}; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{client_async, WebSocketStream}; -use url::Url; - -use zia_common::Stream; - -use crate::handler::UdpHandler; -use crate::upstream::{Upstream, UpstreamSink, UpstreamStream}; - -pub(crate) async fn transmit( - socket: UdpSocket, - upstream: &Url, - proxy: &Option<Url>, -) -> anyhow::Result<()> { - let upstream = WsUpstream { - url: upstream.clone(), - proxy: proxy.clone(), - }; - - let handler = UdpHandler::new(upstream); - handler.listen(Arc::new(socket)).await -} - -struct WsUpstream { - url: Url, - proxy: Option<Url>, -} - -struct WsUpstreamSink { - inner: SplitSink<WebSocketStream<Stream<TcpStream>>, Message>, -} - -struct WsUpstreamStream { - inner: SplitStream<WebSocketStream<Stream<TcpStream>>>, -} - -#[async_trait] -impl Upstream for WsUpstream { - type Sink = WsUpstreamSink; - type Stream = WsUpstreamStream; - - /// A UDP datagram header has a 16 bit field containing an unsigned integer - /// describing the length of the datagram (including the header itself). - /// The max value is 2^16 = 65536 bytes. But since that includes the - /// UDP header, this constant is 8 bytes more than any UDP socket - /// read operation would ever return. We are going to save that extra space. - const BUF_SIZE: usize = u16::MAX as usize - mem::size_of::<u16>(); - const BUF_SKIP: usize = 0; - - fn buffer() -> Box<[u8]> { - Box::new([0_u8; Self::BUF_SIZE]) - } - - async fn open(&self) -> anyhow::Result<(Self::Sink, Self::Stream)> { - let stream = Stream::connect(&self.url, &self.proxy).await?; - let (stream, _) = client_async(&self.url, stream).await?; - let (sink, stream) = stream.split(); - - let sink = Self::Sink { inner: sink }; - let stream = Self::Stream { inner: stream }; - - Ok((sink, stream)) - } -} - -#[async_trait] -impl UpstreamSink for WsUpstreamSink { - async fn write(&mut self, buf: &mut [u8]) -> anyhow::Result<()> { - Ok(self.inner.send(Message::Binary(buf.to_vec())).await?) - } - - async fn close(&mut self) -> anyhow::Result<()> { - Ok(self.inner.close().await?) - } -} - -#[async_trait] -impl UpstreamStream for WsUpstreamStream { - async fn connect(&mut self, socket: &UdpSocket, addr: SocketAddr) -> anyhow::Result<()> { - loop { - match self.inner.next().await { - Some(Ok(message)) => { - socket.send_to(&message.into_data(), addr).await?; - } - Some(err @ Err(_)) => { - err?; - } - None => {} - } - } - } -} diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml index 50d7adb..8f3d4d3 100644 --- a/zia-common/Cargo.toml +++ b/zia-common/Cargo.toml @@ -2,16 +2,13 @@ name = "zia-common" version = "0.0.0-git" edition = "2021" +authors = ["Marcel <https://m4rc3l.de>"] +license = "AGPL-3.0" +description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -async-http-proxy = { version = "1.2", features = ["runtime-tokio", "basic-auth"] } -tokio = { version = "1.32", default-features = false, features = ["net", "time"] } -tokio-tungstenite = { version = "0.20", default-features = false } -futures-util = { version = "0.3", default-features = false } -tokio-rustls = "0.24" -webpki-roots = "0.25" -pin-project = "1.1" -once_cell = "1.18" +tokio = { version = "1.34", default-features = false, features = ["net", "sync"] } +hyper = { version = "0.14", default-features = false, features = [] } +wsocket = "0.1" tracing = "0.1" anyhow = "1.0" -url = "2.4" diff --git a/zia-common/src/forwarding/mod.rs b/zia-common/src/forwarding/mod.rs deleted file mode 100644 index aa409fc..0000000 --- a/zia-common/src/forwarding/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub use crate::forwarding::tcp::*; -pub use crate::forwarding::ws::*; - -mod tcp; -mod ws; diff --git a/zia-common/src/forwarding/tcp.rs b/zia-common/src/forwarding/tcp.rs deleted file mode 100644 index 05c8f87..0000000 --- a/zia-common/src/forwarding/tcp.rs +++ /dev/null @@ -1,153 +0,0 @@ -// see: https://github.com/mullvad/udp-over-tcp/blob/main/src/forward_traffic.rs - -use std::convert::{Infallible, TryFrom}; -use std::mem; -use std::sync::Arc; - -use anyhow::Context; -use futures_util::future::select; -use futures_util::pin_mut; -use tokio::io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, UdpSocket}; -use tracing::error; - -use crate::{Stream}; - -/// A UDP datagram header has a 16 bit field containing an unsigned integer -/// describing the length of the datagram (including the header itself). -/// The max value is 2^16 = 65536 bytes. But since that includes the -/// UDP header, this constant is 8 bytes more than any UDP socket -/// read operation would ever return. We are going to use that extra space -/// to store our 2 byte udp-over-tcp header. -const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize; -const HEADER_LEN: usize = mem::size_of::<u16>(); - -/// Forward traffic between the given UDP and TCP sockets in both directions. -/// This async function runs until one of the sockets are closed or there is an error. -/// Both sockets are closed before returning. -pub async fn process_udp_over_tcp( - udp_socket: UdpSocket, - tcp_stream: Stream<TcpStream>, -) { - let udp_in = Arc::new(udp_socket); - let udp_out = udp_in.clone(); - let (tcp_in, tcp_out) = split(tcp_stream); - - let tcp2udp = tokio::spawn(async move { - if let Err(error) = process_tcp2udp(tcp_in, udp_out).await { - error!("Error: {}", error); - } - }); - - let udp2tcp = tokio::spawn(async move { - if let Err(error) = process_udp2tcp(udp_in, tcp_out).await { - error!("Error: {}", error); - } - }); - - pin_mut!(tcp2udp); - pin_mut!(udp2tcp); - - // Wait until the UDP->TCP or TCP->UDP future terminates. - select(tcp2udp, udp2tcp).await; -} - -/// Reads from `tcp_in` and extracts UDP datagrams. Writes the datagrams to `udp_out`. -/// Returns if the TCP socket is closed, or an IO error happens on either socket. -async fn process_tcp2udp( - mut tcp_in: ReadHalf<Stream<TcpStream>>, - udp_out: Arc<UdpSocket> -) -> anyhow::Result<()> { - let mut buffer = datagram_buffer(); - // `buffer` has unprocessed data from the TCP socket up until this index. - let mut unprocessed_i = 0; - loop { - let tcp_read_len = tcp_in.read(&mut buffer[unprocessed_i..]) - .await - .context("Failed reading from TCP")?; - if tcp_read_len == 0 { - break; - } - unprocessed_i += tcp_read_len; - - let processed_i = forward_datagrams_in_buffer(&udp_out, &buffer[..unprocessed_i]) - .await - .context("Failed writing to UDP")?; - - // If we have read data that was not forwarded, because it was not a complete datagram, - // move it to the start of the buffer and start over - if unprocessed_i > processed_i { - buffer.copy_within(processed_i..unprocessed_i, 0); - } - unprocessed_i -= processed_i; - } - - Ok(()) -} - -/// Forward all complete datagrams in `buffer` to `udp_out`. -/// Returns the number of processed bytes. -async fn forward_datagrams_in_buffer(udp_out: &UdpSocket, buffer: &[u8]) -> anyhow::Result<usize> { - let mut header_start = 0; - loop { - let header_end = header_start + HEADER_LEN; - // "parse" the header - let header = match buffer.get(header_start..header_end) { - Some(header) => <[u8; HEADER_LEN]>::try_from(header).unwrap(), - // Buffer does not contain entire header for next datagram - None => break Ok(header_start), - }; - let datagram_len = usize::from(u16::from_le_bytes(header)); - let datagram_start = header_end; - let datagram_end = datagram_start + datagram_len; - - let datagram_data = match buffer.get(datagram_start..datagram_end) { - Some(datagram_data) => datagram_data, - // The buffer does not contain the entire datagram - None => break Ok(header_start), - }; - - let udp_write_len = udp_out.send(datagram_data).await?; - assert_eq!( - udp_write_len, datagram_len, - "Did not send entire UDP datagram" - ); - - header_start = datagram_end; - } -} - -/// Reads datagrams from `udp_in` and writes them (with the 16 bit header containing the length) -/// to `tcp_out` indefinitely, or until an IO error happens on either socket. -async fn process_udp2tcp( - udp_in: Arc<UdpSocket>, - mut tcp_out: WriteHalf<Stream<TcpStream>>, -) -> anyhow::Result<Infallible> { - // A buffer large enough to hold any possible UDP datagram plus its 16 bit length header. - let mut buffer = datagram_buffer(); - loop { - let udp_read_len = udp_in - .recv(&mut buffer[HEADER_LEN..]) - .await - .context("Failed reading from UDP")?; - - // Set the "header" to the length of the datagram. - let datagram_len = u16::try_from(udp_read_len).expect("UDP datagram can't be larger than 2^16"); - buffer[..HEADER_LEN].copy_from_slice(&datagram_len.to_le_bytes()[..]); - - tcp_out - .write_all(&buffer[..HEADER_LEN + udp_read_len]) - .await - .context("Failed writing to TCP")?; - } -} - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} diff --git a/zia-common/src/forwarding/ws.rs b/zia-common/src/forwarding/ws.rs deleted file mode 100644 index b556360..0000000 --- a/zia-common/src/forwarding/ws.rs +++ /dev/null @@ -1,137 +0,0 @@ -// see: https://github.com/mullvad/udp-over-tcp/blob/main/src/forward_traffic.rs - -use std::convert::Infallible; -use std::mem; -use std::sync::Arc; - -use anyhow::Context; -use futures_util::future::select; -use futures_util::pin_mut; -use futures_util::stream::{SplitSink, SplitStream}; -use futures_util::{SinkExt, StreamExt}; -use tokio::net::{TcpStream, UdpSocket}; -use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal; -use tokio_tungstenite::tungstenite::protocol::CloseFrame; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::WebSocketStream; -use tracing::error; - -use crate::{Stream}; - -/// A UDP datagram header has a 16 bit field containing an unsigned integer -/// describing the length of the datagram (including the header itself). -/// The max value is 2^16 = 65536 bytes. But since that includes the -/// UDP header, this constant is 8 bytes more than any UDP socket -/// read operation would ever return. We are going to save that extra space. -const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::<u16>(); - -/// Forward traffic between the given UDP and WS sockets in both directions. -/// This async function runs until one of the sockets are closed or there is an error. -/// Both sockets are closed before returning. -pub async fn process_udp_over_ws( - udp_socket: UdpSocket, - ws_stream: WebSocketStream<Stream<TcpStream>>, -) { - let (mut ws_out, mut ws_in) = ws_stream.split(); - - { - let udp_in = Arc::new(udp_socket); - let udp_out = udp_in.clone(); - - let ws2udp = async { - if let Err(error) = process_ws2udp(&mut ws_in, udp_out).await { - error!("Error: {}", error); - } - }; - - let udp2ws = async { - if let Err(error) = process_udp2ws(udp_in, &mut ws_out).await { - error!("Error: {}", error); - } - }; - - pin_mut!(ws2udp); - pin_mut!(udp2ws); - - // Wait until the UDP->WS or WS->UDP future terminates. - select(ws2udp, udp2ws).await; - } - - if let Err(err) = ws_in - .reunite(ws_out) - .unwrap() - .close(Some(CloseFrame { - code: Normal, - reason: "Timeout".into(), - })) - .await - { - error!("Unable to close ws stream: {}", err); - } -} - -/// Reads from `ws_in` and extracts UDP datagrams. Writes the datagrams to `udp_out`. -/// Returns if the WS socket is closed, or an IO error happens on either socket. -async fn process_ws2udp( - ws_in: &mut SplitStream<WebSocketStream<Stream<TcpStream>>>, - udp_out: Arc<UdpSocket>, -) -> anyhow::Result<()> { - while let Some(message) = ws_in.next() - .await - .transpose() - .context("Failed reading from WS")? - { - let data = message.into_data(); - - forward_datagrams_in_buffer(&udp_out, &data) - .await - .context("Failed writing to UDP")?; - } - - Ok(()) -} - -/// Forward the datagram in `buffer` to `udp_out`. -/// Returns the number of processed bytes. -async fn forward_datagrams_in_buffer(udp_out: &UdpSocket, buffer: &Vec<u8>) -> anyhow::Result<()> { - let udp_write_len = udp_out.send(buffer).await?; - assert_eq!( - udp_write_len, - buffer.len(), - "Did not send entire UDP datagram" - ); - - Ok(()) -} - -/// Reads datagrams from `udp_in` and writes them (with the 16 bit header containing the length) -/// to `ws_out` indefinitely, or until an IO error happens on either socket. -async fn process_udp2ws( - udp_in: Arc<UdpSocket>, - ws_out: &mut SplitSink<WebSocketStream<Stream<TcpStream>>, Message>, -) -> anyhow::Result<Infallible> { - // A buffer large enough to hold any possible UDP datagram plus its 16 bit length header. - let mut buffer = datagram_buffer(); - - loop { - let udp_read_len = udp_in - .recv(&mut buffer[..]) - .await - .context("Failed reading from UDP")?; - - ws_out - .send(Message::Binary(buffer[..udp_read_len].to_vec())) - .await - .context("Failed writing to WS")?; - } -} - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} diff --git a/zia-common/src/lib.rs b/zia-common/src/lib.rs index 6619fa8..c064096 100644 --- a/zia-common/src/lib.rs +++ b/zia-common/src/lib.rs @@ -1,5 +1,20 @@ -pub use crate::forwarding::*; -pub use crate::stream::*; +use std::mem; -mod forwarding; -mod stream; +pub use read::*; +pub use write::*; + +mod pool; +mod read; +mod write; + +pub const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::<u16>(); + +/// Creates and returns a buffer on the heap with enough space to contain any possible +/// UDP datagram. +/// +/// This is put on the heap and in a separate function to avoid the 64k buffer from ending +/// up on the stack and blowing up the size of the futures using it. +#[inline(never)] +pub fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { + Box::new([0u8; MAX_DATAGRAM_SIZE]) +} diff --git a/zia-common/src/pool.rs b/zia-common/src/pool.rs new file mode 100644 index 0000000..19ac6c0 --- /dev/null +++ b/zia-common/src/pool.rs @@ -0,0 +1,82 @@ +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use tokio::sync::{mpsc, Mutex}; + +pub trait PoolEntry { + fn is_closed(&self) -> bool; +} + +pub struct PoolGuard<T: PoolEntry> { + inner: Option<T>, + back: mpsc::UnboundedSender<T>, + pool_size: Arc<AtomicUsize>, +} + +unsafe impl<T: Send + PoolEntry> Send for PoolGuard<T> {} + +impl<T: PoolEntry> Drop for PoolGuard<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + if inner.is_closed() { + self.pool_size.fetch_sub(1, Ordering::Relaxed); + } else if let Err(err) = self.back.send(inner) { + panic!("Could not put PoolGuard back to pool: {:?}", err); + } + } + } +} + +impl<T: PoolEntry> Deref for PoolGuard<T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.inner.as_ref().unwrap() + } +} + +impl<T: PoolEntry> DerefMut for PoolGuard<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.as_mut().unwrap() + } +} + +pub struct Pool<T: PoolEntry> { + size: Arc<AtomicUsize>, + tx: mpsc::UnboundedSender<T>, + rx: Mutex<mpsc::UnboundedReceiver<T>>, +} + +impl<T: PoolEntry> Pool<T> { + pub fn new() -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + Self { + size: Arc::new(AtomicUsize::new(0)), + tx, + rx: Mutex::new(rx), + } + } + + pub async fn acquire(&self) -> Option<PoolGuard<T>> { + if self.size.load(Ordering::Relaxed) == 0 { + return None; + } + + let inner = self.rx.lock().await.recv().await.unwrap(); + + Some(PoolGuard { + inner: Some(inner), + back: self.tx.clone(), + pool_size: self.size.clone(), + }) + } +} + +impl<T: PoolEntry> Pool<T> { + pub fn push(&self, inner: T) { + self.size.fetch_add(1, Ordering::Relaxed); + if let Err(err) = self.tx.send(inner) { + panic!("Could not put Inner into to pool: {:?}", err); + } + } +} diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs new file mode 100644 index 0000000..8371b08 --- /dev/null +++ b/zia-common/src/read.rs @@ -0,0 +1,98 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use tokio::io::{AsyncRead, ReadHalf}; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::sync::{Mutex, RwLock}; +use tokio::task::{JoinError, JoinSet}; +use tracing::{error, warn}; +use wsocket::{Message, WebSocket}; + +use crate::datagram_buffer; + +pub struct ReadConnection<R> { + read: WebSocket<ReadHalf<R>>, +} + +impl<R: AsyncRead> ReadConnection<R> { + pub fn new(read: WebSocket<ReadHalf<R>>) -> Self { + Self { read } + } + + async fn handle_frame( + &mut self, + socket: &UdpSocket, + addr: &RwLock<Option<SocketAddr>>, + buf: &mut [u8], + ) -> anyhow::Result<()> { + let message = self.read.recv(buf).await?; + + match message { + Message::Binary(data) => { + let addr = addr.read().await.unwrap(); + socket.send_to(data, addr).await?; + } + _ => unimplemented!(), + } + + Ok(()) + } +} + +pub struct ReadPool { + socket: Arc<UdpSocket>, + addr: Arc<RwLock<Option<SocketAddr>>>, + tasks: Mutex<JoinSet<anyhow::Result<()>>>, +} + +impl ReadPool { + pub fn new(socket: Arc<UdpSocket>, addr: Arc<RwLock<Option<SocketAddr>>>) -> Self { + Self { + socket, + addr, + tasks: Mutex::new(JoinSet::new()), + } + } + + async fn wait_for_connections_to_close(&self) -> Option<Result<anyhow::Result<()>, JoinError>> { + let mut set = self.tasks.lock().await; + select! { + result = set.join_next() => result, + _result = tokio::time::sleep(Duration::from_millis(200)) => Some(Ok(Ok(()))), + } + } + + pub async fn join(&self) -> anyhow::Result<()> { + // hack + loop { + while let Some(result) = self.wait_for_connections_to_close().await { + if let Err(err) = result? { + error!("Error while handling websocket frame: {}", err); + // TODO: close and remove from write pool + } + } + + // hack + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + + pub async fn push<R: AsyncRead + Send + 'static>(&self, mut conn: ReadConnection<R>) { + let socket = self.socket.clone(); + let addr = self.addr.clone(); + + self.tasks.lock().await.spawn(async move { + let mut buf = datagram_buffer(); + loop { + if conn.read.is_closed() { + warn!("Read connection closed"); + // TODO: open new connection on client + return Ok(()); + } + conn.handle_frame(&socket, &addr, buf.as_mut()).await?; + } + }); + } +} diff --git a/zia-common/src/stream.rs b/zia-common/src/stream.rs deleted file mode 100644 index deefca1..0000000 --- a/zia-common/src/stream.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::io::{Error, IoSlice}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use anyhow::anyhow; -use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; -use once_cell::sync::Lazy; -use pin_project::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::net::TcpStream; -use tokio_rustls::client::TlsStream; -use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; -use tokio_rustls::TlsConnector; -use url::Url; - -static TLS_CONNECTOR: Lazy<TlsConnector> = Lazy::new(|| { - let mut store = RootCertStore::empty(); - store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints) - })); - - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(store) - .with_no_client_auth(); - - TlsConnector::from(Arc::new(config)) -}); - -#[pin_project(project = EnumProj)] -pub enum Stream<IO> { - Plain(#[pin] IO), - Tls(#[pin] TlsStream<IO>), - TlsOverTls(#[pin] TlsStream<TlsStream<IO>>), -} - -impl Stream<TcpStream> { - pub async fn connect(upstream: &Url, proxy: &Option<Url>) -> anyhow::Result<Self> { - let upstream_host = upstream - .host_str() - .ok_or_else(|| anyhow!("Upstream url is missing host"))?; - let upstream_port = upstream - .port_or_known_default() - .ok_or_else(|| anyhow!("Upstream url is missing port"))?; - - let mut stream = match proxy { - None => { - let stream = TcpStream::connect((upstream_host, upstream_port)).await?; - stream.set_nodelay(true)?; - - Self::Plain(stream) - } - Some(proxy) => { - let proxy_host = proxy - .host_str() - .ok_or_else(|| anyhow!("Proxy url is missing host"))?; - let proxy_port = proxy - .port_or_known_default() - .ok_or_else(|| anyhow!("Proxy url is missing port"))?; - - let stream = TcpStream::connect((proxy_host, proxy_port)).await?; - stream.set_nodelay(true)?; - - let mut stream = Self::Plain(stream); - - if proxy.scheme() == "https" { - stream = stream.upgrade_to_tls(proxy_host).await?; - }; - - match proxy.password() { - Some(password) => { - http_connect_tokio_with_basic_auth( - &mut stream, - upstream_host, - upstream_port, - proxy.username(), - password, - ) - .await? - } - None => http_connect_tokio(&mut stream, upstream_host, upstream_port).await?, - }; - - stream - } - }; - - if upstream.scheme() == "wss" || upstream.scheme() == "tcps" { - stream = stream.upgrade_to_tls(upstream_host).await?; - } - - Ok(stream) - } -} - -impl<IO: AsyncRead + AsyncWrite + Unpin> Stream<IO> { - pub async fn upgrade_to_tls(self, host: &str) -> anyhow::Result<Self> { - let domain = ServerName::try_from(host)?; - - let stream = match self { - Self::Plain(stream) => Self::Tls(TLS_CONNECTOR.connect(domain, stream).await?), - Self::Tls(stream) => Self::TlsOverTls(TLS_CONNECTOR.connect(domain, stream).await?), - Self::TlsOverTls(_) => unimplemented!(), - }; - - Ok(stream) - } -} - -impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<IO> { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll<std::io::Result<()>> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_read(cx, buf), - EnumProj::Tls(stream) => stream.poll_read(cx, buf), - EnumProj::TlsOverTls(stream) => stream.poll_read(cx, buf), - } - } -} - -impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<IO> { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll<Result<usize, Error>> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_write(cx, buf), - EnumProj::Tls(stream) => stream.poll_write(cx, buf), - EnumProj::TlsOverTls(stream) => stream.poll_write(cx, buf), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_flush(cx), - EnumProj::Tls(stream) => stream.poll_flush(cx), - EnumProj::TlsOverTls(stream) => stream.poll_flush(cx), - } - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_shutdown(cx), - EnumProj::Tls(stream) => stream.poll_shutdown(cx), - EnumProj::TlsOverTls(stream) => stream.poll_shutdown(cx), - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll<Result<usize, Error>> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_write_vectored(cx, bufs), - EnumProj::Tls(stream) => stream.poll_write_vectored(cx, bufs), - EnumProj::TlsOverTls(stream) => stream.poll_write_vectored(cx, bufs), - } - } - - fn is_write_vectored(&self) -> bool { - match self { - Self::Plain(stream) => stream.is_write_vectored(), - Self::Tls(stream) => stream.is_write_vectored(), - Self::TlsOverTls(stream) => stream.is_write_vectored(), - } - } -} diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs new file mode 100644 index 0000000..ccc2744 --- /dev/null +++ b/zia-common/src/write.rs @@ -0,0 +1,114 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use tokio::io::{AsyncWrite, WriteHalf}; +use tokio::net::UdpSocket; +use tokio::sync::RwLock; +use tracing::{error, warn}; +use wsocket::{Message, WebSocket}; + +use crate::pool::{Pool, PoolEntry}; +use crate::{datagram_buffer, MAX_DATAGRAM_SIZE}; + +pub struct WriteConnection<W> { + write: WebSocket<WriteHalf<W>>, + buf: Box<[u8; MAX_DATAGRAM_SIZE]>, +} + +impl<W: AsyncWrite> WriteConnection<W> { + pub fn new(write: WebSocket<WriteHalf<W>>) -> Self { + Self { + buf: datagram_buffer(), + write, + } + } + + async fn flush(&mut self, size: usize) -> anyhow::Result<()> { + assert!(size <= MAX_DATAGRAM_SIZE); + + let message = Message::Binary(&self.buf[..size]); + self.write.send(message).await?; + + Ok(()) + } +} + +impl<W> PoolEntry for WriteConnection<W> { + fn is_closed(&self) -> bool { + self.write.is_closed() + // TODO: open new connection on client - maybe fancy login in "abstract" pool + } +} + +pub struct WritePool<W> { + socket: Arc<UdpSocket>, + pool: Pool<WriteConnection<W>>, + addr: Arc<RwLock<Option<SocketAddr>>>, +} + +impl<W: AsyncWrite + Send + 'static> WritePool<W> { + pub fn new(socket: Arc<UdpSocket>, addr: Arc<RwLock<Option<SocketAddr>>>) -> Self { + Self { + socket, + pool: Pool::new(), + addr, + } + } + + async fn update_addr(&self, addr: SocketAddr) { + let is_outdated = self + .addr + .read() + .await + .map(|last_addr| last_addr != addr) + .unwrap_or(true); + + if is_outdated { + *(self.addr.write().await) = Some(addr); + } + } + + pub async fn push(&self, conn: WriteConnection<W>) { + self.pool.push(conn); + } + + pub async fn execute(&self) -> anyhow::Result<()> { + loop { + let conn = self.pool.acquire().await; + + // TODO: + // maybe just block until it is not empty anymore + // .revc() in self.pool.acquire() would be "blocking" (asynchronous) + // until a connection becomes available, therefore + // this would be appropriate + // - better would be an option to enable the blocking + // only on the server and on the client a `None` + // returned would open a new connection + let mut conn = match conn { + Some(conn) => conn, + None => { + warn!("Write pool is empty, waiting 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + + if conn.is_closed() { + continue; + } + + // read from udp socket and save to buf of selected conn + let (read, addr) = self.socket.recv_from(conn.buf.as_mut()).await.unwrap(); + + self.update_addr(addr).await; + + // flush buf of conn asynchronously to read again from udp socket in parallel + tokio::spawn(async move { + if let Err(err) = conn.flush(read).await { + error!("Unable to flush websocket buf: {:?}", err); + } + }); + } + } +} diff --git a/zia-server/Cargo.toml b/zia-server/Cargo.toml index 6861e3e..d9f540e 100644 --- a/zia-server/Cargo.toml +++ b/zia-server/Cargo.toml @@ -7,16 +7,20 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -tokio = { version = "1.32", default-features = false, features = ["macros", "rt-multi-thread", "net", "signal", "sync"] } -tokio-tungstenite = { version = "0.20", default-features = false, features = ["handshake"] } +tokio = { version = "1.34", default-features = false, features = ["rt-multi-thread", "macros", "net","sync", "time", "signal"] } +fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } +hyper = { version = "0.14", default-features = false, features = ["tcp"] } tracing-subscriber = { version = "0.3", features = ["tracing-log"] } -futures-util = { version = "0.3", default-features = false } clap = { version = "4.4", features = ["derive", "env"] } -zia-common = { path = "../zia-common" } -async-trait = "0.1" +webpki-roots = "0.25" +pin-project = "1.1" +once_cell = "1.18" +wsocket = "0.1" tracing = "0.1" anyhow = "1.0" +zia-common = { path = '../zia-common' } + [package.metadata.deb] maintainer-scripts = "debian/" systemd-units = { enable = false } @@ -27,6 +31,5 @@ assets = [ [package.metadata.generate-rpm] assets = [ - # { source = "target/release/status-node", dest = "/usr/bin/status-node", mode = "0755" }, { source = "../LICENSE", dest = "/usr/share/doc/zia-server/LICENSE", doc = true, mode = "0644" }, ] diff --git a/zia-server/src/cfg.rs b/zia-server/src/cfg.rs index 975e0a6..5f21587 100644 --- a/zia-server/src/cfg.rs +++ b/zia-server/src/cfg.rs @@ -5,26 +5,33 @@ use clap::{Parser, ValueEnum}; #[derive(Parser)] #[clap(version)] -pub(crate) struct ClientCfg { +pub(crate) struct ServerCfg { #[arg(short, long, env = "ZIA_LISTEN_ADDR", default_value = "0.0.0.0:1234")] pub(crate) listen_addr: SocketAddr, #[arg(short, long, env = "ZIA_UPSTREAM")] pub(crate) upstream: String, - #[arg(short, long, env = "ZIA_MODE", default_value = "WS", value_enum, ignore_case(true))] + #[arg( + short, + long, + env = "ZIA_MODE", + default_value = "WS", + value_enum, + ignore_case(true) + )] pub(crate) mode: Mode, } #[derive(ValueEnum, Clone)] pub(crate) enum Mode { Ws, - Tcp, + // Tcp, } impl Display for Mode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Ws => f.write_str("ws"), - Self::Tcp => f.write_str("tcp"), + // Self::Tcp => f.write_str("tcp"), } } } diff --git a/zia-server/src/listener/mod.rs b/zia-server/src/listener/mod.rs deleted file mode 100644 index 4fe6811..0000000 --- a/zia-server/src/listener/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub(crate) use self::tcp::*; -pub(crate) use self::ws::*; - -mod tcp; -mod ws; - -#[async_trait::async_trait] -pub(crate) trait Listener { - async fn listen(&self, upstream: &str) -> anyhow::Result<()>; -} diff --git a/zia-server/src/listener/tcp.rs b/zia-server/src/listener/tcp.rs deleted file mode 100644 index a14c7b0..0000000 --- a/zia-server/src/listener/tcp.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::net::SocketAddr; - -use tokio::net::{TcpStream, UdpSocket}; -use tracing::{error, info}; - -use zia_common::process_udp_over_tcp; -use zia_common::Stream::Plain; - -use crate::listener::Listener; - -pub(crate) struct TcpListener { - pub(crate) addr: SocketAddr, -} - -#[async_trait::async_trait] -impl Listener for TcpListener { - async fn listen(&self, upstream: &str) -> anyhow::Result<()> { - let listener = tokio::net::TcpListener::bind(self.addr).await?; - - loop { - let (sock, _) = listener.accept().await?; - let upstream = upstream.to_string(); - - tokio::spawn(async move { - if let Err(err) = Self::handle(sock, upstream).await { - error!("Error while handling connection: {:?}", err); - } - }); - } - } -} - -impl TcpListener { - async fn handle(downstream: TcpStream, upstream_addr: String) -> anyhow::Result<()> { - downstream.set_nodelay(true)?; - let downstream_addr = downstream.peer_addr()?; - info!("New downstream connection: {}", downstream_addr); - - let upstream = UdpSocket::bind("0.0.0.0:0").await?; // TODO: maybe make this configurable - - upstream.connect(upstream_addr).await?; - - info!( - "Connected to udp upstream (local: {}/udp, peer: {}/udp) for downstream {}", - upstream.local_addr()?, - upstream.peer_addr()?, - downstream_addr - ); - - process_udp_over_tcp(upstream, Plain(downstream)).await; - - info!("Connection with downstream {} closed...", downstream_addr); - - Ok(()) - } -} diff --git a/zia-server/src/listener/ws.rs b/zia-server/src/listener/ws.rs deleted file mode 100644 index 02a066e..0000000 --- a/zia-server/src/listener/ws.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::net::SocketAddr; - -use tokio::net::{TcpListener, TcpStream, UdpSocket}; -use tracing::{error, info}; -use zia_common::process_udp_over_ws; -use zia_common::Stream::Plain; - -use crate::listener::Listener; - -pub(crate) struct WsListener { - pub(crate) addr: SocketAddr, -} - -#[async_trait::async_trait] -impl Listener for WsListener { - async fn listen(&self, upstream: &str) -> anyhow::Result<()> { - let listener = TcpListener::bind(self.addr).await?; - info!("Listening on ws://{}...", listener.local_addr()?); - - loop { - let (sock, _) = listener.accept().await?; - let upstream = upstream.to_string(); - - tokio::spawn(async move { - if let Err(err) = Self::handle(sock, upstream).await { - error!("Error while handling connection: {:?}", err); - } - }); - } - } -} - -impl WsListener { - async fn handle(downstream: TcpStream, upstream_addr: String) -> anyhow::Result<()> { - downstream.set_nodelay(true)?; - let downstream_addr = downstream.peer_addr()?; - info!("New downstream connection: {}", downstream_addr); - let downstream = tokio_tungstenite::accept_async(Plain(downstream)).await?; - - let upstream = UdpSocket::bind("0.0.0.0:0").await?; // TODO: maybe make this configurable - - upstream.connect(upstream_addr).await?; - - info!( - "Connected to udp upstream (local: {}/udp, peer: {}/udp) for downstream {}", - upstream.local_addr()?, - upstream.peer_addr()?, - downstream_addr - ); - - process_udp_over_ws(upstream, downstream).await; - - info!("Connection with downstream {} closed...", downstream_addr); - - Ok(()) - } -} diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index 83b4fe1..9b292c6 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -1,39 +1,143 @@ +use std::convert::Infallible; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + use clap::Parser; +use hyper::service::{make_service_fn, Service}; +use hyper::upgrade::Upgraded; +use hyper::{Body, Request, Response, Server, StatusCode}; +use tokio::net::UdpSocket; use tokio::select; use tokio::signal::ctrl_c; +use tokio::sync::RwLock; +use tokio::task::JoinHandle; use tracing::info; +use wsocket::WebSocket; + +use zia_common::{ReadConnection, ReadPool, WriteConnection, WritePool, MAX_DATAGRAM_SIZE}; -use crate::cfg::{ClientCfg, Mode}; -use crate::listener::{Listener, TcpListener, WsListener}; +use crate::cfg::ServerCfg; mod cfg; -mod listener; + +#[pin_project::pin_project] +struct HandleRequestFuture { + req: Request<Body>, + read: Arc<ReadPool>, + write: Arc<WritePool<Upgraded>>, +} + +impl Future for HandleRequestFuture { + type Output = Result<Response<Body>, Infallible>; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + + if !fastwebsockets::upgrade::is_upgrade_request(this.req) { + return Poll::Ready(Ok( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("bad request: expected websocket upgrade")) + .unwrap(), + )); + } + + let (resp, upgrade) = fastwebsockets::upgrade::upgrade(this.req).unwrap(); + + let cloned_read = this.read.clone(); + let cloned_write = this.write.clone(); + + tokio::spawn(async move { + let ws = upgrade.await.unwrap().into_inner(); + + let ws = WebSocket::server(ws, MAX_DATAGRAM_SIZE); + let (read, write) = ws.split(); + + cloned_read.push(ReadConnection::new(read)).await; + cloned_write.push(WriteConnection::new(write)).await; + }); + + Poll::Ready(Ok(resp)) + } +} + +// mod app; +struct ConnectionHandler { + read: Arc<ReadPool>, + write: Arc<WritePool<Upgraded>>, +} + +impl Service<Request<Body>> for ConnectionHandler { + type Response = Response<Body>; + type Error = Infallible; + type Future = HandleRequestFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<Body>) -> Self::Future { + HandleRequestFuture { + req, + read: self.read.clone(), + write: self.write.clone(), + } + } +} #[tokio::main] async fn main() -> anyhow::Result<()> { - let config = ClientCfg::parse(); + let config = ServerCfg::parse(); tracing_subscriber::fmt::init(); - let listener: Box<dyn Listener> = match config.mode { - Mode::Ws => Box::new(WsListener { - addr: config.listen_addr, - }), - Mode::Tcp => Box::new(TcpListener { - addr: config.listen_addr, - }), - }; + let socket = Arc::new(UdpSocket::bind(Into::<SocketAddr>::into(([0, 0, 0, 0], 0))).await?); + socket.connect(&config.upstream).await?; + info!("Connected to upstream udp://{}...", config.upstream); + + let addr = Arc::new(RwLock::new(Some(socket.peer_addr()?))); + let write_pool = Arc::new(WritePool::new(socket.clone(), addr.clone())); + let read_pool = Arc::new(ReadPool::new(socket, addr)); + + let wp = write_pool.clone(); + let rp = read_pool.clone(); + + let make_service = make_service_fn(|_conn| { + let read = rp.clone(); + let write = wp.clone(); + + async move { Ok::<_, Infallible>(ConnectionHandler { read, write }) } + }); + + let server = Server::bind(&config.listen_addr).serve(make_service); - info!("Listening in {}://{}...", config.mode, config.listen_addr); + info!("Listening on {}://{}...", config.mode, config.listen_addr); + + let write_handle: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move { + loop { + write_pool.execute().await?; + } + }); select! { - result = listener.listen(&config.upstream) => { - result?; + result = server => { info!("Socket closed, quitting..."); + result?; }, - result = shutdown_signal() => { + result = write_handle => { + info!("Write pool finished"); + result??; + }, + result = read_pool.join() => { + info!("Read pool finished"); result?; + }, + result = shutdown_signal() => { info!("Termination signal received, quitting..."); + result?; } }