From 4d6f1b21a1e2eb9726c1a5bb57f0ecb625a44316 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 29 Sep 2023 16:33:49 +0200 Subject: [PATCH 01/19] Rewrite --- Cargo.lock | 359 +++++++++++++++---------------- Cargo.toml | 1 - zia-client/Cargo.toml | 15 +- zia-client/src/app.rs | 212 ++++++++++++++++++ zia-client/src/cfg.rs | 2 + zia-client/src/handler.rs | 81 ------- zia-client/src/main.rs | 112 ++++++++-- zia-client/src/upstream/mod.rs | 44 ---- zia-client/src/upstream/tcp.rs | 151 ------------- zia-client/src/upstream/ws.rs | 100 --------- zia-common/Cargo.toml | 17 -- zia-common/src/forwarding/mod.rs | 5 - zia-common/src/forwarding/tcp.rs | 153 ------------- zia-common/src/forwarding/ws.rs | 137 ------------ zia-common/src/lib.rs | 5 - zia-common/src/stream.rs | 173 --------------- zia-server/Cargo.toml | 14 +- zia-server/src/app.rs | 194 +++++++++++++++++ zia-server/src/listener/mod.rs | 10 - zia-server/src/listener/tcp.rs | 56 ----- zia-server/src/listener/ws.rs | 57 ----- zia-server/src/main.rs | 8 +- 22 files changed, 693 insertions(+), 1213 deletions(-) create mode 100644 zia-client/src/app.rs delete mode 100644 zia-client/src/handler.rs delete mode 100644 zia-client/src/upstream/mod.rs delete mode 100644 zia-client/src/upstream/tcp.rs delete mode 100644 zia-client/src/upstream/ws.rs delete mode 100644 zia-common/Cargo.toml delete mode 100644 zia-common/src/forwarding/mod.rs delete mode 100644 zia-common/src/forwarding/tcp.rs delete mode 100644 zia-common/src/forwarding/ws.rs delete mode 100644 zia-common/src/lib.rs delete mode 100644 zia-common/src/stream.rs create mode 100644 zia-server/src/app.rs delete mode 100644 zia-server/src/listener/mod.rs delete mode 100644 zia-server/src/listener/tcp.rs delete mode 100644 zia-server/src/listener/ws.rs diff --git a/Cargo.lock b/Cargo.lock index e8280b1..933026e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "anstream" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f58811cfac344940f1a400b6e6231ce35171f614f26439e80f8c1465c5cc0c" +checksum = "f6cd65a4b849ace0b7f6daeebcc1a1d111282227ca745458c61dbf670e52a597" dependencies = [ "anstyle", "anstyle-parse", @@ -33,15 +33,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.1" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" dependencies = [ "utf8parse", ] @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "2.1.0" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f54d10c6dfa51283a066ceab3ec1ab78d13fae00aa49243a45e4571fb79dfd" +checksum = "0238ca56c96dfa37bdf7c373c8886dd591322500aceeeccdb2216fe06dc2f796" dependencies = [ "anstyle", "windows-sys", @@ -67,9 +67,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.74" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c6f84b74db2535ebae81eede2f39b947dcbf01d093ae5f791e5dd414a1bf289" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" [[package]] name = "async-http-proxy" @@ -77,34 +77,17 @@ version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29faa5d4d308266048bd7505ba55484315a890102f9345b9ff4b87de64201592" dependencies = [ - "base64", + "base64 0.13.1", "httparse", "thiserror", "tokio", ] -[[package]] -name = "async-trait" -version = "0.1.73" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - [[package]] name = "backtrace" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" dependencies = [ "addr2line", "cc", @@ -121,6 +104,12 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" + [[package]] name = "block-buffer" version = "0.10.4" @@ -132,27 +121,21 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" - -[[package]] -name = "byteorder" -version = "1.4.3" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.82" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "305fe645edc1442a0fa8b6726ba61d422798d37a52e12eaecf4b022ebbb88f01" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "libc", ] @@ -165,9 +148,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.2" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a13b88d2c62ff462f88e4a121f17a82c1af05693a2f192b5c38d14de73c19f6" +checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" dependencies = [ "clap_builder", "clap_derive", @@ -175,9 +158,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.2" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bb9faaa7c2ef94b2743a21f5a29e6f0010dff4caa69ac8e9d6cf8b6fa74da08" +checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" dependencies = [ "anstream", "anstyle", @@ -199,9 +182,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "colorchoice" @@ -228,12 +211,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "data-encoding" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" - [[package]] name = "digest" version = "0.10.7" @@ -244,6 +221,21 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "fastwebsockets" +version = "0.4.4" +source = "git+https://github.com/MarcelCoding/fastwebsockets?branch=split#fa0521b583d88e88a74ac1e0b50957a4d3244c45" +dependencies = [ + "base64 0.21.4", + "hyper", + "pin-project", + "rand", + "sha1", + "thiserror", + "tokio", + "utf-8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -260,16 +252,19 @@ dependencies = [ ] [[package]] -name = "futures-core" +name = "futures-channel" version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", +] [[package]] -name = "futures-sink" +name = "futures-core" version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" [[package]] name = "futures-task" @@ -284,11 +279,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-core", - "futures-sink", "futures-task", "pin-project-lite", "pin-utils", - "slab", ] [[package]] @@ -314,9 +307,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] name = "heck" @@ -326,9 +319,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "http" @@ -341,12 +334,51 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "0.14.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "idna" version = "0.4.0" @@ -380,9 +412,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "log" @@ -392,9 +424,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "miniz_oxide" @@ -438,9 +470,9 @@ dependencies = [ [[package]] name = "object" -version = "0.31.1" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -485,9 +517,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12cc1b0bf1727a77a54b6654e7b5f1af8604923edc8b81885f8ec92f9e3f0a05" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -503,18 +535,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -572,9 +604,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustls" -version = "0.21.6" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ "log", "ring", @@ -584,9 +616,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.4" +version = "0.101.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" +checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" dependencies = [ "ring", "untrusted", @@ -604,18 +636,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.183" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.183" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", @@ -624,9 +656,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -635,9 +667,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +checksum = "c1b21f559e07218024e7e9f90f96f601825397de0e25420135f7f952453fed0b" dependencies = [ "lazy_static", ] @@ -651,26 +683,17 @@ dependencies = [ "libc", ] -[[package]] -name = "slab" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" -dependencies = [ - "autocfg", -] - [[package]] name = "smallvec" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" dependencies = [ "libc", "windows-sys", @@ -690,9 +713,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "2.0.28" +version = "2.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" dependencies = [ "proc-macro2", "quote", @@ -701,18 +724,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.46" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9207952ae1a003f42d3d5e892dac3c6ba42aa6ac0c79a6a91a2b5cb4253e75c" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.46" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1728216d3244de4f14f14f8c15c79be1a7c67867d28d69b719690e2a19fb445" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", @@ -784,16 +807,10 @@ dependencies = [ ] [[package]] -name = "tokio-tungstenite" -version = "0.20.0" +name = "tower-service" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2dbec703c26b00d74844519606ef15d09a7d6857860f84ad223dec002ddea2" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" @@ -854,29 +871,16 @@ dependencies = [ ] [[package]] -name = "tungstenite" -version = "0.20.0" +name = "try-lock" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e862a1c4128df0112ab625f55cd5c934bcb4312ba80b39ae4b4835a3fd58e649" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http", - "httparse", - "log", - "rand", - "sha1", - "thiserror", - "url", - "utf-8", -] +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" @@ -886,9 +890,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -907,9 +911,9 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", "idna", @@ -941,6 +945,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1050,9 +1063,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1eeca1c172a285ee6c2c84c341ccea837e7c01b12fbb2d0fe3c9e550ce49ec8" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -1065,75 +1078,60 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b10d0c968ba7f6166195e13d593af609ec2e3d24f916f081690695cf5eaffb2f" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "571d8d4e62f26d4932099a9efe89660e8bd5087775a2ab5cdd8b747b811f1058" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2229ad223e178db5fbbc8bd8d3835e51e566b8474bfca58d2e6150c48bb723cd" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600956e2d840c194eedfc5d18f8242bc2e17c7775b6684488af3a9fff6fe3287" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea99ff3f8b49fb7a8e0d305e5aec485bd068c2ba691b6e277d29eaeac945868a" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1a05a1ece9a7a0d5a7ccf30ba2c33e3a61a30e042ffd247567d1de1d94120d" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d419259aba16b663966e29e6d7c6ecfa0bb8425818bb96f6f1f3c3eb71a6e7b9" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "zia-client" version = "0.0.0-git" -dependencies = [ - "anyhow", - "async-trait", - "clap", - "futures-util", - "tokio", - "tokio-tungstenite", - "tracing", - "tracing-subscriber", - "url", - "zia-common", -] - -[[package]] -name = "zia-common" -version = "0.0.0-git" dependencies = [ "anyhow", "async-http-proxy", - "futures-util", + "clap", + "fastwebsockets", + "hyper", "once_cell", - "pin-project", "tokio", "tokio-rustls", - "tokio-tungstenite", "tracing", + "tracing-subscriber", "url", "webpki-roots", ] @@ -1143,12 +1141,9 @@ name = "zia-server" version = "0.0.0-git" dependencies = [ "anyhow", - "async-trait", - "clap", - "futures-util", + "fastwebsockets", + "hyper", + "once_cell", "tokio", - "tokio-tungstenite", - "tracing", - "tracing-subscriber", - "zia-common", + "webpki-roots", ] diff --git a/Cargo.toml b/Cargo.toml index af9ae16..f4f9b43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,6 @@ [workspace] members = [ "zia-client", - "zia-common", "zia-server" ] diff --git a/zia-client/Cargo.toml b/zia-client/Cargo.toml index 86a63ca..b446f0c 100644 --- a/zia-client/Cargo.toml +++ b/zia-client/Cargo.toml @@ -7,20 +7,21 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -tokio = { version = "1.32", default-features = false, features = ["macros", "net", "rt-multi-thread", "signal", "sync"] } +fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } +tokio = { version = "1.32", default-features = false, features = ["rt-multi-thread", "macros", "net", "sync", "time", "signal"] } +async-http-proxy = { version = "1.2", default-features = false, features = ["runtime-tokio", "basic-auth"] } +hyper = { version = "0.14", default-features = false, features = [] } tracing-subscriber = { version = "0.3", features = ["tracing-log"] } -tokio-tungstenite = { version = "0.20", default-features = false, features = ["handshake"] } -futures-util = { version = "0.3", default-features = false } clap = { version = "4.4", features = ["derive", "env"] } url = { version = "2.4", features = ["serde"] } -async-trait = "0.1" -tracing = "0.1" +webpki-roots = "0.25" +tokio-rustls = "0.24" +once_cell = "1.18" anyhow = "1.0" +tracing = "0.1" -zia-common = { path = "../zia-common" } [package.metadata.generate-rpm] assets = [ - # { source = "target/release/status-node", dest = "/usr/bin/status-node", mode = "0755" }, { source = "../LICENSE", dest = "/usr/share/doc/zia-client/LICENSE", doc = true, mode = "0644" }, ] diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs new file mode 100644 index 0000000..4dc6ca7 --- /dev/null +++ b/zia-client/src/app.rs @@ -0,0 +1,212 @@ +use std::future::Future; +use std::mem; +use std::net::SocketAddr; +use std::sync::Arc; + +use anyhow::anyhow; +use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; +use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; +use hyper::header::{ + CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, +}; +use hyper::upgrade::Upgraded; +use hyper::{Body, Request}; +use once_cell::sync::Lazy; +use tokio::io::{split, ReadHalf, WriteHalf}; +use tokio::net::{TcpStream, UdpSocket}; +use tokio::sync::{Mutex as TokioMutex, Mutex}; +use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; +use tokio_rustls::TlsConnector; +use tracing::info; +use url::Url; + +static TLS_CONNECTOR: Lazy = Lazy::new(|| { + let mut store = RootCertStore::empty(); + store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints) + })); + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(store) + .with_no_client_auth(); + + TlsConnector::from(Arc::new(config)) +}); + +pub(crate) struct Connection { + finished: Arc>, + buf: Arc>>, + write: Arc>>>, +} + +impl Connection { + pub(crate) async fn new( + upstream: &Url, + proxy: &Option, + ) -> anyhow::Result<(Self, WebSocketRead>)> { + let upstream_host = upstream + .host_str() + .ok_or_else(|| anyhow!("Upstream url is missing host"))?; + let upstream_port = upstream + .port_or_known_default() + .ok_or_else(|| anyhow!("Upstream url is missing port"))?; + + let stream = match proxy { + None => { + let stream = TcpStream::connect((upstream_host, upstream_port)).await?; + stream.set_nodelay(true)?; + + info!("Connected to tcp"); + + stream + } + Some(proxy) => { + assert_eq!(proxy.scheme(), "http"); + + let proxy_host = proxy + .host_str() + .ok_or_else(|| anyhow!("Proxy url is missing host"))?; + let proxy_port = proxy + .port_or_known_default() + .ok_or_else(|| anyhow!("Proxy url is missing port"))?; + + let mut stream = TcpStream::connect((proxy_host, proxy_port)).await?; + stream.set_nodelay(true)?; + + info!("Connected to tcp"); + + match proxy.password() { + Some(password) => { + http_connect_tokio_with_basic_auth( + &mut stream, + upstream_host, + upstream_port, + proxy.username(), + password, + ) + .await? + } + None => http_connect_tokio(&mut stream, upstream_host, upstream_port).await?, + }; + + info!("Proxy handshake"); + + stream + } + }; + + let req = Request::get(upstream.to_string()) + .header(HOST, format!("{}:{}", upstream_host, upstream_port)) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) + .header(SEC_WEBSOCKET_VERSION, "13") + .header(USER_AGENT, "zia") + .body(Body::empty())?; + + let (ws, _) = if upstream.scheme() == "wss" { + let domain = ServerName::try_from(upstream_host)?; + let stream = TLS_CONNECTOR.connect(domain, stream).await?; + info!("Upgraded to tls"); + + fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? + } else { + fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? + }; + + info!("Finished websocket handshake"); + + let (read, write) = ws.split(|upgraded| split(upgraded)); + + Ok(( + Self { + finished: Arc::new(Mutex::new(true)), + buf: Arc::new(Mutex::new(datagram_buffer())), + write: Arc::new(Mutex::new(write)), + }, + read, + )) + } +} + +pub(crate) struct WritePool { + pub(crate) connections: Vec, + pub(crate) next: usize, +} + +impl WritePool { + pub(crate) async fn write( + &mut self, + socket: &UdpSocket, + old_addr: Arc>>, + ) -> anyhow::Result<()> { + loop { + let connection = self + .connections + .get(self.next % self.connections.len()) + .unwrap(); + + self.next += 1; + + let mut finished = connection.finished.lock().await; + if *finished { + *finished = false; + + let mut buf = connection.buf.lock().await; + let (read, addr) = socket.recv_from(&mut buf[..]).await.unwrap(); + tokio::spawn(async move { + *(old_addr.lock().await) = Some(addr); + }); + + let finished = connection.finished.clone(); + let buf = connection.buf.clone(); + let write = connection.write.clone(); + + tokio::spawn(async move { + let mut buf = buf.lock().await; + + write + .lock() + .await + .write_frame(Frame::new( + true, + OpCode::Binary, + None, + Payload::BorrowedMut(&mut buf[..read]), + )) + .await + .unwrap(); + + *(finished.lock().await) = true; + }); + return Ok(()); + } + } + } +} + +// Tie hyper's executor to tokio runtime +struct SpawnExecutor; + +impl hyper::rt::Executor for SpawnExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} + +const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); + +/// Creates and returns a buffer on the heap with enough space to contain any possible +/// UDP datagram. +/// +/// This is put on the heap and in a separate function to avoid the 64k buffer from ending +/// up on the stack and blowing up the size of the futures using it. +#[inline(never)] +fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { + Box::new([0u8; MAX_DATAGRAM_SIZE]) +} diff --git a/zia-client/src/cfg.rs b/zia-client/src/cfg.rs index 9384287..5d4e0e7 100644 --- a/zia-client/src/cfg.rs +++ b/zia-client/src/cfg.rs @@ -12,4 +12,6 @@ pub(crate) struct ClientCfg { pub(crate) upstream: Url, #[arg(short, long, env = "ZIA_PROXY")] pub(crate) proxy: Option, + #[arg(short, long, env = "ZIA_COUNT")] + pub(crate) count: usize, } diff --git a/zia-client/src/handler.rs b/zia-client/src/handler.rs deleted file mode 100644 index 77d008f..0000000 --- a/zia-client/src/handler.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; - -use tokio::net::UdpSocket; -use tokio::sync::Mutex; -use tracing::{error, info, warn}; - -use crate::upstream::{Upstream, UpstreamSink, UpstreamStream}; - -pub(crate) struct UdpHandler { - upstream: U, - known: Arc>>, -} - -impl UdpHandler { - pub(crate) fn new(upstream: U) -> Self { - UdpHandler { - upstream, - known: Arc::new(Mutex::new(HashMap::new())), - } - } - - pub(crate) async fn listen(self, socket: Arc) -> anyhow::Result<()> - where - ::Stream: Send + 'static, - ::Sink: Send + 'static, - { - let mut buf = U::buffer(); - - loop { - let (read, addr) = socket.recv_from(&mut buf[U::BUF_SKIP..]).await?; - - let mut known = self.known.lock().await; - let mut upstream = match known.entry(addr) { - Entry::Occupied(occupied) => occupied, - Entry::Vacant(vacant) => { - info!("New socket at {}/udp, opening upstream connection...", addr); - - let (sink, mut stream) = match self.upstream.open().await { - Ok(conn) => conn, - Err(err) => { - warn!("Error while opening upstream connection: {err}"); - continue; - } - }; - - let known = self.known.clone(); - let socket = socket.clone(); - - tokio::spawn(async move { - if let Err(err) = stream.connect(&socket, addr).await { - warn!( - "Unable to read from upstream or write to udp socket: {}", - err - ); - } - - if let Some(mut sink) = known.lock().await.remove(&addr) { - if let Err(err) = sink.close().await { - error!("Unable to close upstream sink, closing...: {}", err); - } - } - }); - - vacant.insert_entry(sink) - } - }; - - let sink = upstream.get_mut(); - if let Err(err) = sink.write(&mut buf[..U::BUF_SKIP + read]).await { - warn!("Unable to write to upstream, closing...: {}", err); - - if let Err(err) = upstream.remove().close().await { - error!("Unable to close upstream sink: {}", err); - } - }; - } - } -} diff --git a/zia-client/src/main.rs b/zia-client/src/main.rs index 18c03cf..6571363 100644 --- a/zia-client/src/main.rs +++ b/zia-client/src/main.rs @@ -1,19 +1,22 @@ -#![feature(entry_insert)] - +use std::convert::Infallible; use std::net::SocketAddr; -use clap::Parser; +use std::sync::Arc; +use crate::app::{Connection, WritePool}; +use clap::Parser; +use fastwebsockets::{Frame, OpCode, Payload}; use tokio::net::UdpSocket; use tokio::select; use tokio::signal::ctrl_c; -use tracing::info; +use tokio::sync::Mutex; +use tokio::task::JoinSet; +use tracing::{error, info}; use url::Url; use crate::cfg::ClientCfg; +mod app; mod cfg; -mod handler; -mod upstream; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -22,13 +25,13 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); select! { - result = tokio::spawn(listen(config.listen_addr, config.upstream, config.proxy)) => { + result = tokio::spawn(listen(config.listen_addr, config.upstream, config.proxy, config.count)) => { result??; info!("Socket closed, quitting..."); }, result = shutdown_signal() => { result?; - info!("Termination signal received, quitting...") + info!("Termination signal received, quitting..."); } } @@ -62,19 +65,92 @@ async fn shutdown_signal() -> anyhow::Result<()> { } } -async fn listen(addr: SocketAddr, upstream: Url, proxy: Option) -> anyhow::Result<()> { - let inbound = UdpSocket::bind(addr).await?; - info!("Listening on {}/udp", inbound.local_addr()?); +async fn listen( + addr: SocketAddr, + upstream: Url, + proxy: Option, + connection_count: usize, +) -> anyhow::Result<()> { + let socket = Arc::new(UdpSocket::bind(addr).await?); + + let upstream = Arc::new(upstream); + let proxy = Arc::new(proxy); + + let mut conns = JoinSet::new(); + for _ in 0..connection_count { + let upstream = upstream.clone(); + let proxy = proxy.clone(); + conns.spawn(async move { Connection::new(&upstream, &proxy).await }); + } - if let Some(proxy) = &proxy { - info!("Using upstream at {} via proxy {}...", upstream, proxy); - } else { - info!("Using upstream at {}...", upstream); + let mut connections = Vec::new(); + let mut read_pool = Vec::new(); + + while let Some(connection) = conns.join_next().await.transpose()? { + let (conn, read) = connection?; + connections.push(conn); + read_pool.push(read); } - upstream::transmit(inbound, &upstream, &proxy).await?; + info!("Connected to upstream"); - info!("Transmission via {} closed", upstream); + let addr = Arc::new(Mutex::new(Option::None)); - Ok(()) + async fn test(_: Frame<'_>) -> Result<(), Infallible> { + todo!() + } + + for mut ws_read in read_pool { + let addr = addr.clone(); + let socket = socket.clone(); + + tokio::spawn(async move { + loop { + let frame = ws_read.read_frame(&mut test).await.unwrap(); + + if !frame.fin { + error!( + "unexpected buffer received, expect full udp frame to be in one websocket message" + ); + continue; + } + + match frame.opcode { + OpCode::Binary => match frame.payload { + Payload::BorrowedMut(payload) => { + socket + .send_to(payload, addr.lock().await.unwrap()) + .await + .unwrap(); + } + Payload::Borrowed(payload) => { + socket + .send_to(payload, addr.lock().await.unwrap()) + .await + .unwrap(); + } + Payload::Owned(payload) => { + socket + .send_to(&payload, addr.lock().await.unwrap()) + .await + .unwrap(); + } + }, + opcode => error!("Unexpected opcode: {:?}", opcode), + } + } + }); + } + + let socket = socket.clone(); + let old_addr = addr.clone(); + + let mut pool = WritePool { + connections, + next: 0, + }; + + loop { + pool.write(&socket, old_addr.clone()).await?; + } } diff --git a/zia-client/src/upstream/mod.rs b/zia-client/src/upstream/mod.rs deleted file mode 100644 index d24d145..0000000 --- a/zia-client/src/upstream/mod.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::net::SocketAddr; - -use anyhow::anyhow; -use async_trait::async_trait; -use tokio::net::UdpSocket; -use url::Url; - -mod tcp; -mod ws; - -pub(crate) async fn transmit( - socket: UdpSocket, - upstream: &Url, - proxy: &Option, -) -> anyhow::Result<()> { - match upstream.scheme().to_lowercase().as_str() { - "tcp" | "tcps" => tcp::transmit(socket, upstream, proxy).await, - "ws" | "wss" => ws::transmit(socket, upstream, proxy).await, - _ => Err(anyhow!("Unsupported upstream scheme {}", upstream.scheme())), - } -} - -#[async_trait] -pub(crate) trait Upstream { - type Sink: UpstreamSink; - type Stream: UpstreamStream; - - const BUF_SIZE: usize; - const BUF_SKIP: usize; - - fn buffer() -> Box<[u8]>; - async fn open(&self) -> anyhow::Result<(Self::Sink, Self::Stream)>; -} - -#[async_trait] -pub(crate) trait UpstreamSink { - async fn write(&mut self, buf: &mut [u8]) -> anyhow::Result<()>; - async fn close(&mut self) -> anyhow::Result<()>; -} - -#[async_trait] -pub(crate) trait UpstreamStream { - async fn connect(&mut self, socket: &UdpSocket, addr: SocketAddr) -> anyhow::Result<()>; -} diff --git a/zia-client/src/upstream/tcp.rs b/zia-client/src/upstream/tcp.rs deleted file mode 100644 index 5e03dc8..0000000 --- a/zia-client/src/upstream/tcp.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::mem; -use std::net::SocketAddr; -use std::sync::Arc; - -use anyhow::anyhow; -use async_trait::async_trait; -use tokio::io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, UdpSocket}; -use url::Url; - -use zia_common::Stream; - -use crate::handler::UdpHandler; -use crate::upstream::{Upstream, UpstreamSink, UpstreamStream}; - -pub(crate) async fn transmit( - socket: UdpSocket, - upstream: &Url, - proxy: &Option, -) -> anyhow::Result<()> { - let upstream = TcpUpstream { - url: upstream.clone(), - proxy: proxy.clone(), - }; - - let handler = UdpHandler::new(upstream); - handler.listen(Arc::new(socket)).await -} - -struct TcpUpstream { - url: Url, - proxy: Option, -} - -struct TcpUpstreamSink { - inner: WriteHalf>, -} - -struct TcpUpstreamStream { - inner: ReadHalf>, -} - -#[async_trait] -impl Upstream for TcpUpstream { - type Sink = TcpUpstreamSink; - type Stream = TcpUpstreamStream; - - /// A UDP datagram header has a 16 bit field containing an unsigned integer - /// describing the length of the datagram (including the header itself). - /// The max value is 2^16 = 65536 bytes. But since that includes the - /// UDP header, this constant is 8 bytes more than any UDP socket - /// read operation would ever return. We are going to use that extra space. - const BUF_SIZE: usize = u16::MAX as usize; - const BUF_SKIP: usize = mem::size_of::(); - - fn buffer() -> Box<[u8]> { - Box::new([0_u8; Self::BUF_SIZE]) - } - - async fn open(&self) -> anyhow::Result<(Self::Sink, Self::Stream)> { - let stream = Stream::connect(&self.url, &self.proxy).await?; - let (stream, sink) = split(stream); - - let sink = Self::Sink { inner: sink }; - let stream = Self::Stream { inner: stream }; - - Ok((sink, stream)) - } -} - -#[async_trait] -impl UpstreamSink for TcpUpstreamSink { - async fn write(&mut self, buf: &mut [u8]) -> anyhow::Result<()> { - let len = (buf.len() - TcpUpstream::BUF_SKIP) as u16; - - buf[0..TcpUpstream::BUF_SKIP].copy_from_slice(&len.to_le_bytes()); - - self.inner.write_all(buf).await?; - - Ok(()) - } - - async fn close(&mut self) -> anyhow::Result<()> { - Ok(self.inner.shutdown().await?) - } -} - -#[async_trait] -impl UpstreamStream for TcpUpstreamStream { - async fn connect(&mut self, socket: &UdpSocket, addr: SocketAddr) -> anyhow::Result<()> { - let mut buf = TcpUpstream::buffer(); - let mut unprocessed = 0; - - loop { - let read = self.inner.read(&mut buf[unprocessed..]).await?; - - if read == 0 { - return Err(anyhow!("End of tcp upstream stream.")); - } - - unprocessed += read; - - let processed = forward(socket, addr, &buf[..unprocessed]).await?; - - // discard processed bytes - if unprocessed > processed { - buf.copy_within(processed..unprocessed, 0); - } - - unprocessed -= processed; - } - } -} - -async fn forward(socket: &UdpSocket, addr: SocketAddr, buf: &[u8]) -> anyhow::Result { - let mut start = 0; - - loop { - let body = start + TcpUpstream::BUF_SKIP; - - let len: [u8; TcpUpstream::BUF_SKIP] = match buf.get(start..body) { - Some(header) => header.try_into().unwrap(), - // not enough bytes for a complete header - None => return Ok(start), - }; - - let len = u16::from_le_bytes(len) as usize; - let end = body + len; - - let data = match buf.get(body..end) { - Some(data) => data, - // not enough bytes for a complete dataframe - None => return Ok(start), - }; - - let written = socket.send_to(data, addr).await?; - assert_eq!(len, written, "Did not send entire UDP datagram"); - - start = end; - } -} - -// Creates and returns a buffer on the heap with enough space to contain any possible -// UDP datagram. -// -// This is put on the heap and in a separate function to avoid the 64k buffer from ending -// up on the stack and blowing up the size of the futures using it. -// #[inline(never)] -// fn datagram_buffer() -> Box<[u8; U::BUF_SIZE]> { -// Box::new([0u8; U::BUF_SIZE]) -// } diff --git a/zia-client/src/upstream/ws.rs b/zia-client/src/upstream/ws.rs deleted file mode 100644 index bc52829..0000000 --- a/zia-client/src/upstream/ws.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::mem; -use std::net::SocketAddr; -use std::sync::Arc; - -use async_trait::async_trait; -use futures_util::stream::{SplitSink, SplitStream}; -use futures_util::{SinkExt, StreamExt}; -use tokio::net::{TcpStream, UdpSocket}; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{client_async, WebSocketStream}; -use url::Url; - -use zia_common::Stream; - -use crate::handler::UdpHandler; -use crate::upstream::{Upstream, UpstreamSink, UpstreamStream}; - -pub(crate) async fn transmit( - socket: UdpSocket, - upstream: &Url, - proxy: &Option, -) -> anyhow::Result<()> { - let upstream = WsUpstream { - url: upstream.clone(), - proxy: proxy.clone(), - }; - - let handler = UdpHandler::new(upstream); - handler.listen(Arc::new(socket)).await -} - -struct WsUpstream { - url: Url, - proxy: Option, -} - -struct WsUpstreamSink { - inner: SplitSink>, Message>, -} - -struct WsUpstreamStream { - inner: SplitStream>>, -} - -#[async_trait] -impl Upstream for WsUpstream { - type Sink = WsUpstreamSink; - type Stream = WsUpstreamStream; - - /// A UDP datagram header has a 16 bit field containing an unsigned integer - /// describing the length of the datagram (including the header itself). - /// The max value is 2^16 = 65536 bytes. But since that includes the - /// UDP header, this constant is 8 bytes more than any UDP socket - /// read operation would ever return. We are going to save that extra space. - const BUF_SIZE: usize = u16::MAX as usize - mem::size_of::(); - const BUF_SKIP: usize = 0; - - fn buffer() -> Box<[u8]> { - Box::new([0_u8; Self::BUF_SIZE]) - } - - async fn open(&self) -> anyhow::Result<(Self::Sink, Self::Stream)> { - let stream = Stream::connect(&self.url, &self.proxy).await?; - let (stream, _) = client_async(&self.url, stream).await?; - let (sink, stream) = stream.split(); - - let sink = Self::Sink { inner: sink }; - let stream = Self::Stream { inner: stream }; - - Ok((sink, stream)) - } -} - -#[async_trait] -impl UpstreamSink for WsUpstreamSink { - async fn write(&mut self, buf: &mut [u8]) -> anyhow::Result<()> { - Ok(self.inner.send(Message::Binary(buf.to_vec())).await?) - } - - async fn close(&mut self) -> anyhow::Result<()> { - Ok(self.inner.close().await?) - } -} - -#[async_trait] -impl UpstreamStream for WsUpstreamStream { - async fn connect(&mut self, socket: &UdpSocket, addr: SocketAddr) -> anyhow::Result<()> { - loop { - match self.inner.next().await { - Some(Ok(message)) => { - socket.send_to(&message.into_data(), addr).await?; - } - Some(err @ Err(_)) => { - err?; - } - None => {} - } - } - } -} diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml deleted file mode 100644 index 50d7adb..0000000 --- a/zia-common/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "zia-common" -version = "0.0.0-git" -edition = "2021" - -[dependencies] -async-http-proxy = { version = "1.2", features = ["runtime-tokio", "basic-auth"] } -tokio = { version = "1.32", default-features = false, features = ["net", "time"] } -tokio-tungstenite = { version = "0.20", default-features = false } -futures-util = { version = "0.3", default-features = false } -tokio-rustls = "0.24" -webpki-roots = "0.25" -pin-project = "1.1" -once_cell = "1.18" -tracing = "0.1" -anyhow = "1.0" -url = "2.4" diff --git a/zia-common/src/forwarding/mod.rs b/zia-common/src/forwarding/mod.rs deleted file mode 100644 index aa409fc..0000000 --- a/zia-common/src/forwarding/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub use crate::forwarding::tcp::*; -pub use crate::forwarding::ws::*; - -mod tcp; -mod ws; diff --git a/zia-common/src/forwarding/tcp.rs b/zia-common/src/forwarding/tcp.rs deleted file mode 100644 index 05c8f87..0000000 --- a/zia-common/src/forwarding/tcp.rs +++ /dev/null @@ -1,153 +0,0 @@ -// see: https://github.com/mullvad/udp-over-tcp/blob/main/src/forward_traffic.rs - -use std::convert::{Infallible, TryFrom}; -use std::mem; -use std::sync::Arc; - -use anyhow::Context; -use futures_util::future::select; -use futures_util::pin_mut; -use tokio::io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, UdpSocket}; -use tracing::error; - -use crate::{Stream}; - -/// A UDP datagram header has a 16 bit field containing an unsigned integer -/// describing the length of the datagram (including the header itself). -/// The max value is 2^16 = 65536 bytes. But since that includes the -/// UDP header, this constant is 8 bytes more than any UDP socket -/// read operation would ever return. We are going to use that extra space -/// to store our 2 byte udp-over-tcp header. -const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize; -const HEADER_LEN: usize = mem::size_of::(); - -/// Forward traffic between the given UDP and TCP sockets in both directions. -/// This async function runs until one of the sockets are closed or there is an error. -/// Both sockets are closed before returning. -pub async fn process_udp_over_tcp( - udp_socket: UdpSocket, - tcp_stream: Stream, -) { - let udp_in = Arc::new(udp_socket); - let udp_out = udp_in.clone(); - let (tcp_in, tcp_out) = split(tcp_stream); - - let tcp2udp = tokio::spawn(async move { - if let Err(error) = process_tcp2udp(tcp_in, udp_out).await { - error!("Error: {}", error); - } - }); - - let udp2tcp = tokio::spawn(async move { - if let Err(error) = process_udp2tcp(udp_in, tcp_out).await { - error!("Error: {}", error); - } - }); - - pin_mut!(tcp2udp); - pin_mut!(udp2tcp); - - // Wait until the UDP->TCP or TCP->UDP future terminates. - select(tcp2udp, udp2tcp).await; -} - -/// Reads from `tcp_in` and extracts UDP datagrams. Writes the datagrams to `udp_out`. -/// Returns if the TCP socket is closed, or an IO error happens on either socket. -async fn process_tcp2udp( - mut tcp_in: ReadHalf>, - udp_out: Arc -) -> anyhow::Result<()> { - let mut buffer = datagram_buffer(); - // `buffer` has unprocessed data from the TCP socket up until this index. - let mut unprocessed_i = 0; - loop { - let tcp_read_len = tcp_in.read(&mut buffer[unprocessed_i..]) - .await - .context("Failed reading from TCP")?; - if tcp_read_len == 0 { - break; - } - unprocessed_i += tcp_read_len; - - let processed_i = forward_datagrams_in_buffer(&udp_out, &buffer[..unprocessed_i]) - .await - .context("Failed writing to UDP")?; - - // If we have read data that was not forwarded, because it was not a complete datagram, - // move it to the start of the buffer and start over - if unprocessed_i > processed_i { - buffer.copy_within(processed_i..unprocessed_i, 0); - } - unprocessed_i -= processed_i; - } - - Ok(()) -} - -/// Forward all complete datagrams in `buffer` to `udp_out`. -/// Returns the number of processed bytes. -async fn forward_datagrams_in_buffer(udp_out: &UdpSocket, buffer: &[u8]) -> anyhow::Result { - let mut header_start = 0; - loop { - let header_end = header_start + HEADER_LEN; - // "parse" the header - let header = match buffer.get(header_start..header_end) { - Some(header) => <[u8; HEADER_LEN]>::try_from(header).unwrap(), - // Buffer does not contain entire header for next datagram - None => break Ok(header_start), - }; - let datagram_len = usize::from(u16::from_le_bytes(header)); - let datagram_start = header_end; - let datagram_end = datagram_start + datagram_len; - - let datagram_data = match buffer.get(datagram_start..datagram_end) { - Some(datagram_data) => datagram_data, - // The buffer does not contain the entire datagram - None => break Ok(header_start), - }; - - let udp_write_len = udp_out.send(datagram_data).await?; - assert_eq!( - udp_write_len, datagram_len, - "Did not send entire UDP datagram" - ); - - header_start = datagram_end; - } -} - -/// Reads datagrams from `udp_in` and writes them (with the 16 bit header containing the length) -/// to `tcp_out` indefinitely, or until an IO error happens on either socket. -async fn process_udp2tcp( - udp_in: Arc, - mut tcp_out: WriteHalf>, -) -> anyhow::Result { - // A buffer large enough to hold any possible UDP datagram plus its 16 bit length header. - let mut buffer = datagram_buffer(); - loop { - let udp_read_len = udp_in - .recv(&mut buffer[HEADER_LEN..]) - .await - .context("Failed reading from UDP")?; - - // Set the "header" to the length of the datagram. - let datagram_len = u16::try_from(udp_read_len).expect("UDP datagram can't be larger than 2^16"); - buffer[..HEADER_LEN].copy_from_slice(&datagram_len.to_le_bytes()[..]); - - tcp_out - .write_all(&buffer[..HEADER_LEN + udp_read_len]) - .await - .context("Failed writing to TCP")?; - } -} - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} diff --git a/zia-common/src/forwarding/ws.rs b/zia-common/src/forwarding/ws.rs deleted file mode 100644 index b556360..0000000 --- a/zia-common/src/forwarding/ws.rs +++ /dev/null @@ -1,137 +0,0 @@ -// see: https://github.com/mullvad/udp-over-tcp/blob/main/src/forward_traffic.rs - -use std::convert::Infallible; -use std::mem; -use std::sync::Arc; - -use anyhow::Context; -use futures_util::future::select; -use futures_util::pin_mut; -use futures_util::stream::{SplitSink, SplitStream}; -use futures_util::{SinkExt, StreamExt}; -use tokio::net::{TcpStream, UdpSocket}; -use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal; -use tokio_tungstenite::tungstenite::protocol::CloseFrame; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::WebSocketStream; -use tracing::error; - -use crate::{Stream}; - -/// A UDP datagram header has a 16 bit field containing an unsigned integer -/// describing the length of the datagram (including the header itself). -/// The max value is 2^16 = 65536 bytes. But since that includes the -/// UDP header, this constant is 8 bytes more than any UDP socket -/// read operation would ever return. We are going to save that extra space. -const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); - -/// Forward traffic between the given UDP and WS sockets in both directions. -/// This async function runs until one of the sockets are closed or there is an error. -/// Both sockets are closed before returning. -pub async fn process_udp_over_ws( - udp_socket: UdpSocket, - ws_stream: WebSocketStream>, -) { - let (mut ws_out, mut ws_in) = ws_stream.split(); - - { - let udp_in = Arc::new(udp_socket); - let udp_out = udp_in.clone(); - - let ws2udp = async { - if let Err(error) = process_ws2udp(&mut ws_in, udp_out).await { - error!("Error: {}", error); - } - }; - - let udp2ws = async { - if let Err(error) = process_udp2ws(udp_in, &mut ws_out).await { - error!("Error: {}", error); - } - }; - - pin_mut!(ws2udp); - pin_mut!(udp2ws); - - // Wait until the UDP->WS or WS->UDP future terminates. - select(ws2udp, udp2ws).await; - } - - if let Err(err) = ws_in - .reunite(ws_out) - .unwrap() - .close(Some(CloseFrame { - code: Normal, - reason: "Timeout".into(), - })) - .await - { - error!("Unable to close ws stream: {}", err); - } -} - -/// Reads from `ws_in` and extracts UDP datagrams. Writes the datagrams to `udp_out`. -/// Returns if the WS socket is closed, or an IO error happens on either socket. -async fn process_ws2udp( - ws_in: &mut SplitStream>>, - udp_out: Arc, -) -> anyhow::Result<()> { - while let Some(message) = ws_in.next() - .await - .transpose() - .context("Failed reading from WS")? - { - let data = message.into_data(); - - forward_datagrams_in_buffer(&udp_out, &data) - .await - .context("Failed writing to UDP")?; - } - - Ok(()) -} - -/// Forward the datagram in `buffer` to `udp_out`. -/// Returns the number of processed bytes. -async fn forward_datagrams_in_buffer(udp_out: &UdpSocket, buffer: &Vec) -> anyhow::Result<()> { - let udp_write_len = udp_out.send(buffer).await?; - assert_eq!( - udp_write_len, - buffer.len(), - "Did not send entire UDP datagram" - ); - - Ok(()) -} - -/// Reads datagrams from `udp_in` and writes them (with the 16 bit header containing the length) -/// to `ws_out` indefinitely, or until an IO error happens on either socket. -async fn process_udp2ws( - udp_in: Arc, - ws_out: &mut SplitSink>, Message>, -) -> anyhow::Result { - // A buffer large enough to hold any possible UDP datagram plus its 16 bit length header. - let mut buffer = datagram_buffer(); - - loop { - let udp_read_len = udp_in - .recv(&mut buffer[..]) - .await - .context("Failed reading from UDP")?; - - ws_out - .send(Message::Binary(buffer[..udp_read_len].to_vec())) - .await - .context("Failed writing to WS")?; - } -} - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} diff --git a/zia-common/src/lib.rs b/zia-common/src/lib.rs deleted file mode 100644 index 6619fa8..0000000 --- a/zia-common/src/lib.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub use crate::forwarding::*; -pub use crate::stream::*; - -mod forwarding; -mod stream; diff --git a/zia-common/src/stream.rs b/zia-common/src/stream.rs deleted file mode 100644 index deefca1..0000000 --- a/zia-common/src/stream.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::io::{Error, IoSlice}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use anyhow::anyhow; -use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; -use once_cell::sync::Lazy; -use pin_project::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::net::TcpStream; -use tokio_rustls::client::TlsStream; -use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; -use tokio_rustls::TlsConnector; -use url::Url; - -static TLS_CONNECTOR: Lazy = Lazy::new(|| { - let mut store = RootCertStore::empty(); - store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints) - })); - - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(store) - .with_no_client_auth(); - - TlsConnector::from(Arc::new(config)) -}); - -#[pin_project(project = EnumProj)] -pub enum Stream { - Plain(#[pin] IO), - Tls(#[pin] TlsStream), - TlsOverTls(#[pin] TlsStream>), -} - -impl Stream { - pub async fn connect(upstream: &Url, proxy: &Option) -> anyhow::Result { - let upstream_host = upstream - .host_str() - .ok_or_else(|| anyhow!("Upstream url is missing host"))?; - let upstream_port = upstream - .port_or_known_default() - .ok_or_else(|| anyhow!("Upstream url is missing port"))?; - - let mut stream = match proxy { - None => { - let stream = TcpStream::connect((upstream_host, upstream_port)).await?; - stream.set_nodelay(true)?; - - Self::Plain(stream) - } - Some(proxy) => { - let proxy_host = proxy - .host_str() - .ok_or_else(|| anyhow!("Proxy url is missing host"))?; - let proxy_port = proxy - .port_or_known_default() - .ok_or_else(|| anyhow!("Proxy url is missing port"))?; - - let stream = TcpStream::connect((proxy_host, proxy_port)).await?; - stream.set_nodelay(true)?; - - let mut stream = Self::Plain(stream); - - if proxy.scheme() == "https" { - stream = stream.upgrade_to_tls(proxy_host).await?; - }; - - match proxy.password() { - Some(password) => { - http_connect_tokio_with_basic_auth( - &mut stream, - upstream_host, - upstream_port, - proxy.username(), - password, - ) - .await? - } - None => http_connect_tokio(&mut stream, upstream_host, upstream_port).await?, - }; - - stream - } - }; - - if upstream.scheme() == "wss" || upstream.scheme() == "tcps" { - stream = stream.upgrade_to_tls(upstream_host).await?; - } - - Ok(stream) - } -} - -impl Stream { - pub async fn upgrade_to_tls(self, host: &str) -> anyhow::Result { - let domain = ServerName::try_from(host)?; - - let stream = match self { - Self::Plain(stream) => Self::Tls(TLS_CONNECTOR.connect(domain, stream).await?), - Self::Tls(stream) => Self::TlsOverTls(TLS_CONNECTOR.connect(domain, stream).await?), - Self::TlsOverTls(_) => unimplemented!(), - }; - - Ok(stream) - } -} - -impl AsyncRead for Stream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_read(cx, buf), - EnumProj::Tls(stream) => stream.poll_read(cx, buf), - EnumProj::TlsOverTls(stream) => stream.poll_read(cx, buf), - } - } -} - -impl AsyncWrite for Stream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_write(cx, buf), - EnumProj::Tls(stream) => stream.poll_write(cx, buf), - EnumProj::TlsOverTls(stream) => stream.poll_write(cx, buf), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_flush(cx), - EnumProj::Tls(stream) => stream.poll_flush(cx), - EnumProj::TlsOverTls(stream) => stream.poll_flush(cx), - } - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_shutdown(cx), - EnumProj::Tls(stream) => stream.poll_shutdown(cx), - EnumProj::TlsOverTls(stream) => stream.poll_shutdown(cx), - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - match self.project() { - EnumProj::Plain(stream) => stream.poll_write_vectored(cx, bufs), - EnumProj::Tls(stream) => stream.poll_write_vectored(cx, bufs), - EnumProj::TlsOverTls(stream) => stream.poll_write_vectored(cx, bufs), - } - } - - fn is_write_vectored(&self) -> bool { - match self { - Self::Plain(stream) => stream.is_write_vectored(), - Self::Tls(stream) => stream.is_write_vectored(), - Self::TlsOverTls(stream) => stream.is_write_vectored(), - } - } -} diff --git a/zia-server/Cargo.toml b/zia-server/Cargo.toml index 6861e3e..b11a305 100644 --- a/zia-server/Cargo.toml +++ b/zia-server/Cargo.toml @@ -7,14 +7,11 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -tokio = { version = "1.32", default-features = false, features = ["macros", "rt-multi-thread", "net", "signal", "sync"] } -tokio-tungstenite = { version = "0.20", default-features = false, features = ["handshake"] } -tracing-subscriber = { version = "0.3", features = ["tracing-log"] } -futures-util = { version = "0.3", default-features = false } -clap = { version = "4.4", features = ["derive", "env"] } -zia-common = { path = "../zia-common" } -async-trait = "0.1" -tracing = "0.1" +fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } +tokio = { version = "1.32", default-features = false, features = ["rt-multi-thread", "macros", "net","sync", "time"] } +hyper = { version = "0.14", default-features = false, features = [] } +webpki-roots = "0.25" +once_cell = "1.18" anyhow = "1.0" [package.metadata.deb] @@ -27,6 +24,5 @@ assets = [ [package.metadata.generate-rpm] assets = [ - # { source = "target/release/status-node", dest = "/usr/bin/status-node", mode = "0755" }, { source = "../LICENSE", dest = "/usr/share/doc/zia-server/LICENSE", doc = true, mode = "0644" }, ] diff --git a/zia-server/src/app.rs b/zia-server/src/app.rs new file mode 100644 index 0000000..d50f841 --- /dev/null +++ b/zia-server/src/app.rs @@ -0,0 +1,194 @@ +use std::convert::Infallible; +use std::future::Future; +use std::mem; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::str::FromStr; +use std::sync::Arc; +use fastwebsockets::{Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite}; +use hyper::{Body, Request}; +use hyper::header::{ + CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, +}; +use hyper::upgrade::Upgraded; +use tokio::io::{ReadHalf, split, WriteHalf}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::sync::{Mutex as TokioMutex, Mutex}; + +struct Connection { + finished: Arc>, + buf: Arc>>, + write: Arc>>>, +} + +const UDP_ADDR: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10,99,0,0), 51838)); + +async fn test(_: Frame<'_>) -> Result<(), Infallible> { + todo!() +} + +async fn handle( + socket: TcpStream, +) -> anyhow::Result>> { + let mut ws = WebSocket::after_handshake(socket, Role::Server); + ws.set_writev(false); + ws.set_auto_close(true); + ws.set_auto_pong(true); + + let (mut ws_read, write) = ws.split(|upgraded| split(upgraded)); + + tokio::spawn(async move { + loop { + let frame = ws_read.read_frame(&mut test).await.unwrap(); + + if !frame.fin { + println!("unexpected buffer received, expect full udp frame to be in one websocket message"); + continue; + } + + match frame.opcode { + OpCode::Binary => match frame.payload { + Payload::BorrowedMut(payload) => { + socket.send_to(payload, UDP_ADDR).await.unwrap(); + } + Payload::Borrowed(payload) => { + socket.send_to(payload, UDP_ADDR).await.unwrap(); + } + Payload::Owned(payload) => { + socket.send_to(&payload, UDP_ADDR).await.unwrap(); + } + }, + opcode => eprintln!("Unexpected opcode: {:?}", opcode), + } + } + }); + + Ok(write) +} + +impl Connection { + async fn new() -> anyhow::Result<(Self, WebSocketRead>)> { + // 135.181.77.88:80 + + println!("connected https"); + + let (ws, resp) = fastwebsockets::upgrade::upgrade(&SpawnExecutor, req, listener).await?; + + println!("connected websocket"); + + let (ws_read, ws_write) = ws.split(|upgraded| split(upgraded)); + + Ok(( + Self { + finished: Arc::new(Mutex::new(true)), + buf: Arc::new(Mutex::new(datagram_buffer())), + write: Arc::new(Mutex::new(ws_write)), + }, + ws_read, + )) + } +} + +struct WritePool { + connections: Vec, + next: usize, +} + +impl WritePool { + async fn write( + &mut self, + socket: &UdpSocket, + old_addr: Arc>>, + ) -> anyhow::Result<()> { + loop { + let connection = self + .connections + .get(self.next % self.connections.len()) + .unwrap(); + + self.next += 1; + + let mut finished = connection.finished.lock().await; + if *finished { + *finished = false; + + let mut buf = connection.buf.lock().await; + let (read, addr) = socket.recv_from(&mut buf[..]).await.unwrap(); + tokio::spawn(async move { + *(old_addr.lock().await) = Some(addr); + }); + + let buf = connection.buf.clone(); + let write = connection.write.clone(); + + tokio::spawn(async move { + let mut buf = buf.lock().await; + + write + .lock() + .await + .write_frame(Frame::new( + true, + OpCode::Binary, + None, + Payload::BorrowedMut(&mut buf[..read]), + )) + .await + .unwrap(); + }); + return Ok(()); + } + } + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); + let mut listener = TcpListener::bind("127.0.0.1:3128").await?; + + let mut connections = Arc::new(Mutex::new(Vec::new())); + + println!("connected"); + + + { + let connections = connections.clone(); + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await?; + connections.lock().await.push(handle(stream).await?); + } + }); + } + + loop { + connections.lock(). + } +} + +// Tie hyper's executor to tokio runtime +struct SpawnExecutor; + +impl hyper::rt::Executor for SpawnExecutor + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} + +const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); + +/// Creates and returns a buffer on the heap with enough space to contain any possible +/// UDP datagram. +/// +/// This is put on the heap and in a separate function to avoid the 64k buffer from ending +/// up on the stack and blowing up the size of the futures using it. +#[inline(never)] +fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { + Box::new([0u8; MAX_DATAGRAM_SIZE]) +} + + diff --git a/zia-server/src/listener/mod.rs b/zia-server/src/listener/mod.rs deleted file mode 100644 index 4fe6811..0000000 --- a/zia-server/src/listener/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub(crate) use self::tcp::*; -pub(crate) use self::ws::*; - -mod tcp; -mod ws; - -#[async_trait::async_trait] -pub(crate) trait Listener { - async fn listen(&self, upstream: &str) -> anyhow::Result<()>; -} diff --git a/zia-server/src/listener/tcp.rs b/zia-server/src/listener/tcp.rs deleted file mode 100644 index a14c7b0..0000000 --- a/zia-server/src/listener/tcp.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::net::SocketAddr; - -use tokio::net::{TcpStream, UdpSocket}; -use tracing::{error, info}; - -use zia_common::process_udp_over_tcp; -use zia_common::Stream::Plain; - -use crate::listener::Listener; - -pub(crate) struct TcpListener { - pub(crate) addr: SocketAddr, -} - -#[async_trait::async_trait] -impl Listener for TcpListener { - async fn listen(&self, upstream: &str) -> anyhow::Result<()> { - let listener = tokio::net::TcpListener::bind(self.addr).await?; - - loop { - let (sock, _) = listener.accept().await?; - let upstream = upstream.to_string(); - - tokio::spawn(async move { - if let Err(err) = Self::handle(sock, upstream).await { - error!("Error while handling connection: {:?}", err); - } - }); - } - } -} - -impl TcpListener { - async fn handle(downstream: TcpStream, upstream_addr: String) -> anyhow::Result<()> { - downstream.set_nodelay(true)?; - let downstream_addr = downstream.peer_addr()?; - info!("New downstream connection: {}", downstream_addr); - - let upstream = UdpSocket::bind("0.0.0.0:0").await?; // TODO: maybe make this configurable - - upstream.connect(upstream_addr).await?; - - info!( - "Connected to udp upstream (local: {}/udp, peer: {}/udp) for downstream {}", - upstream.local_addr()?, - upstream.peer_addr()?, - downstream_addr - ); - - process_udp_over_tcp(upstream, Plain(downstream)).await; - - info!("Connection with downstream {} closed...", downstream_addr); - - Ok(()) - } -} diff --git a/zia-server/src/listener/ws.rs b/zia-server/src/listener/ws.rs deleted file mode 100644 index 02a066e..0000000 --- a/zia-server/src/listener/ws.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::net::SocketAddr; - -use tokio::net::{TcpListener, TcpStream, UdpSocket}; -use tracing::{error, info}; -use zia_common::process_udp_over_ws; -use zia_common::Stream::Plain; - -use crate::listener::Listener; - -pub(crate) struct WsListener { - pub(crate) addr: SocketAddr, -} - -#[async_trait::async_trait] -impl Listener for WsListener { - async fn listen(&self, upstream: &str) -> anyhow::Result<()> { - let listener = TcpListener::bind(self.addr).await?; - info!("Listening on ws://{}...", listener.local_addr()?); - - loop { - let (sock, _) = listener.accept().await?; - let upstream = upstream.to_string(); - - tokio::spawn(async move { - if let Err(err) = Self::handle(sock, upstream).await { - error!("Error while handling connection: {:?}", err); - } - }); - } - } -} - -impl WsListener { - async fn handle(downstream: TcpStream, upstream_addr: String) -> anyhow::Result<()> { - downstream.set_nodelay(true)?; - let downstream_addr = downstream.peer_addr()?; - info!("New downstream connection: {}", downstream_addr); - let downstream = tokio_tungstenite::accept_async(Plain(downstream)).await?; - - let upstream = UdpSocket::bind("0.0.0.0:0").await?; // TODO: maybe make this configurable - - upstream.connect(upstream_addr).await?; - - info!( - "Connected to udp upstream (local: {}/udp, peer: {}/udp) for downstream {}", - upstream.local_addr()?, - upstream.peer_addr()?, - downstream_addr - ); - - process_udp_over_ws(upstream, downstream).await; - - info!("Connection with downstream {} closed...", downstream_addr); - - Ok(()) - } -} diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index 83b4fe1..dcfcb4e 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -1,13 +1,7 @@ -use clap::Parser; -use tokio::select; -use tokio::signal::ctrl_c; -use tracing::info; - use crate::cfg::{ClientCfg, Mode}; -use crate::listener::{Listener, TcpListener, WsListener}; mod cfg; -mod listener; +mod app; #[tokio::main] async fn main() -> anyhow::Result<()> { From e3aa94fc5abf0b782b1c2aaff111b38d6bc26fb9 Mon Sep 17 00:00:00 2001 From: Marcel Date: Sun, 1 Oct 2023 11:13:24 +0200 Subject: [PATCH 02/19] things --- Cargo.lock | 38 ++++- Cargo.toml | 2 + zia-client/Cargo.toml | 3 +- zia-client/src/app.rs | 249 +++++++++++---------------------- zia-client/src/main.rs | 95 ++++--------- zia-common/Cargo.toml | 14 ++ zia-common/src/lib.rs | 10 ++ zia-common/src/pool.rs | 61 ++++++++ zia-common/src/read.rs | 89 ++++++++++++ zia-common/src/write.rs | 94 +++++++++++++ zia-common/src/ws/README.md | 1 + zia-common/src/ws/frame.rs | 97 +++++++++++++ zia-common/src/ws/mod.rs | 162 ++++++++++++++++++++++ zia-common/src/ws/ws.rs | 270 ++++++++++++++++++++++++++++++++++++ zia-server/Cargo.toml | 10 +- zia-server/src/app.rs | 194 -------------------------- zia-server/src/cfg.rs | 15 +- zia-server/src/main.rs | 142 +++++++++++++++++-- 18 files changed, 1092 insertions(+), 454 deletions(-) create mode 100644 zia-common/Cargo.toml create mode 100644 zia-common/src/lib.rs create mode 100644 zia-common/src/pool.rs create mode 100644 zia-common/src/read.rs create mode 100644 zia-common/src/write.rs create mode 100644 zia-common/src/ws/README.md create mode 100644 zia-common/src/ws/frame.rs create mode 100644 zia-common/src/ws/mod.rs create mode 100644 zia-common/src/ws/ws.rs delete mode 100644 zia-server/src/app.rs diff --git a/Cargo.lock b/Cargo.lock index 933026e..ffb52df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "anstream" -version = "0.6.1" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6cd65a4b849ace0b7f6daeebcc1a1d111282227ca745458c61dbf670e52a597" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.0" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0238ca56c96dfa37bdf7c373c8886dd591322500aceeeccdb2216fe06dc2f796" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys", @@ -373,6 +373,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -689,6 +690,16 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "socket2" version = "0.5.4" @@ -780,7 +791,7 @@ dependencies = [ "num_cpus", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.4", "tokio-macros", "windows-sys", ] @@ -1134,6 +1145,18 @@ dependencies = [ "tracing-subscriber", "url", "webpki-roots", + "zia-common", +] + +[[package]] +name = "zia-common" +version = "0.0.0-git" +dependencies = [ + "anyhow", + "hyper", + "rand", + "tokio", + "tracing", ] [[package]] @@ -1141,9 +1164,14 @@ name = "zia-server" version = "0.0.0-git" dependencies = [ "anyhow", + "clap", "fastwebsockets", "hyper", "once_cell", + "pin-project", "tokio", + "tracing", + "tracing-subscriber", "webpki-roots", + "zia-common", ] diff --git a/Cargo.toml b/Cargo.toml index f4f9b43..a16ff00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,8 @@ [workspace] +resolver = "2" members = [ "zia-client", + "zia-common", "zia-server" ] diff --git a/zia-client/Cargo.toml b/zia-client/Cargo.toml index b446f0c..c624531 100644 --- a/zia-client/Cargo.toml +++ b/zia-client/Cargo.toml @@ -17,9 +17,10 @@ url = { version = "2.4", features = ["serde"] } webpki-roots = "0.25" tokio-rustls = "0.24" once_cell = "1.18" -anyhow = "1.0" tracing = "0.1" +anyhow = "1.0" +zia-common = { path = '../zia-common' } [package.metadata.generate-rpm] assets = [ diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs index 4dc6ca7..845b2c2 100644 --- a/zia-client/src/app.rs +++ b/zia-client/src/app.rs @@ -1,25 +1,24 @@ use std::future::Future; -use std::mem; -use std::net::SocketAddr; use std::sync::Arc; use anyhow::anyhow; use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; -use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; use hyper::header::{ CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, }; use hyper::upgrade::Upgraded; use hyper::{Body, Request}; use once_cell::sync::Lazy; -use tokio::io::{split, ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, UdpSocket}; -use tokio::sync::{Mutex as TokioMutex, Mutex}; +use tokio::io::split; +use tokio::net::TcpStream; use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; use tokio_rustls::TlsConnector; use tracing::info; use url::Url; +use zia_common::ws::{Role, WebSocket}; +use zia_common::{ReadConnection, WriteConnection, MAX_DATAGRAM_SIZE}; + static TLS_CONNECTOR: Lazy = Lazy::new(|| { let mut store = RootCertStore::empty(); store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { @@ -34,156 +33,88 @@ static TLS_CONNECTOR: Lazy = Lazy::new(|| { TlsConnector::from(Arc::new(config)) }); -pub(crate) struct Connection { - finished: Arc>, - buf: Arc>>, - write: Arc>>>, -} +pub(crate) async fn open_connection( + upstream: &Url, + proxy: &Option, +) -> anyhow::Result<(ReadConnection, WriteConnection)> { + let upstream_host = upstream + .host_str() + .ok_or_else(|| anyhow!("Upstream url is missing host"))?; + let upstream_port = upstream + .port_or_known_default() + .ok_or_else(|| anyhow!("Upstream url is missing port"))?; + + let stream = match proxy { + None => { + let stream = TcpStream::connect((upstream_host, upstream_port)).await?; + stream.set_nodelay(true)?; + + info!("Connected to tcp"); + + stream + } + Some(proxy) => { + assert_eq!(proxy.scheme(), "http"); + + let proxy_host = proxy + .host_str() + .ok_or_else(|| anyhow!("Proxy url is missing host"))?; + let proxy_port = proxy + .port_or_known_default() + .ok_or_else(|| anyhow!("Proxy url is missing port"))?; + + let mut stream = TcpStream::connect((proxy_host, proxy_port)).await?; + stream.set_nodelay(true)?; + + info!("Connected to tcp"); + + match proxy.password() { + Some(password) => { + http_connect_tokio_with_basic_auth( + &mut stream, + upstream_host, + upstream_port, + proxy.username(), + password, + ) + .await? + } + None => http_connect_tokio(&mut stream, upstream_host, upstream_port).await?, + }; + + info!("Proxy handshake"); + + stream + } + }; -impl Connection { - pub(crate) async fn new( - upstream: &Url, - proxy: &Option, - ) -> anyhow::Result<(Self, WebSocketRead>)> { - let upstream_host = upstream - .host_str() - .ok_or_else(|| anyhow!("Upstream url is missing host"))?; - let upstream_port = upstream - .port_or_known_default() - .ok_or_else(|| anyhow!("Upstream url is missing port"))?; - - let stream = match proxy { - None => { - let stream = TcpStream::connect((upstream_host, upstream_port)).await?; - stream.set_nodelay(true)?; - - info!("Connected to tcp"); - - stream - } - Some(proxy) => { - assert_eq!(proxy.scheme(), "http"); - - let proxy_host = proxy - .host_str() - .ok_or_else(|| anyhow!("Proxy url is missing host"))?; - let proxy_port = proxy - .port_or_known_default() - .ok_or_else(|| anyhow!("Proxy url is missing port"))?; - - let mut stream = TcpStream::connect((proxy_host, proxy_port)).await?; - stream.set_nodelay(true)?; - - info!("Connected to tcp"); - - match proxy.password() { - Some(password) => { - http_connect_tokio_with_basic_auth( - &mut stream, - upstream_host, - upstream_port, - proxy.username(), - password, - ) - .await? - } - None => http_connect_tokio(&mut stream, upstream_host, upstream_port).await?, - }; - - info!("Proxy handshake"); - - stream - } - }; - - let req = Request::get(upstream.to_string()) - .header(HOST, format!("{}:{}", upstream_host, upstream_port)) - .header(UPGRADE, "websocket") - .header(CONNECTION, "upgrade") - .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) - .header(SEC_WEBSOCKET_VERSION, "13") - .header(USER_AGENT, "zia") - .body(Body::empty())?; - - let (ws, _) = if upstream.scheme() == "wss" { - let domain = ServerName::try_from(upstream_host)?; - let stream = TLS_CONNECTOR.connect(domain, stream).await?; - info!("Upgraded to tls"); - - fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? - } else { - fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? - }; - - info!("Finished websocket handshake"); - - let (read, write) = ws.split(|upgraded| split(upgraded)); - - Ok(( - Self { - finished: Arc::new(Mutex::new(true)), - buf: Arc::new(Mutex::new(datagram_buffer())), - write: Arc::new(Mutex::new(write)), - }, - read, - )) - } -} + let req = Request::get(upstream.to_string()) + .header(HOST, format!("{}:{}", upstream_host, upstream_port)) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) + .header(SEC_WEBSOCKET_VERSION, "13") + .header(USER_AGENT, "zia") + .body(Body::empty())?; -pub(crate) struct WritePool { - pub(crate) connections: Vec, - pub(crate) next: usize, -} + let (ws, _) = if upstream.scheme() == "wss" { + let domain = ServerName::try_from(upstream_host)?; + let stream = TLS_CONNECTOR.connect(domain, stream).await?; + info!("Upgraded to tls"); -impl WritePool { - pub(crate) async fn write( - &mut self, - socket: &UdpSocket, - old_addr: Arc>>, - ) -> anyhow::Result<()> { - loop { - let connection = self - .connections - .get(self.next % self.connections.len()) - .unwrap(); - - self.next += 1; - - let mut finished = connection.finished.lock().await; - if *finished { - *finished = false; - - let mut buf = connection.buf.lock().await; - let (read, addr) = socket.recv_from(&mut buf[..]).await.unwrap(); - tokio::spawn(async move { - *(old_addr.lock().await) = Some(addr); - }); - - let finished = connection.finished.clone(); - let buf = connection.buf.clone(); - let write = connection.write.clone(); - - tokio::spawn(async move { - let mut buf = buf.lock().await; - - write - .lock() - .await - .write_frame(Frame::new( - true, - OpCode::Binary, - None, - Payload::BorrowedMut(&mut buf[..read]), - )) - .await - .unwrap(); - - *(finished.lock().await) = true; - }); - return Ok(()); - } - } - } + fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? + } else { + fastwebsockets::handshake::client(&SpawnExecutor, req, stream).await? + }; + + info!("Finished websocket handshake"); + + let (read, write) = split(ws.into_inner()); + + let read = WebSocket::new(read, MAX_DATAGRAM_SIZE, Role::Client); + let write = WebSocket::new(write, MAX_DATAGRAM_SIZE, Role::Client); + + Ok((ReadConnection::new(read), WriteConnection::new(write))) } // Tie hyper's executor to tokio runtime @@ -198,15 +129,3 @@ where tokio::task::spawn(fut); } } - -const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} diff --git a/zia-client/src/main.rs b/zia-client/src/main.rs index 6571363..67bb4dc 100644 --- a/zia-client/src/main.rs +++ b/zia-client/src/main.rs @@ -1,18 +1,18 @@ -use std::convert::Infallible; use std::net::SocketAddr; use std::sync::Arc; -use crate::app::{Connection, WritePool}; use clap::Parser; -use fastwebsockets::{Frame, OpCode, Payload}; use tokio::net::UdpSocket; use tokio::select; use tokio::signal::ctrl_c; -use tokio::sync::Mutex; +use tokio::sync::RwLock; use tokio::task::JoinSet; -use tracing::{error, info}; +use tracing::info; use url::Url; +use zia_common::{ReadPool, WritePool}; + +use crate::app::open_connection; use crate::cfg::ClientCfg; mod app; @@ -80,77 +80,36 @@ async fn listen( for _ in 0..connection_count { let upstream = upstream.clone(); let proxy = proxy.clone(); - conns.spawn(async move { Connection::new(&upstream, &proxy).await }); + conns.spawn(async move { open_connection(&upstream, &proxy).await }); } - let mut connections = Vec::new(); - let mut read_pool = Vec::new(); + let addr = Arc::new(RwLock::new(Option::None)); + + let write_pool = WritePool::new(socket.clone(), addr.clone()); + let read_pool = ReadPool::new(socket, addr); while let Some(connection) = conns.join_next().await.transpose()? { - let (conn, read) = connection?; - connections.push(conn); - read_pool.push(read); + let (read, write) = connection?; + read_pool.push(read).await; + write_pool.push(write).await; } info!("Connected to upstream"); - let addr = Arc::new(Mutex::new(Option::None)); - - async fn test(_: Frame<'_>) -> Result<(), Infallible> { - todo!() - } - - for mut ws_read in read_pool { - let addr = addr.clone(); - let socket = socket.clone(); - - tokio::spawn(async move { - loop { - let frame = ws_read.read_frame(&mut test).await.unwrap(); - - if !frame.fin { - error!( - "unexpected buffer received, expect full udp frame to be in one websocket message" - ); - continue; - } - - match frame.opcode { - OpCode::Binary => match frame.payload { - Payload::BorrowedMut(payload) => { - socket - .send_to(payload, addr.lock().await.unwrap()) - .await - .unwrap(); - } - Payload::Borrowed(payload) => { - socket - .send_to(payload, addr.lock().await.unwrap()) - .await - .unwrap(); - } - Payload::Owned(payload) => { - socket - .send_to(&payload, addr.lock().await.unwrap()) - .await - .unwrap(); - } - }, - opcode => error!("Unexpected opcode: {:?}", opcode), - } - } - }); - } - - let socket = socket.clone(); - let old_addr = addr.clone(); - - let mut pool = WritePool { - connections, - next: 0, - }; + let write_handle = tokio::spawn(async move { + loop { + write_pool.execute().await?; + } + }); - loop { - pool.write(&socket, old_addr.clone()).await?; + select! { + result = write_handle => { + info!("Write pool finished"); + result? + }, + result = read_pool.join() => { + info!("Read pool finished"); + result + }, } } diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml new file mode 100644 index 0000000..ad09daa --- /dev/null +++ b/zia-common/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "zia-common" +version = "0.0.0-git" +edition = "2021" +authors = ["Marcel "] +license = "AGPL-3.0" +description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." + +[dependencies] +tokio = { version = "1.32", default-features = false, features = ["net", "sync"] } +hyper = { version = "0.14", default-features = false, features = [] } +anyhow = { version = "1.0", features = [] } +rand = {version = "0.8" ,default-features = false} +tracing = "0.1" diff --git a/zia-common/src/lib.rs b/zia-common/src/lib.rs new file mode 100644 index 0000000..411c423 --- /dev/null +++ b/zia-common/src/lib.rs @@ -0,0 +1,10 @@ +pub use read::*; +use std::mem; +pub use write::*; + +mod pool; +mod read; +mod write; +pub mod ws; + +pub const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); diff --git a/zia-common/src/pool.rs b/zia-common/src/pool.rs new file mode 100644 index 0000000..01bdb7d --- /dev/null +++ b/zia-common/src/pool.rs @@ -0,0 +1,61 @@ +use std::ops::{Deref, DerefMut}; + +use tokio::sync::{mpsc, Mutex}; + +pub struct PoolGuard { + inner: Option, + back: mpsc::UnboundedSender, +} + +unsafe impl Send for PoolGuard {} + +impl Drop for PoolGuard { + fn drop(&mut self) { + if let Err(err) = self.back.send(self.inner.take().unwrap()) { + panic!("Could not put PoolGuard back to pool: {:?}", err); + } + } +} + +impl Deref for PoolGuard { + type Target = T; + fn deref(&self) -> &Self::Target { + self.inner.as_ref().unwrap() + } +} + +impl DerefMut for PoolGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.as_mut().unwrap() + } +} + +pub struct Pool { + tx: mpsc::UnboundedSender, + rx: Mutex>, +} + +impl Pool { + pub fn new() -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + Self { + tx, + rx: Mutex::new(rx), + } + } + + pub async fn acquire(&self) -> PoolGuard { + let inner = self.rx.lock().await.recv().await.unwrap(); + + PoolGuard { + inner: Some(inner), + back: self.tx.clone(), + } + } + + pub fn push(&self, inner: T) { + if let Err(err) = self.tx.send(inner) { + panic!("Could not put Inner into to pool: {:?}", err); + } + } +} diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs new file mode 100644 index 0000000..b631aad --- /dev/null +++ b/zia-common/src/read.rs @@ -0,0 +1,89 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use tokio::io::{AsyncRead, ReadHalf}; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::sync::{Mutex, RwLock}; +use tokio::task::{JoinError, JoinSet}; +use tracing::error; + +use crate::ws::{Event, WebSocket}; + +pub struct ReadConnection { + read: WebSocket>, +} + +impl ReadConnection { + pub fn new(read: WebSocket>) -> Self { + Self { read } + } + + async fn handle_frame( + &mut self, + socket: &UdpSocket, + addr: &RwLock>, + ) -> anyhow::Result<()> { + let event = self.read.recv().await?; + + match event { + Event::Data(data) => { + let addr = addr.read().await.unwrap(); + socket.send_to(&data, addr).await?; + } + Event::Close { .. } => {} + } + + Ok(()) + } +} + +pub struct ReadPool { + socket: Arc, + addr: Arc>>, + tasks: Mutex>>, +} + +impl ReadPool { + pub fn new(socket: Arc, addr: Arc>>) -> Self { + Self { + socket, + addr, + tasks: Mutex::new(JoinSet::new()), + } + } + async fn wait(&self) -> Option, JoinError>> { + let mut set = self.tasks.lock().await; + select! { + result = set.join_next() => result, + _result = tokio::time::sleep(Duration::from_millis(200)) => Some(Ok(Ok(()))), + } + } + + pub async fn join(&self) -> anyhow::Result<()> { + // hack + loop { + while let Some(result) = self.wait().await { + if let Err(err) = result? { + error!("Error while handling websocket frame: {}", err); + // TODO: close and remove from write pool + } + } + + // hack + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + + pub async fn push(&self, mut conn: ReadConnection) { + let socket = self.socket.clone(); + let addr = self.addr.clone(); + + self.tasks.lock().await.spawn(async move { + loop { + conn.handle_frame(&socket, &addr).await?; + } + }); + } +} diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs new file mode 100644 index 0000000..1cf52fe --- /dev/null +++ b/zia-common/src/write.rs @@ -0,0 +1,94 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use crate::MAX_DATAGRAM_SIZE; +use tokio::io::{AsyncWrite, WriteHalf}; +use tokio::net::UdpSocket; +use tokio::sync::RwLock; +use tracing::error; + +use crate::pool::Pool; +use crate::ws::WebSocket; + +pub struct WriteConnection { + write: WebSocket>, + buf: Box<[u8; MAX_DATAGRAM_SIZE]>, +} + +impl WriteConnection { + pub fn new(write: WebSocket>) -> Self { + Self { + buf: datagram_buffer(), + write, + } + } + + async fn flush(&mut self, size: usize) -> anyhow::Result<()> { + assert!(size <= MAX_DATAGRAM_SIZE); + + self.write.send(&self.buf[..size]).await?; + + Ok(()) + } +} + +pub struct WritePool { + socket: Arc, + pool: Pool>, + addr: Arc>>, +} + +impl WritePool { + pub fn new(socket: Arc, addr: Arc>>) -> Self { + Self { + socket, + pool: Pool::new(), + addr, + } + } + + async fn update_addr(&self, addr: SocketAddr) { + let is_outdated = self + .addr + .read() + .await + .map(|last_addr| last_addr != addr) + .unwrap_or(true); + + if is_outdated { + *(self.addr.write().await) = Some(addr); + } + } + + pub async fn push(&self, conn: WriteConnection) { + self.pool.push(conn); + } + + pub async fn execute(&self) -> anyhow::Result<()> { + loop { + let mut conn = self.pool.acquire().await; + + // read from udp socket and save to buf of selected conn + let (read, addr) = self.socket.recv_from(conn.buf.as_mut()).await.unwrap(); + + self.update_addr(addr).await; + + // flush buf of conn asynchronously to read again from udp socket in parallel + tokio::spawn(async move { + if let Err(err) = conn.flush(read).await { + error!("Unable to flush websocket buf: {:?}", err); + } + }); + } + } +} + +/// Creates and returns a buffer on the heap with enough space to contain any possible +/// UDP datagram. +/// +/// This is put on the heap and in a separate function to avoid the 64k buffer from ending +/// up on the stack and blowing up the size of the futures using it. +#[inline(never)] +fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { + Box::new([0u8; MAX_DATAGRAM_SIZE]) +} diff --git a/zia-common/src/ws/README.md b/zia-common/src/ws/README.md new file mode 100644 index 0000000..79b9bdb --- /dev/null +++ b/zia-common/src/ws/README.md @@ -0,0 +1 @@ +based on https://github.com/nurmohammed840/websocket.rs diff --git a/zia-common/src/ws/frame.rs b/zia-common/src/ws/frame.rs new file mode 100644 index 0000000..4e4cb2a --- /dev/null +++ b/zia-common/src/ws/frame.rs @@ -0,0 +1,97 @@ +#![doc(hidden)] + +pub struct Frame<'a> { + pub fin: bool, + pub opcode: u8, + pub data: &'a [u8], +} + +impl<'a> Frame<'a> { + #[inline] + pub fn encode_without_mask(self) -> Vec { + let mut buf = Vec::::with_capacity(10 + self.data.len()); + unsafe { + let dist = buf.as_mut_ptr(); + let head_len = self.encode_header_unchecked(dist, 0); + std::ptr::copy_nonoverlapping(self.data.as_ptr(), dist.add(head_len), self.data.len()); + buf.set_len(head_len + self.data.len()); + } + buf + } + + #[inline] + pub fn encode_with(self, mask: [u8; 4]) -> Vec { + let mut buf = Vec::::with_capacity(14 + self.data.len()); + unsafe { + let dist = buf.as_mut_ptr(); + let head_len = self.encode_header_unchecked(dist, 0x80); + + let [a, b, c, d] = mask; + dist.add(head_len).write(a); + dist.add(head_len + 1).write(b); + dist.add(head_len + 2).write(c); + dist.add(head_len + 3).write(d); + + let dist = dist.add(head_len + 4); + // TODO: Use SIMD wherever possible for best performance + for i in 0..self.data.len() { + dist + .add(i) + .write(self.data.get_unchecked(i) ^ mask.get_unchecked(i & 3)); + } + buf.set_len(head_len + 4 + self.data.len()); + } + buf + } + + /// # SAFETY + /// + /// - `dist` must be valid for writes of 10 bytes. + pub(crate) unsafe fn encode_header_unchecked(&self, dist: *mut u8, mask_bit: u8) -> usize { + dist.write(((self.fin as u8) << 7) | self.opcode); + if self.data.len() < 126 { + dist.add(1).write(mask_bit | self.data.len() as u8); + 2 + } else if self.data.len() < 65536 { + let [b2, b3] = (self.data.len() as u16).to_be_bytes(); + dist.add(1).write(mask_bit | 126); + dist.add(2).write(b2); + dist.add(3).write(b3); + 4 + } else { + let [b2, b3, b4, b5, b6, b7, b8, b9] = (self.data.len() as u64).to_be_bytes(); + dist.add(1).write(mask_bit | 127); + dist.add(2).write(b2); + dist.add(3).write(b3); + dist.add(4).write(b4); + dist.add(5).write(b5); + dist.add(6).write(b6); + dist.add(7).write(b7); + dist.add(8).write(b8); + dist.add(9).write(b9); + 10 + } + } +} + +impl<'a> From<&'a str> for Frame<'a> { + #[inline] + fn from(string: &'a str) -> Self { + Self { + fin: true, + opcode: 1, + data: string.as_bytes(), + } + } +} + +impl<'a> From<&'a [u8]> for Frame<'a> { + #[inline] + fn from(data: &'a [u8]) -> Self { + Self { + fin: true, + opcode: 2, + data, + } + } +} diff --git a/zia-common/src/ws/mod.rs b/zia-common/src/ws/mod.rs new file mode 100644 index 0000000..b5521b6 --- /dev/null +++ b/zia-common/src/ws/mod.rs @@ -0,0 +1,162 @@ +pub use frame::Frame; +pub use ws::WebSocket; + +mod frame; +mod ws; + +pub enum Role { + Server, + Client, +} + +/// It represent the type of data that is being sent over the WebSocket connection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageType { + /// `Text` data is represented as a sequence of Unicode characters encoded using UTF-8 encoding. + Text = 1, + /// `Binary` data can be any sequence of bytes and is typically used for sending non-textual data, such as images, audio files etc... + Binary = 2, +} + +impl MessageType { + /// Returns `true` if message type is text + #[inline] + pub fn is_text(&self) -> bool { + matches!(self, MessageType::Text) + } + + /// Returns `true` if message type is binary + #[inline] + pub fn is_binary(&self) -> bool { + matches!(self, MessageType::Binary) + } +} + +#[derive(Debug)] +pub enum Event { + Data(Box<[u8]>), + + Close { code: CloseCode, reason: Box }, +} + +/// When closing an established connection an endpoint MAY indicate a reason for closure. +#[derive(Debug, Clone, Copy)] +pub enum CloseCode { + /// The purpose for which the connection was established has been fulfilled + Normal = 1000, + /// Server going down or a browser having navigated away from a page + Away = 1001, + /// An endpoint is terminating the connection due to a protocol error. + ProtocolError = 1002, + /// It has received a type of data it cannot accept + Unsupported = 1003, + + // reserved 1004 + /// MUST NOT be set as a status code in a Close control frame by an endpoint. + /// + /// No status code was actually present. + NoStatusRcvd = 1005, + /// MUST NOT be set as a status code in a Close control frame by an endpoint. + /// + /// Connection was closed abnormally. + Abnormal = 1006, + /// Application has received data within a message that was not consistent with the type of the message. + InvalidPayload = 1007, + /// This is a generic status code that can be returned when there is no other more suitable status code. + PolicyViolation = 1008, + /// Message that is too big for it to process. + MessageTooBig = 1009, + /// It has expected the server to negotiate one or more extension. + MandatoryExt = 1010, + /// The server has encountered an unexpected condition that prevented it from fulfilling the request. + InternalError = 1011, + /// MUST NOT be set as a status code in a Close control frame by an endpoint. + /// + /// The connection was closed due to a failure to perform a TLS handshake. + TLSHandshake = 1015, +} + +impl From for u16 { + #[inline] + fn from(code: CloseCode) -> Self { + code as u16 + } +} + +impl From for CloseCode { + #[inline] + fn from(value: u16) -> Self { + match value { + 1000 => CloseCode::Normal, + 1001 => CloseCode::Away, + 1002 => CloseCode::ProtocolError, + 1003 => CloseCode::Unsupported, + 1005 => CloseCode::NoStatusRcvd, + 1006 => CloseCode::Abnormal, + 1007 => CloseCode::InvalidPayload, + 1009 => CloseCode::MessageTooBig, + 1010 => CloseCode::MandatoryExt, + 1011 => CloseCode::InternalError, + 1015 => CloseCode::TLSHandshake, + _ => CloseCode::PolicyViolation, + } + } +} + +impl PartialEq for CloseCode { + #[inline] + fn eq(&self, other: &u16) -> bool { + (*self as u16) == *other + } +} + +/// This trait is responsible for encoding websocket closed frame. +pub trait CloseReason { + /// Encoded close reason as bytes + type Bytes; + /// Encode websocket close frame. + fn to_bytes(self) -> Self::Bytes; +} + +impl CloseReason for () { + type Bytes = [u8; 0]; + fn to_bytes(self) -> Self::Bytes { + [0; 0] + } +} + +impl CloseReason for u16 { + type Bytes = [u8; 2]; + fn to_bytes(self) -> Self::Bytes { + self.to_be_bytes() + } +} + +impl CloseReason for CloseCode { + type Bytes = [u8; 2]; + fn to_bytes(self) -> Self::Bytes { + (self as u16).to_be_bytes() + } +} + +impl CloseReason for &str { + type Bytes = Vec; + fn to_bytes(self) -> Self::Bytes { + CloseReason::to_bytes((CloseCode::Normal, self)) + } +} + +impl CloseReason for (Code, Msg) +where + Code: Into, + Msg: AsRef<[u8]>, +{ + type Bytes = Vec; + fn to_bytes(self) -> Self::Bytes { + let (code, reason) = (self.0.into(), self.1.as_ref()); + let mut data = Vec::with_capacity(2 + reason.len()); + data.extend_from_slice(&code.to_be_bytes()); + data.extend_from_slice(reason); + data + } +} diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs new file mode 100644 index 0000000..d87d4f9 --- /dev/null +++ b/zia-common/src/ws/ws.rs @@ -0,0 +1,270 @@ +use std::io::{IoSlice, Result}; + +use anyhow::anyhow; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::ws::{CloseReason, Event, Frame, Role}; + +/// WebSocket implementation for both client and server +pub struct WebSocket { + /// it is a low-level abstraction that represents the underlying byte stream over which WebSocket messages are exchanged. + pub stream: Stream, + + /// Maximum allowed payload length in bytes. + pub max_payload_len: usize, + + role: Role, + is_closed: bool, +} + +impl WebSocket { + #[inline] + pub fn new(stream: IO, max_payload_len: usize, role: Role) -> Self { + Self { + stream, + max_payload_len, + role, + is_closed: false, + } + } +} + +impl WebSocket +where + W: Unpin + AsyncWrite, +{ + #[doc(hidden)] + pub async fn send_raw(&mut self, frame: Frame<'_>) -> Result<()> { + let buf = match self.role { + Role::Server => { + if self.stream.is_write_vectored() { + let mut head = [0; 10]; + let head_len = unsafe { frame.encode_header_unchecked(head.as_mut_ptr(), 0) }; + let total_len = head_len + frame.data.len(); + + let mut bufs = [IoSlice::new(&head[..head_len]), IoSlice::new(frame.data)]; + let mut amt = self.stream.write_vectored(&bufs).await?; + if amt == total_len { + return Ok(()); + } + while amt < head_len { + bufs[0] = IoSlice::new(&head[amt..head_len]); + amt += self.stream.write_vectored(&bufs).await?; + } + if amt < total_len { + self.stream.write_all(&frame.data[amt - head_len..]).await?; + } + return Ok(()); + } + frame.encode_without_mask() + } + Role::Client => frame.encode_with(rand::random::().to_ne_bytes()), + }; + self.stream.write_all(&buf).await + } + + /// Send message to a endpoint. + pub async fn send(&mut self, data: impl Into>) -> Result<()> { + self.send_raw(data.into()).await + } + + /// - The Close frame MAY contain a body that indicates a reason for closing. + pub async fn close(mut self, reason: T) -> Result<()> + where + T: CloseReason, + T::Bytes: AsRef<[u8]>, + { + self + .send_raw(Frame { + fin: true, + opcode: 8, + data: reason.to_bytes().as_ref(), + }) + .await?; + self.stream.flush().await + } + + /// Flushes this output stream, ensuring that all intermediately buffered contents reach their destination. + pub async fn flash(&mut self) -> Result<()> { + self.stream.flush().await + } +} + +// ------------------------------------------------------------------------ + +macro_rules! err { [$msg: expr] => { return Err(anyhow!($msg)) }; } + +impl WebSocket +where + R: Unpin + AsyncRead, +{ + /// reads [Event] from websocket stream. + pub async fn recv(&mut self) -> anyhow::Result { + if self.is_closed { + return Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "read after close", + ))?; + } + let event = self.recv_event().await; + if let Ok(Event::Close { .. }) | Err(..) = event { + self.is_closed = true; + } + event + } + + // ### WebSocket Frame Header + // + // ```txt + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ + // ``` + /// reads [Event] from websocket stream. + pub async fn recv_event(&mut self) -> anyhow::Result { + let mut buf = [0u8; 2]; + self.stream.read_exact(&mut buf).await?; + + let [b1, b2] = buf; + + let fin = b1 & 0b1000_0000 != 0; + let rsv = b1 & 0b111_0000; + let opcode = b1 & 0b1111; + let len = (b2 & 0b111_1111) as usize; + + // Defines whether the "Payload data" is masked. If set to 1, a + // masking key is present in masking-key, and this is used to unmask + // the "Payload data" as per [Section 5.3](https://datatracker.ietf.org/doc/html/rfc6455#section-5.3). All frames sent from + // client to server have this bit set to 1. + let is_masked = b2 & 0b_1000_0000 != 0; + + if rsv != 0 { + // MUST be `0` unless an extension is negotiated that defines meanings + // for non-zero values. If a nonzero value is received and none of + // the negotiated extensions defines the meaning of such a nonzero + // value, the receiving endpoint MUST _Fail the WebSocket Connection_. + err!("reserve bit must be `0`"); + } + + // A client MUST mask all frames that it sends to the server. (Note + // that masking is done whether or not the WebSocket Protocol is running + // over TLS.) The server MUST close the connection upon receiving a + // frame that is not masked. + // + // A server MUST NOT mask any frames that it sends to the client. + if let Role::Server = self.role { + if !is_masked { + err!("expected masked frame"); + } + } else if is_masked { + err!("expected unmasked frame"); + } + + // 3-7 are reserved for further non-control frames. + if opcode >= 8 { + if !fin { + err!("control frame must not be fragmented"); + } + if len > 125 { + err!("control frame must have a payload length of 125 bytes or less"); + } + let msg = self.read_payload(len).await?; + match opcode { + 8 => on_close(&msg), + // 9 => Ok(Event::Ping(msg)), + // 10 => Ok(Event::Pong(msg)), + // 11-15 are reserved for further control frames + _ => err!("unknown opcode"), + } + } else { + match (opcode, fin) { + (2, true) => {} + _ => err!("invalid data frame"), + }; + let len = match len { + 126 => self.stream.read_u16().await? as usize, + 127 => self.stream.read_u64().await? as usize, + len => len, + }; + if len > self.max_payload_len { + err!("payload too large"); + } + let data = self.read_payload(len).await?; + Ok(Event::Data(data)) + } + } + + async fn read_payload(&mut self, len: usize) -> Result> { + let mut data = vec![0; len].into_boxed_slice(); + match self.role { + Role::Server => { + let mut mask = [0u8; 4]; + self.stream.read_exact(&mut mask).await?; + self.stream.read_exact(&mut data).await?; + // TODO: Use SIMD wherever possible for best performance + for i in 0..data.len() { + data[i] ^= mask[i & 3]; + } + } + Role::Client => { + self.stream.read_exact(&mut data).await?; + } + } + Ok(data) + } +} + +/// - If there is a body, the first two bytes of the body MUST be a 2-byte unsigned integer (in network byte order: Big Endian) +/// representing a status code with value /code/ defined in [Section 7.4](https:///datatracker.ietf.org/doc/html/rfc6455#section-7.4). +/// Following the 2-byte integer, +/// +/// - The application MUST NOT send any more data frames after sending a `Close` frame. +/// +/// - If an endpoint receives a Close frame and did not previously send a +/// Close frame, the endpoint MUST send a Close frame in response. (When +/// sending a Close frame in response, the endpoint typically echos the +/// status code it received.) It SHOULD do so as soon as practical. An +/// endpoint MAY delay sending a Close frame until its current message is +/// sent +/// +/// - After both sending and receiving a Close message, an endpoint +/// considers the WebSocket connection closed and MUST close the +/// underlying TCP connection. +fn on_close(msg: &[u8]) -> anyhow::Result { + let code = msg + .get(..2) + .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) + .unwrap_or(1000); + + match code { + 1000..=1003 | 1007..=1011 | 1015 | 3000..=3999 | 4000..=4999 => { + match msg.get(2..).map(|data| String::from_utf8(data.to_vec())) { + Some(Ok(msg)) => Ok(Event::Close { + code: code.into(), + reason: msg.into_boxed_str(), + }), + None => Ok(Event::Close { + code: code.into(), + reason: "".into(), + }), + Some(Err(_)) => Err(anyhow!("invalid utf-8 payload")), + } + } + _ => Err(anyhow!("invalid close code")), + } +} diff --git a/zia-server/Cargo.toml b/zia-server/Cargo.toml index b11a305..8b122e3 100644 --- a/zia-server/Cargo.toml +++ b/zia-server/Cargo.toml @@ -7,13 +7,19 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] +tokio = { version = "1.32", default-features = false, features = ["rt-multi-thread", "macros", "net","sync", "time", "signal"] } fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } -tokio = { version = "1.32", default-features = false, features = ["rt-multi-thread", "macros", "net","sync", "time"] } -hyper = { version = "0.14", default-features = false, features = [] } +hyper = { version = "0.14", default-features = false, features = ["tcp"] } +tracing-subscriber = { version = "0.3", features = ["tracing-log"] } +clap = { version = "4.4", features = ["derive", "env"] } webpki-roots = "0.25" +pin-project = "1.1" once_cell = "1.18" +tracing = "0.1" anyhow = "1.0" +zia-common = { path = '../zia-common' } + [package.metadata.deb] maintainer-scripts = "debian/" systemd-units = { enable = false } diff --git a/zia-server/src/app.rs b/zia-server/src/app.rs deleted file mode 100644 index d50f841..0000000 --- a/zia-server/src/app.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::convert::Infallible; -use std::future::Future; -use std::mem; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; -use std::str::FromStr; -use std::sync::Arc; -use fastwebsockets::{Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite}; -use hyper::{Body, Request}; -use hyper::header::{ - CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, -}; -use hyper::upgrade::Upgraded; -use tokio::io::{ReadHalf, split, WriteHalf}; -use tokio::net::{TcpListener, TcpStream, UdpSocket}; -use tokio::sync::{Mutex as TokioMutex, Mutex}; - -struct Connection { - finished: Arc>, - buf: Arc>>, - write: Arc>>>, -} - -const UDP_ADDR: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10,99,0,0), 51838)); - -async fn test(_: Frame<'_>) -> Result<(), Infallible> { - todo!() -} - -async fn handle( - socket: TcpStream, -) -> anyhow::Result>> { - let mut ws = WebSocket::after_handshake(socket, Role::Server); - ws.set_writev(false); - ws.set_auto_close(true); - ws.set_auto_pong(true); - - let (mut ws_read, write) = ws.split(|upgraded| split(upgraded)); - - tokio::spawn(async move { - loop { - let frame = ws_read.read_frame(&mut test).await.unwrap(); - - if !frame.fin { - println!("unexpected buffer received, expect full udp frame to be in one websocket message"); - continue; - } - - match frame.opcode { - OpCode::Binary => match frame.payload { - Payload::BorrowedMut(payload) => { - socket.send_to(payload, UDP_ADDR).await.unwrap(); - } - Payload::Borrowed(payload) => { - socket.send_to(payload, UDP_ADDR).await.unwrap(); - } - Payload::Owned(payload) => { - socket.send_to(&payload, UDP_ADDR).await.unwrap(); - } - }, - opcode => eprintln!("Unexpected opcode: {:?}", opcode), - } - } - }); - - Ok(write) -} - -impl Connection { - async fn new() -> anyhow::Result<(Self, WebSocketRead>)> { - // 135.181.77.88:80 - - println!("connected https"); - - let (ws, resp) = fastwebsockets::upgrade::upgrade(&SpawnExecutor, req, listener).await?; - - println!("connected websocket"); - - let (ws_read, ws_write) = ws.split(|upgraded| split(upgraded)); - - Ok(( - Self { - finished: Arc::new(Mutex::new(true)), - buf: Arc::new(Mutex::new(datagram_buffer())), - write: Arc::new(Mutex::new(ws_write)), - }, - ws_read, - )) - } -} - -struct WritePool { - connections: Vec, - next: usize, -} - -impl WritePool { - async fn write( - &mut self, - socket: &UdpSocket, - old_addr: Arc>>, - ) -> anyhow::Result<()> { - loop { - let connection = self - .connections - .get(self.next % self.connections.len()) - .unwrap(); - - self.next += 1; - - let mut finished = connection.finished.lock().await; - if *finished { - *finished = false; - - let mut buf = connection.buf.lock().await; - let (read, addr) = socket.recv_from(&mut buf[..]).await.unwrap(); - tokio::spawn(async move { - *(old_addr.lock().await) = Some(addr); - }); - - let buf = connection.buf.clone(); - let write = connection.write.clone(); - - tokio::spawn(async move { - let mut buf = buf.lock().await; - - write - .lock() - .await - .write_frame(Frame::new( - true, - OpCode::Binary, - None, - Payload::BorrowedMut(&mut buf[..read]), - )) - .await - .unwrap(); - }); - return Ok(()); - } - } - } -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let mut listener = TcpListener::bind("127.0.0.1:3128").await?; - - let mut connections = Arc::new(Mutex::new(Vec::new())); - - println!("connected"); - - - { - let connections = connections.clone(); - tokio::spawn(async move { - loop { - let (stream, _) = listener.accept().await?; - connections.lock().await.push(handle(stream).await?); - } - }); - } - - loop { - connections.lock(). - } -} - -// Tie hyper's executor to tokio runtime -struct SpawnExecutor; - -impl hyper::rt::Executor for SpawnExecutor - where - Fut: Future + Send + 'static, - Fut::Output: Send + 'static, -{ - fn execute(&self, fut: Fut) { - tokio::task::spawn(fut); - } -} - -const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} - - diff --git a/zia-server/src/cfg.rs b/zia-server/src/cfg.rs index 975e0a6..5f21587 100644 --- a/zia-server/src/cfg.rs +++ b/zia-server/src/cfg.rs @@ -5,26 +5,33 @@ use clap::{Parser, ValueEnum}; #[derive(Parser)] #[clap(version)] -pub(crate) struct ClientCfg { +pub(crate) struct ServerCfg { #[arg(short, long, env = "ZIA_LISTEN_ADDR", default_value = "0.0.0.0:1234")] pub(crate) listen_addr: SocketAddr, #[arg(short, long, env = "ZIA_UPSTREAM")] pub(crate) upstream: String, - #[arg(short, long, env = "ZIA_MODE", default_value = "WS", value_enum, ignore_case(true))] + #[arg( + short, + long, + env = "ZIA_MODE", + default_value = "WS", + value_enum, + ignore_case(true) + )] pub(crate) mode: Mode, } #[derive(ValueEnum, Clone)] pub(crate) enum Mode { Ws, - Tcp, + // Tcp, } impl Display for Mode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Ws => f.write_str("ws"), - Self::Tcp => f.write_str("tcp"), + // Self::Tcp => f.write_str("tcp"), } } } diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index dcfcb4e..22ce2a6 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -1,33 +1,145 @@ -use crate::cfg::{ClientCfg, Mode}; +use std::convert::Infallible; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use clap::Parser; +use hyper::{Body, Request, Response, Server, StatusCode}; +use hyper::service::{make_service_fn, Service}; +use hyper::upgrade::Upgraded; +use tokio::io::split; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::signal::ctrl_c; +use tokio::sync::RwLock; +use tokio::task::JoinHandle; +use tracing::info; + +use zia_common::{MAX_DATAGRAM_SIZE, ReadConnection, ReadPool, WriteConnection, WritePool}; +use zia_common::ws::{Role, WebSocket}; + +use crate::cfg::ServerCfg; mod cfg; -mod app; + +#[pin_project::pin_project] +struct FutA { + req: Request, + read: Arc, + write: Arc>, +} + +impl Future for FutA { + type Output = Result, Infallible>; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if !fastwebsockets::upgrade::is_upgrade_request(this.req) { + return Poll::Ready(Ok( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("bad request: expected websocket upgrade")) + .unwrap(), + )); + } + + let (resp, upgrade) = fastwebsockets::upgrade::upgrade(this.req).unwrap(); + + let wread = this.read.clone(); + let wwrite = this.write.clone(); + + tokio::spawn(async move { + let ws = upgrade.await.unwrap().into_inner(); + let (read, write) = split(ws); + + let read = WebSocket::new(read, MAX_DATAGRAM_SIZE, Role::Server); + let write = WebSocket::new(write, MAX_DATAGRAM_SIZE, Role::Server); + + wread.push(ReadConnection::new(read)).await; + wwrite.push(WriteConnection::new(write)).await; + }); + + Poll::Ready(Ok(resp)) + } +} + +// mod app; +struct Handler { + read: Arc, + write: Arc>, +} + +impl Service> for Handler { + type Response = Response; + type Error = Infallible; + type Future = FutA; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + FutA { + req, + read: self.read.clone(), + write: self.write.clone(), + } + } +} #[tokio::main] async fn main() -> anyhow::Result<()> { - let config = ClientCfg::parse(); + let config = ServerCfg::parse(); tracing_subscriber::fmt::init(); - let listener: Box = match config.mode { - Mode::Ws => Box::new(WsListener { - addr: config.listen_addr, - }), - Mode::Tcp => Box::new(TcpListener { - addr: config.listen_addr, - }), - }; + let socket = Arc::new(UdpSocket::bind(Into::::into(([0, 0, 0, 0], 0))).await?); + socket.connect(&config.upstream).await?; + info!("Connected to upstream udp://{}...", config.upstream); + + let addr = Arc::new(RwLock::new(Some(socket.peer_addr()?))); + let write_pool = Arc::new(WritePool::new(socket.clone(), addr.clone())); + let read_pool = Arc::new(ReadPool::new(socket, addr)); + + let wp = write_pool.clone(); + let rp = read_pool.clone(); + + let make_service = make_service_fn(|_conn| { + let read = rp.clone(); + let write = wp.clone(); + + async move { Ok::<_, Infallible>(Handler { read, write }) } + }); - info!("Listening in {}://{}...", config.mode, config.listen_addr); + let server = Server::bind(&config.listen_addr).serve(make_service); + + info!("Listening on {}://{}...", config.mode, config.listen_addr); + + let write_handle: JoinHandle> = tokio::spawn(async move { + loop { + write_pool.execute().await?; + } + }); select! { - result = listener.listen(&config.upstream) => { - result?; + result = server => { info!("Socket closed, quitting..."); + result?; }, - result = shutdown_signal() => { + result = write_handle => { + info!("Write pool finished"); + result??; + }, + result = read_pool.join() => { + info!("Read pool finished"); result?; + }, + result = shutdown_signal() => { info!("Termination signal received, quitting..."); + result?; } } From 0a5a6eab2ae9500b8d93ad7fbc58228f349f0bff Mon Sep 17 00:00:00 2001 From: Marcel Date: Sun, 1 Oct 2023 15:27:37 +0200 Subject: [PATCH 03/19] Added option to disable masking removed additional buffers --- .editorconfig | 3 + zia-client/src/app.rs | 17 ++++- zia-client/src/cfg.rs | 7 ++ zia-client/src/main.rs | 6 +- zia-common/Cargo.toml | 4 +- zia-common/src/write.rs | 101 ++++++++++++++-------------- zia-common/src/ws/frame.rs | 115 ++++++++++++------------------- zia-common/src/ws/mod.rs | 2 +- zia-common/src/ws/ws.rs | 134 ++++++++++++++++--------------------- zia-server/src/main.rs | 4 +- 10 files changed, 188 insertions(+), 205 deletions(-) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..0da8f80 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,3 @@ +[*] +indent_size = 2 +indent_style = space diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs index 845b2c2..b8eb3c8 100644 --- a/zia-client/src/app.rs +++ b/zia-client/src/app.rs @@ -36,6 +36,7 @@ static TLS_CONNECTOR: Lazy = Lazy::new(|| { pub(crate) async fn open_connection( upstream: &Url, proxy: &Option, + websocket_masking: bool, ) -> anyhow::Result<(ReadConnection, WriteConnection)> { let upstream_host = upstream .host_str() @@ -111,8 +112,20 @@ pub(crate) async fn open_connection( let (read, write) = split(ws.into_inner()); - let read = WebSocket::new(read, MAX_DATAGRAM_SIZE, Role::Client); - let write = WebSocket::new(write, MAX_DATAGRAM_SIZE, Role::Client); + let read = WebSocket::new( + read, + MAX_DATAGRAM_SIZE, + Role::Client { + masking: websocket_masking, + }, + ); + let write = WebSocket::new( + write, + MAX_DATAGRAM_SIZE, + Role::Client { + masking: websocket_masking, + }, + ); Ok((ReadConnection::new(read), WriteConnection::new(write))) } diff --git a/zia-client/src/cfg.rs b/zia-client/src/cfg.rs index 5d4e0e7..83f9555 100644 --- a/zia-client/src/cfg.rs +++ b/zia-client/src/cfg.rs @@ -14,4 +14,11 @@ pub(crate) struct ClientCfg { pub(crate) proxy: Option, #[arg(short, long, env = "ZIA_COUNT")] pub(crate) count: usize, + #[arg( + short = 'm', + long, + env = "ZIA_WEBSOCKET_MASKING", + default_value = "false" + )] + pub(crate) websocket_masking: bool, } diff --git a/zia-client/src/main.rs b/zia-client/src/main.rs index 67bb4dc..2b54f5e 100644 --- a/zia-client/src/main.rs +++ b/zia-client/src/main.rs @@ -25,7 +25,7 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); select! { - result = tokio::spawn(listen(config.listen_addr, config.upstream, config.proxy, config.count)) => { + result = tokio::spawn(listen(config.listen_addr, config.upstream, config.proxy, config.count, config.websocket_masking)) => { result??; info!("Socket closed, quitting..."); }, @@ -70,6 +70,7 @@ async fn listen( upstream: Url, proxy: Option, connection_count: usize, + websocket_masking: bool, ) -> anyhow::Result<()> { let socket = Arc::new(UdpSocket::bind(addr).await?); @@ -80,7 +81,8 @@ async fn listen( for _ in 0..connection_count { let upstream = upstream.clone(); let proxy = proxy.clone(); - conns.spawn(async move { open_connection(&upstream, &proxy).await }); + let websocket_masking = websocket_masking; + conns.spawn(async move { open_connection(&upstream, &proxy, websocket_masking).await }); } let addr = Arc::new(RwLock::new(Option::None)); diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml index ad09daa..4705f60 100644 --- a/zia-common/Cargo.toml +++ b/zia-common/Cargo.toml @@ -9,6 +9,6 @@ description = "Proxy udp over websocket, useful to use Wireguard in restricted n [dependencies] tokio = { version = "1.32", default-features = false, features = ["net", "sync"] } hyper = { version = "0.14", default-features = false, features = [] } -anyhow = { version = "1.0", features = [] } -rand = {version = "0.8" ,default-features = false} +rand = { version = "0.8", default-features = false } tracing = "0.1" +anyhow = "1.0" diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index 1cf52fe..f555c8a 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -8,79 +8,80 @@ use tokio::sync::RwLock; use tracing::error; use crate::pool::Pool; -use crate::ws::WebSocket; +use crate::ws::{Frame, WebSocket}; pub struct WriteConnection { - write: WebSocket>, - buf: Box<[u8; MAX_DATAGRAM_SIZE]>, + write: WebSocket>, + buf: Box<[u8; MAX_DATAGRAM_SIZE]>, } impl WriteConnection { - pub fn new(write: WebSocket>) -> Self { - Self { - buf: datagram_buffer(), - write, - } + pub fn new(write: WebSocket>) -> Self { + Self { + buf: datagram_buffer(), + write, } + } - async fn flush(&mut self, size: usize) -> anyhow::Result<()> { - assert!(size <= MAX_DATAGRAM_SIZE); + async fn flush(&mut self, size: usize) -> anyhow::Result<()> { + assert!(size <= MAX_DATAGRAM_SIZE); - self.write.send(&self.buf[..size]).await?; + let frame = Frame::binary(&self.buf[..size]); + self.write.send(frame).await?; - Ok(()) - } + Ok(()) + } } pub struct WritePool { - socket: Arc, - pool: Pool>, - addr: Arc>>, + socket: Arc, + pool: Pool>, + addr: Arc>>, } impl WritePool { - pub fn new(socket: Arc, addr: Arc>>) -> Self { - Self { - socket, - pool: Pool::new(), - addr, - } + pub fn new(socket: Arc, addr: Arc>>) -> Self { + Self { + socket, + pool: Pool::new(), + addr, } - - async fn update_addr(&self, addr: SocketAddr) { - let is_outdated = self - .addr - .read() - .await - .map(|last_addr| last_addr != addr) - .unwrap_or(true); - - if is_outdated { - *(self.addr.write().await) = Some(addr); - } + } + + async fn update_addr(&self, addr: SocketAddr) { + let is_outdated = self + .addr + .read() + .await + .map(|last_addr| last_addr != addr) + .unwrap_or(true); + + if is_outdated { + *(self.addr.write().await) = Some(addr); } + } - pub async fn push(&self, conn: WriteConnection) { - self.pool.push(conn); - } + pub async fn push(&self, conn: WriteConnection) { + self.pool.push(conn); + } - pub async fn execute(&self) -> anyhow::Result<()> { - loop { - let mut conn = self.pool.acquire().await; + pub async fn execute(&self) -> anyhow::Result<()> { + loop { + let mut conn = self.pool.acquire().await; - // read from udp socket and save to buf of selected conn - let (read, addr) = self.socket.recv_from(conn.buf.as_mut()).await.unwrap(); + // read from udp socket and save to buf of selected conn + let (read, addr) = self.socket.recv_from(conn.buf.as_mut()).await.unwrap(); - self.update_addr(addr).await; + self.update_addr(addr).await; - // flush buf of conn asynchronously to read again from udp socket in parallel - tokio::spawn(async move { - if let Err(err) = conn.flush(read).await { - error!("Unable to flush websocket buf: {:?}", err); - } - }); + // flush buf of conn asynchronously to read again from udp socket in parallel + tokio::spawn(async move { + if let Err(err) = conn.flush(read).await { + error!("Unable to flush websocket buf: {:?}", err); } + }); } + } } /// Creates and returns a buffer on the heap with enough space to contain any possible @@ -90,5 +91,5 @@ impl WritePool { /// up on the stack and blowing up the size of the futures using it. #[inline(never)] fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) + Box::new([0u8; MAX_DATAGRAM_SIZE]) } diff --git a/zia-common/src/ws/frame.rs b/zia-common/src/ws/frame.rs index 4e4cb2a..2c61f75 100644 --- a/zia-common/src/ws/frame.rs +++ b/zia-common/src/ws/frame.rs @@ -1,4 +1,4 @@ -#![doc(hidden)] +use tokio::io::{AsyncWrite, AsyncWriteExt}; pub struct Frame<'a> { pub fin: bool, @@ -8,90 +8,63 @@ pub struct Frame<'a> { impl<'a> Frame<'a> { #[inline] - pub fn encode_without_mask(self) -> Vec { - let mut buf = Vec::::with_capacity(10 + self.data.len()); - unsafe { - let dist = buf.as_mut_ptr(); - let head_len = self.encode_header_unchecked(dist, 0); - std::ptr::copy_nonoverlapping(self.data.as_ptr(), dist.add(head_len), self.data.len()); - buf.set_len(head_len + self.data.len()); + pub fn binary(data: &'a [u8]) -> Self { + Self { + fin: true, + opcode: 2, + data, } - buf } #[inline] - pub fn encode_with(self, mask: [u8; 4]) -> Vec { - let mut buf = Vec::::with_capacity(14 + self.data.len()); - unsafe { - let dist = buf.as_mut_ptr(); - let head_len = self.encode_header_unchecked(dist, 0x80); + pub async fn write_without_mask( + self, + write: &mut W, + ) -> anyhow::Result<()> { + self.write_header(write, 0).await?; + write.write_all(self.data).await?; + + Ok(()) + } - let [a, b, c, d] = mask; - dist.add(head_len).write(a); - dist.add(head_len + 1).write(b); - dist.add(head_len + 2).write(c); - dist.add(head_len + 3).write(d); + #[inline] + pub async fn write_with_mask( + self, + write: &mut W, + mask: [u8; 4], + ) -> anyhow::Result<()> { + self.write_header(write, 0x80).await?; + write.write_all(&mask).await?; - let dist = dist.add(head_len + 4); + for i in 0..self.data.len() { // TODO: Use SIMD wherever possible for best performance - for i in 0..self.data.len() { - dist - .add(i) - .write(self.data.get_unchecked(i) ^ mask.get_unchecked(i & 3)); - } - buf.set_len(head_len + 4 + self.data.len()); + write + .write_u8(unsafe { self.data.get_unchecked(i) ^ mask.get_unchecked(i & 3) }) + .await? } - buf + + Ok(()) } - /// # SAFETY - /// - /// - `dist` must be valid for writes of 10 bytes. - pub(crate) unsafe fn encode_header_unchecked(&self, dist: *mut u8, mask_bit: u8) -> usize { - dist.write(((self.fin as u8) << 7) | self.opcode); + async fn write_header( + &self, + write: &mut W, + mask_bit: u8, + ) -> anyhow::Result<()> { + write + .write_u8(((self.fin as u8) << 7) | self.opcode) + .await?; + if self.data.len() < 126 { - dist.add(1).write(mask_bit | self.data.len() as u8); - 2 + write.write_u8(mask_bit | self.data.len() as u8).await?; } else if self.data.len() < 65536 { - let [b2, b3] = (self.data.len() as u16).to_be_bytes(); - dist.add(1).write(mask_bit | 126); - dist.add(2).write(b2); - dist.add(3).write(b3); - 4 + write.write_u8(mask_bit | 126).await?; + write.write_u16(self.data.len() as u16).await?; } else { - let [b2, b3, b4, b5, b6, b7, b8, b9] = (self.data.len() as u64).to_be_bytes(); - dist.add(1).write(mask_bit | 127); - dist.add(2).write(b2); - dist.add(3).write(b3); - dist.add(4).write(b4); - dist.add(5).write(b5); - dist.add(6).write(b6); - dist.add(7).write(b7); - dist.add(8).write(b8); - dist.add(9).write(b9); - 10 - } - } -} - -impl<'a> From<&'a str> for Frame<'a> { - #[inline] - fn from(string: &'a str) -> Self { - Self { - fin: true, - opcode: 1, - data: string.as_bytes(), + write.write_u8(mask_bit | 127).await?; + write.write_u64(self.data.len() as u64).await?; } - } -} -impl<'a> From<&'a [u8]> for Frame<'a> { - #[inline] - fn from(data: &'a [u8]) -> Self { - Self { - fin: true, - opcode: 2, - data, - } + Ok(()) } } diff --git a/zia-common/src/ws/mod.rs b/zia-common/src/ws/mod.rs index b5521b6..3dea435 100644 --- a/zia-common/src/ws/mod.rs +++ b/zia-common/src/ws/mod.rs @@ -6,7 +6,7 @@ mod ws; pub enum Role { Server, - Client, + Client { masking: bool }, } /// It represent the type of data that is being sent over the WebSocket connection. diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index d87d4f9..4457f9d 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,14 +1,14 @@ -use std::io::{IoSlice, Result}; +use std::io::Result; use anyhow::anyhow; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::ws::{CloseReason, Event, Frame, Role}; +use crate::ws::{Event, Frame, Role}; /// WebSocket implementation for both client and server -pub struct WebSocket { +pub struct WebSocket { /// it is a low-level abstraction that represents the underlying byte stream over which WebSocket messages are exchanged. - pub stream: Stream, + pub io: IO, /// Maximum allowed payload length in bytes. pub max_payload_len: usize, @@ -21,7 +21,7 @@ impl WebSocket { #[inline] pub fn new(stream: IO, max_payload_len: usize, role: Role) -> Self { Self { - stream, + io: stream, max_payload_len, role, is_closed: false, @@ -29,64 +29,43 @@ impl WebSocket { } } -impl WebSocket -where - W: Unpin + AsyncWrite, -{ - #[doc(hidden)] - pub async fn send_raw(&mut self, frame: Frame<'_>) -> Result<()> { - let buf = match self.role { - Role::Server => { - if self.stream.is_write_vectored() { - let mut head = [0; 10]; - let head_len = unsafe { frame.encode_header_unchecked(head.as_mut_ptr(), 0) }; - let total_len = head_len + frame.data.len(); - - let mut bufs = [IoSlice::new(&head[..head_len]), IoSlice::new(frame.data)]; - let mut amt = self.stream.write_vectored(&bufs).await?; - if amt == total_len { - return Ok(()); - } - while amt < head_len { - bufs[0] = IoSlice::new(&head[amt..head_len]); - amt += self.stream.write_vectored(&bufs).await?; - } - if amt < total_len { - self.stream.write_all(&frame.data[amt - head_len..]).await?; - } - return Ok(()); +impl WebSocket { + pub async fn send(&mut self, frame: Frame<'_>) -> anyhow::Result<()> { + match self.role { + Role::Server => frame.write_without_mask(&mut self.io).await?, + Role::Client { masking } => { + if masking { + let mask = rand::random::().to_ne_bytes(); + frame.write_with_mask(&mut self.io, mask).await?; + } else { + frame.write_without_mask(&mut self.io).await?; } - frame.encode_without_mask() } - Role::Client => frame.encode_with(rand::random::().to_ne_bytes()), - }; - self.stream.write_all(&buf).await - } - - /// Send message to a endpoint. - pub async fn send(&mut self, data: impl Into>) -> Result<()> { - self.send_raw(data.into()).await - } + } - /// - The Close frame MAY contain a body that indicates a reason for closing. - pub async fn close(mut self, reason: T) -> Result<()> - where - T: CloseReason, - T::Bytes: AsRef<[u8]>, - { - self - .send_raw(Frame { - fin: true, - opcode: 8, - data: reason.to_bytes().as_ref(), - }) - .await?; - self.stream.flush().await + Ok(()) } - /// Flushes this output stream, ensuring that all intermediately buffered contents reach their destination. - pub async fn flash(&mut self) -> Result<()> { - self.stream.flush().await + // TODO: implement close + // pub async fn close(mut self, reason: T) -> anyhow::Result<()> + // where + // T: CloseReason, + // T::Bytes: AsRef<[u8]>, + // { + // let frame = Frame { + // fin: true, + // opcode: 8, + // data: reason.to_bytes().as_ref(), + // }; + // + // self.send(frame).await?; + // self.flush().await?; + // Ok(()) + // } + + pub async fn flush(&mut self) -> anyhow::Result<()> { + self.io.flush().await?; + Ok(()) } } @@ -138,7 +117,7 @@ where /// reads [Event] from websocket stream. pub async fn recv_event(&mut self) -> anyhow::Result { let mut buf = [0u8; 2]; - self.stream.read_exact(&mut buf).await?; + self.io.read_exact(&mut buf).await?; let [b1, b2] = buf; @@ -168,9 +147,10 @@ where // // A server MUST NOT mask any frames that it sends to the client. if let Role::Server = self.role { - if !is_masked { - err!("expected masked frame"); - } + // TODO: disabled, to allow unmasked client frames + // if !is_masked { + // err!("expected masked frame"); + // } } else if is_masked { err!("expected unmasked frame"); } @@ -183,7 +163,7 @@ where if len > 125 { err!("control frame must have a payload length of 125 bytes or less"); } - let msg = self.read_payload(len).await?; + let msg = self.read_payload(is_masked, len).await?; match opcode { 8 => on_close(&msg), // 9 => Ok(Event::Ping(msg)), @@ -197,32 +177,36 @@ where _ => err!("invalid data frame"), }; let len = match len { - 126 => self.stream.read_u16().await? as usize, - 127 => self.stream.read_u64().await? as usize, + 126 => self.io.read_u16().await? as usize, + 127 => self.io.read_u64().await? as usize, len => len, }; if len > self.max_payload_len { err!("payload too large"); } - let data = self.read_payload(len).await?; + let data = self.read_payload(is_masked, len).await?; Ok(Event::Data(data)) } } - async fn read_payload(&mut self, len: usize) -> Result> { + async fn read_payload(&mut self, masked: bool, len: usize) -> Result> { let mut data = vec![0; len].into_boxed_slice(); match self.role { Role::Server => { - let mut mask = [0u8; 4]; - self.stream.read_exact(&mut mask).await?; - self.stream.read_exact(&mut data).await?; - // TODO: Use SIMD wherever possible for best performance - for i in 0..data.len() { - data[i] ^= mask[i & 3]; + if masked { + let mut mask = [0u8; 4]; + self.io.read_exact(&mut mask).await?; + self.io.read_exact(&mut data).await?; + // TODO: Use SIMD wherever possible for best performance + for i in 0..data.len() { + data[i] ^= mask[i & 3]; + } + } else { + self.io.read_exact(&mut data).await?; } } - Role::Client => { - self.stream.read_exact(&mut data).await?; + Role::Client { .. } => { + self.io.read_exact(&mut data).await?; } } Ok(data) diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index 22ce2a6..21b5306 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -6,9 +6,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use clap::Parser; -use hyper::{Body, Request, Response, Server, StatusCode}; use hyper::service::{make_service_fn, Service}; use hyper::upgrade::Upgraded; +use hyper::{Body, Request, Response, Server, StatusCode}; use tokio::io::split; use tokio::net::UdpSocket; use tokio::select; @@ -17,8 +17,8 @@ use tokio::sync::RwLock; use tokio::task::JoinHandle; use tracing::info; -use zia_common::{MAX_DATAGRAM_SIZE, ReadConnection, ReadPool, WriteConnection, WritePool}; use zia_common::ws::{Role, WebSocket}; +use zia_common::{ReadConnection, ReadPool, WriteConnection, WritePool, MAX_DATAGRAM_SIZE}; use crate::cfg::ServerCfg; From bd631d9a40fc58c13677c993d3dea5ee6ec8c864 Mon Sep 17 00:00:00 2001 From: Marcel Date: Sun, 1 Oct 2023 20:25:21 +0200 Subject: [PATCH 04/19] Improved websocket frame parsing, removed dynamic allocations of payload buffers --- zia-client/src/app.rs | 4 +- zia-common/src/lib.rs | 10 +++ zia-common/src/read.rs | 10 ++- zia-common/src/write.rs | 12 +--- zia-common/src/ws/frame.rs | 126 +++++++++++++++++++++++++++++++--- zia-common/src/ws/mod.rs | 31 +-------- zia-common/src/ws/ws.rs | 135 ++++++------------------------------- 7 files changed, 158 insertions(+), 170 deletions(-) diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs index b8eb3c8..7f3d45c 100644 --- a/zia-client/src/app.rs +++ b/zia-client/src/app.rs @@ -9,7 +9,7 @@ use hyper::header::{ use hyper::upgrade::Upgraded; use hyper::{Body, Request}; use once_cell::sync::Lazy; -use tokio::io::split; +use tokio::io::{split, BufStream}; use tokio::net::TcpStream; use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; use tokio_rustls::TlsConnector; @@ -89,6 +89,8 @@ pub(crate) async fn open_connection( } }; + let stream = BufStream::new(stream); + let req = Request::get(upstream.to_string()) .header(HOST, format!("{}:{}", upstream_host, upstream_port)) .header(UPGRADE, "websocket") diff --git a/zia-common/src/lib.rs b/zia-common/src/lib.rs index 411c423..a30bd5f 100644 --- a/zia-common/src/lib.rs +++ b/zia-common/src/lib.rs @@ -8,3 +8,13 @@ mod write; pub mod ws; pub const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); + +/// Creates and returns a buffer on the heap with enough space to contain any possible +/// UDP datagram. +/// +/// This is put on the heap and in a separate function to avoid the 64k buffer from ending +/// up on the stack and blowing up the size of the futures using it. +#[inline(never)] +pub fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { + Box::new([0u8; MAX_DATAGRAM_SIZE]) +} diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs index b631aad..65a8f64 100644 --- a/zia-common/src/read.rs +++ b/zia-common/src/read.rs @@ -8,6 +8,7 @@ use tokio::select; use tokio::sync::{Mutex, RwLock}; use tokio::task::{JoinError, JoinSet}; use tracing::error; +use crate::datagram_buffer; use crate::ws::{Event, WebSocket}; @@ -24,13 +25,14 @@ impl ReadConnection { &mut self, socket: &UdpSocket, addr: &RwLock>, + buf: &mut [u8], ) -> anyhow::Result<()> { - let event = self.read.recv().await?; + let event = self.read.recv(buf).await?; match event { Event::Data(data) => { let addr = addr.read().await.unwrap(); - socket.send_to(&data, addr).await?; + socket.send_to(data, addr).await?; } Event::Close { .. } => {} } @@ -53,6 +55,7 @@ impl ReadPool { tasks: Mutex::new(JoinSet::new()), } } + async fn wait(&self) -> Option, JoinError>> { let mut set = self.tasks.lock().await; select! { @@ -81,8 +84,9 @@ impl ReadPool { let addr = self.addr.clone(); self.tasks.lock().await.spawn(async move { + let mut buf = datagram_buffer(); loop { - conn.handle_frame(&socket, &addr).await?; + conn.handle_frame(&socket, &addr, buf.as_mut()).await?; } }); } diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index f555c8a..806f792 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -1,7 +1,7 @@ use std::net::SocketAddr; use std::sync::Arc; -use crate::MAX_DATAGRAM_SIZE; +use crate::{datagram_buffer, MAX_DATAGRAM_SIZE}; use tokio::io::{AsyncWrite, WriteHalf}; use tokio::net::UdpSocket; use tokio::sync::RwLock; @@ -83,13 +83,3 @@ impl WritePool { } } } - -/// Creates and returns a buffer on the heap with enough space to contain any possible -/// UDP datagram. -/// -/// This is put on the heap and in a separate function to avoid the 64k buffer from ending -/// up on the stack and blowing up the size of the futures using it. -#[inline(never)] -fn datagram_buffer() -> Box<[u8; MAX_DATAGRAM_SIZE]> { - Box::new([0u8; MAX_DATAGRAM_SIZE]) -} diff --git a/zia-common/src/ws/frame.rs b/zia-common/src/ws/frame.rs index 2c61f75..57f3bb1 100644 --- a/zia-common/src/ws/frame.rs +++ b/zia-common/src/ws/frame.rs @@ -1,8 +1,36 @@ -use tokio::io::{AsyncWrite, AsyncWriteExt}; +use anyhow::anyhow; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +#[repr(u8)] +#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub enum OpCode { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, +} + +impl TryFrom for OpCode { + type Error = anyhow::Error; + + fn try_from(value: u8) -> Result { + match value { + 0x0 => Ok(Self::Continuation), + 0x1 => Ok(Self::Text), + 0x2 => Ok(Self::Binary), + 0x8 => Ok(Self::Close), + 0x9 => Ok(Self::Ping), + 0xA => Ok(Self::Pong), + value => Err(anyhow!("unimplemented opcode: {}", value)), + } + } +} pub struct Frame<'a> { pub fin: bool, - pub opcode: u8, + pub opcode: OpCode, pub data: &'a [u8], } @@ -11,12 +39,67 @@ impl<'a> Frame<'a> { pub fn binary(data: &'a [u8]) -> Self { Self { fin: true, - opcode: 2, + opcode: OpCode::Binary, data, } } - #[inline] + pub async fn read( + read: &mut R, + buf: &'a mut [u8], + max_payload_len: usize, + ) -> anyhow::Result> { + let [b1, b2] = { + let mut header = [0u8; 2]; + read.read_exact(&mut header).await?; + header + }; + + let fin = b1 & 0b1000_0000 != 0; + let rsv = b1 & 0b0111_0000; + let opcode = OpCode::try_from(b1 & 0b0000_1111)?; + + let len = (b2 & 0b0111_1111) as usize; + let masked = b2 & 0b_1000_0000 != 0; + + if rsv != 0 { + return Err(anyhow!("reserve bit must be `0`")); + } + + let len = match opcode { + OpCode::Continuation | OpCode::Text | OpCode::Binary => match len { + 126 => read.read_u16().await? as usize, + 127 => read.read_u64().await? as usize, + len => len, + }, + OpCode::Close | OpCode::Ping | OpCode::Pong => { + if !fin { + return Err(anyhow!("control frame must not be fragmented")); + } + + if len > 125 { + return Err(anyhow!( + "control frame must have a payload length of 125 bytes or less" + )); + } + + len + } + }; + + if len > max_payload_len { + return Err(anyhow!("payload too large")); + } + + read_payload(read, &mut buf[..len], masked).await?; + + Ok(Self { + fin, + opcode, + data: &buf[..len], + }) + } + pub async fn write_without_mask( self, write: &mut W, @@ -27,7 +110,6 @@ impl<'a> Frame<'a> { Ok(()) } - #[inline] pub async fn write_with_mask( self, write: &mut W, @@ -52,19 +134,41 @@ impl<'a> Frame<'a> { mask_bit: u8, ) -> anyhow::Result<()> { write - .write_u8(((self.fin as u8) << 7) | self.opcode) + .write_u8(((self.fin as u8) << 7) | self.opcode as u8) .await?; - if self.data.len() < 126 { - write.write_u8(mask_bit | self.data.len() as u8).await?; - } else if self.data.len() < 65536 { + let len = self.data.len(); + + if len < 126 { + write.write_u8(mask_bit | len as u8).await?; + } else if len < 65536 { write.write_u8(mask_bit | 126).await?; - write.write_u16(self.data.len() as u16).await?; + write.write_u16(len as u16).await?; } else { write.write_u8(mask_bit | 127).await?; - write.write_u64(self.data.len() as u64).await?; + write.write_u64(len as u64).await?; } Ok(()) } } + +async fn read_payload( + read: &mut R, + buf: &mut [u8], + masked: bool, +) -> anyhow::Result<()> { + if masked { + let mut mask = [0u8; 4]; + read.read_exact(&mut mask).await?; + read.read_exact(buf).await?; + // TODO: Use SIMD wherever possible for best performance + for i in 0..buf.len() { + buf[i] ^= mask[i & 3]; + } + } else { + read.read_exact(buf).await?; + } + + Ok(()) +} diff --git a/zia-common/src/ws/mod.rs b/zia-common/src/ws/mod.rs index 3dea435..efb1cba 100644 --- a/zia-common/src/ws/mod.rs +++ b/zia-common/src/ws/mod.rs @@ -9,34 +9,9 @@ pub enum Role { Client { masking: bool }, } -/// It represent the type of data that is being sent over the WebSocket connection. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageType { - /// `Text` data is represented as a sequence of Unicode characters encoded using UTF-8 encoding. - Text = 1, - /// `Binary` data can be any sequence of bytes and is typically used for sending non-textual data, such as images, audio files etc... - Binary = 2, -} - -impl MessageType { - /// Returns `true` if message type is text - #[inline] - pub fn is_text(&self) -> bool { - matches!(self, MessageType::Text) - } - - /// Returns `true` if message type is binary - #[inline] - pub fn is_binary(&self) -> bool { - matches!(self, MessageType::Binary) - } -} - -#[derive(Debug)] -pub enum Event { - Data(Box<[u8]>), - - Close { code: CloseCode, reason: Box }, +pub enum Event<'a> { + Data(&'a [u8]), + Close { code: CloseCode, reason: String }, } /// When closing an established connection an endpoint MAY indicate a reason for closure. diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index 4457f9d..05b53e3 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,8 +1,7 @@ -use std::io::Result; - use anyhow::anyhow; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use crate::ws::frame::OpCode; use crate::ws::{Event, Frame, Role}; /// WebSocket implementation for both client and server @@ -43,6 +42,8 @@ impl WebSocket { } } + self.io.flush().await?; + Ok(()) } @@ -71,21 +72,19 @@ impl WebSocket { // ------------------------------------------------------------------------ -macro_rules! err { [$msg: expr] => { return Err(anyhow!($msg)) }; } - impl WebSocket where R: Unpin + AsyncRead, { /// reads [Event] from websocket stream. - pub async fn recv(&mut self) -> anyhow::Result { + pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { if self.is_closed { return Err(std::io::Error::new( std::io::ErrorKind::NotConnected, "read after close", ))?; } - let event = self.recv_event().await; + let event = self.recv_event(buf).await; if let Ok(Event::Close { .. }) | Err(..) = event { self.is_closed = true; } @@ -115,121 +114,25 @@ where // +---------------------------------------------------------------+ // ``` /// reads [Event] from websocket stream. - pub async fn recv_event(&mut self) -> anyhow::Result { - let mut buf = [0u8; 2]; - self.io.read_exact(&mut buf).await?; - - let [b1, b2] = buf; - - let fin = b1 & 0b1000_0000 != 0; - let rsv = b1 & 0b111_0000; - let opcode = b1 & 0b1111; - let len = (b2 & 0b111_1111) as usize; - - // Defines whether the "Payload data" is masked. If set to 1, a - // masking key is present in masking-key, and this is used to unmask - // the "Payload data" as per [Section 5.3](https://datatracker.ietf.org/doc/html/rfc6455#section-5.3). All frames sent from - // client to server have this bit set to 1. - let is_masked = b2 & 0b_1000_0000 != 0; - - if rsv != 0 { - // MUST be `0` unless an extension is negotiated that defines meanings - // for non-zero values. If a nonzero value is received and none of - // the negotiated extensions defines the meaning of such a nonzero - // value, the receiving endpoint MUST _Fail the WebSocket Connection_. - err!("reserve bit must be `0`"); - } + pub async fn recv_event<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { + let frame = Frame::read(&mut self.io, buf, self.max_payload_len).await?; - // A client MUST mask all frames that it sends to the server. (Note - // that masking is done whether or not the WebSocket Protocol is running - // over TLS.) The server MUST close the connection upon receiving a - // frame that is not masked. - // - // A server MUST NOT mask any frames that it sends to the client. - if let Role::Server = self.role { - // TODO: disabled, to allow unmasked client frames - // if !is_masked { - // err!("expected masked frame"); - // } - } else if is_masked { - err!("expected unmasked frame"); + if !frame.fin { + return Err(anyhow!("framed messages are not supported")); } - // 3-7 are reserved for further non-control frames. - if opcode >= 8 { - if !fin { - err!("control frame must not be fragmented"); - } - if len > 125 { - err!("control frame must have a payload length of 125 bytes or less"); - } - let msg = self.read_payload(is_masked, len).await?; - match opcode { - 8 => on_close(&msg), - // 9 => Ok(Event::Ping(msg)), - // 10 => Ok(Event::Pong(msg)), - // 11-15 are reserved for further control frames - _ => err!("unknown opcode"), - } - } else { - match (opcode, fin) { - (2, true) => {} - _ => err!("invalid data frame"), - }; - let len = match len { - 126 => self.io.read_u16().await? as usize, - 127 => self.io.read_u64().await? as usize, - len => len, - }; - if len > self.max_payload_len { - err!("payload too large"); - } - let data = self.read_payload(is_masked, len).await?; - Ok(Event::Data(data)) - } - } - - async fn read_payload(&mut self, masked: bool, len: usize) -> Result> { - let mut data = vec![0; len].into_boxed_slice(); - match self.role { - Role::Server => { - if masked { - let mut mask = [0u8; 4]; - self.io.read_exact(&mut mask).await?; - self.io.read_exact(&mut data).await?; - // TODO: Use SIMD wherever possible for best performance - for i in 0..data.len() { - data[i] ^= mask[i & 3]; - } - } else { - self.io.read_exact(&mut data).await?; - } - } - Role::Client { .. } => { - self.io.read_exact(&mut data).await?; - } + match frame.opcode { + OpCode::Continuation => Err(anyhow!("framed messages are not supported")), + OpCode::Text => Err(anyhow!("text framed are not supported")), + OpCode::Binary => Ok(Event::Data(frame.data)), + OpCode::Close => Ok(parse_close_body(frame.data)?), + OpCode::Ping => Err(anyhow!("ping frames are not supported")), + OpCode::Pong => Err(anyhow!("pong framed are not supported")), } - Ok(data) } } -/// - If there is a body, the first two bytes of the body MUST be a 2-byte unsigned integer (in network byte order: Big Endian) -/// representing a status code with value /code/ defined in [Section 7.4](https:///datatracker.ietf.org/doc/html/rfc6455#section-7.4). -/// Following the 2-byte integer, -/// -/// - The application MUST NOT send any more data frames after sending a `Close` frame. -/// -/// - If an endpoint receives a Close frame and did not previously send a -/// Close frame, the endpoint MUST send a Close frame in response. (When -/// sending a Close frame in response, the endpoint typically echos the -/// status code it received.) It SHOULD do so as soon as practical. An -/// endpoint MAY delay sending a Close frame until its current message is -/// sent -/// -/// - After both sending and receiving a Close message, an endpoint -/// considers the WebSocket connection closed and MUST close the -/// underlying TCP connection. -fn on_close(msg: &[u8]) -> anyhow::Result { +fn parse_close_body(msg: &[u8]) -> anyhow::Result { let code = msg .get(..2) .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) @@ -240,7 +143,7 @@ fn on_close(msg: &[u8]) -> anyhow::Result { match msg.get(2..).map(|data| String::from_utf8(data.to_vec())) { Some(Ok(msg)) => Ok(Event::Close { code: code.into(), - reason: msg.into_boxed_str(), + reason: msg, }), None => Ok(Event::Close { code: code.into(), From 6ee4a3c84ff58b65b2dde30b7f7fbac2e96bc0d9 Mon Sep 17 00:00:00 2001 From: Marcel Date: Sun, 1 Oct 2023 20:42:35 +0200 Subject: [PATCH 05/19] fixed typo --- zia-common/src/ws/ws.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index 05b53e3..baafd32 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -123,11 +123,11 @@ where match frame.opcode { OpCode::Continuation => Err(anyhow!("framed messages are not supported")), - OpCode::Text => Err(anyhow!("text framed are not supported")), + OpCode::Text => Err(anyhow!("text frames are not supported")), OpCode::Binary => Ok(Event::Data(frame.data)), OpCode::Close => Ok(parse_close_body(frame.data)?), OpCode::Ping => Err(anyhow!("ping frames are not supported")), - OpCode::Pong => Err(anyhow!("pong framed are not supported")), + OpCode::Pong => Err(anyhow!("pong frames are not supported")), } } } From a3834c6eaf3791625101100b71dc1ad6a861970d Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 2 Oct 2023 09:40:32 +0200 Subject: [PATCH 06/19] added messages, added splitting, basic close handling --- zia-client/src/app.rs | 17 ++-- zia-common/src/read.rs | 10 +-- zia-common/src/write.rs | 6 +- zia-common/src/ws/frame.rs | 45 +++++++--- zia-common/src/ws/mod.rs | 62 ++----------- zia-common/src/ws/ws.rs | 173 +++++++++++++++++++------------------ zia-server/src/main.rs | 6 +- 7 files changed, 144 insertions(+), 175 deletions(-) diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs index 7f3d45c..bc69f63 100644 --- a/zia-client/src/app.rs +++ b/zia-client/src/app.rs @@ -9,7 +9,7 @@ use hyper::header::{ use hyper::upgrade::Upgraded; use hyper::{Body, Request}; use once_cell::sync::Lazy; -use tokio::io::{split, BufStream}; +use tokio::io::BufStream; use tokio::net::TcpStream; use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; use tokio_rustls::TlsConnector; @@ -112,23 +112,16 @@ pub(crate) async fn open_connection( info!("Finished websocket handshake"); - let (read, write) = split(ws.into_inner()); - - let read = WebSocket::new( - read, - MAX_DATAGRAM_SIZE, - Role::Client { - masking: websocket_masking, - }, - ); - let write = WebSocket::new( - write, + let ws = WebSocket::new( + ws.into_inner(), MAX_DATAGRAM_SIZE, Role::Client { masking: websocket_masking, }, ); + let (read, write) = ws.split(); + Ok((ReadConnection::new(read), WriteConnection::new(write))) } diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs index 65a8f64..655e399 100644 --- a/zia-common/src/read.rs +++ b/zia-common/src/read.rs @@ -2,15 +2,15 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use crate::datagram_buffer; use tokio::io::{AsyncRead, ReadHalf}; use tokio::net::UdpSocket; use tokio::select; use tokio::sync::{Mutex, RwLock}; use tokio::task::{JoinError, JoinSet}; use tracing::error; -use crate::datagram_buffer; -use crate::ws::{Event, WebSocket}; +use crate::ws::{Message, WebSocket}; pub struct ReadConnection { read: WebSocket>, @@ -30,11 +30,11 @@ impl ReadConnection { let event = self.read.recv(buf).await?; match event { - Event::Data(data) => { + Message::Binary(data) => { let addr = addr.read().await.unwrap(); socket.send_to(data, addr).await?; } - Event::Close { .. } => {} + Message::Close { .. } => {} } Ok(()) @@ -86,7 +86,7 @@ impl ReadPool { self.tasks.lock().await.spawn(async move { let mut buf = datagram_buffer(); loop { - conn.handle_frame(&socket, &addr, buf.as_mut()).await?; + conn.handle_frame(&socket, &addr, buf.as_mut()).await?; } }); } diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index 806f792..9e9273f 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -8,7 +8,7 @@ use tokio::sync::RwLock; use tracing::error; use crate::pool::Pool; -use crate::ws::{Frame, WebSocket}; +use crate::ws::{Message, WebSocket}; pub struct WriteConnection { write: WebSocket>, @@ -26,8 +26,8 @@ impl WriteConnection { async fn flush(&mut self, size: usize) -> anyhow::Result<()> { assert!(size <= MAX_DATAGRAM_SIZE); - let frame = Frame::binary(&self.buf[..size]); - self.write.send(frame).await?; + let message = Message::Binary(&self.buf[..size]); + self.write.send(message).await?; Ok(()) } diff --git a/zia-common/src/ws/frame.rs b/zia-common/src/ws/frame.rs index 57f3bb1..7c889ed 100644 --- a/zia-common/src/ws/frame.rs +++ b/zia-common/src/ws/frame.rs @@ -3,7 +3,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[repr(u8)] #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] -pub enum OpCode { +pub(crate) enum OpCode { Continuation = 0x0, Text = 0x1, Binary = 0x2, @@ -28,23 +28,42 @@ impl TryFrom for OpCode { } } -pub struct Frame<'a> { - pub fin: bool, - pub opcode: OpCode, - pub data: &'a [u8], +pub(crate) struct Frame<'a> { + pub(crate) fin: bool, + pub(crate) opcode: OpCode, + pub(crate) data: &'a [u8], } impl<'a> Frame<'a> { #[inline] - pub fn binary(data: &'a [u8]) -> Self { - Self { - fin: true, - opcode: OpCode::Binary, - data, - } + pub(crate) fn new(fin: bool, opcode: OpCode, data: &'a [u8]) -> Self { + Self { fin, opcode, data } } - pub async fn read( + /// ### WebSocket Frame Header + /// + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-------+-+-------------+-------------------------------+ + /// |F|R|R|R| opcode|M| Payload len | Extended payload length | + /// |I|S|S|S| (4) |A| (7) | (16/64) | + /// |N|V|V|V| |S| | (if payload len==126/127) | + /// | |1|2|3| |K| | | + /// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + /// | Extended payload length continued, if payload len == 127 | + /// + - - - - - - - - - - - - - - - +-------------------------------+ + /// | |Masking-key, if MASK set to 1 | + /// +-------------------------------+-------------------------------+ + /// | Masking-key (continued) | Payload Data | + /// +-------------------------------- - - - - - - - - - - - - - - - + + /// : Payload Data continued ... : + /// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + /// | Payload Data continued ... | + /// +---------------------------------------------------------------+ + /// ``` + pub(crate) async fn read( read: &mut R, buf: &'a mut [u8], max_payload_len: usize, @@ -100,7 +119,7 @@ impl<'a> Frame<'a> { }) } - pub async fn write_without_mask( + pub(crate) async fn write_without_mask( self, write: &mut W, ) -> anyhow::Result<()> { diff --git a/zia-common/src/ws/mod.rs b/zia-common/src/ws/mod.rs index efb1cba..3dc9863 100644 --- a/zia-common/src/ws/mod.rs +++ b/zia-common/src/ws/mod.rs @@ -1,17 +1,20 @@ -pub use frame::Frame; pub use ws::WebSocket; mod frame; mod ws; +#[derive(Copy, Clone)] pub enum Role { Server, Client { masking: bool }, } -pub enum Event<'a> { - Data(&'a [u8]), - Close { code: CloseCode, reason: String }, +pub enum Message<'a> { + Binary(&'a [u8]), + Close { + code: CloseCode, + reason: Option<&'a str>, + }, } /// When closing an established connection an endpoint MAY indicate a reason for closure. @@ -84,54 +87,3 @@ impl PartialEq for CloseCode { (*self as u16) == *other } } - -/// This trait is responsible for encoding websocket closed frame. -pub trait CloseReason { - /// Encoded close reason as bytes - type Bytes; - /// Encode websocket close frame. - fn to_bytes(self) -> Self::Bytes; -} - -impl CloseReason for () { - type Bytes = [u8; 0]; - fn to_bytes(self) -> Self::Bytes { - [0; 0] - } -} - -impl CloseReason for u16 { - type Bytes = [u8; 2]; - fn to_bytes(self) -> Self::Bytes { - self.to_be_bytes() - } -} - -impl CloseReason for CloseCode { - type Bytes = [u8; 2]; - fn to_bytes(self) -> Self::Bytes { - (self as u16).to_be_bytes() - } -} - -impl CloseReason for &str { - type Bytes = Vec; - fn to_bytes(self) -> Self::Bytes { - CloseReason::to_bytes((CloseCode::Normal, self)) - } -} - -impl CloseReason for (Code, Msg) -where - Code: Into, - Msg: AsRef<[u8]>, -{ - type Bytes = Vec; - fn to_bytes(self) -> Self::Bytes { - let (code, reason) = (self.0.into(), self.1.as_ref()); - let mut data = Vec::with_capacity(2 + reason.len()); - data.extend_from_slice(&code.to_be_bytes()); - data.extend_from_slice(reason); - data - } -} diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index baafd32..6eff77c 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,19 +1,17 @@ +use std::sync::Arc; + use anyhow::anyhow; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::sync::RwLock; -use crate::ws::frame::OpCode; -use crate::ws::{Event, Frame, Role}; +use crate::ws::frame::{Frame, OpCode}; +use crate::ws::{CloseCode, Message, Role}; -/// WebSocket implementation for both client and server pub struct WebSocket { - /// it is a low-level abstraction that represents the underlying byte stream over which WebSocket messages are exchanged. - pub io: IO, - - /// Maximum allowed payload length in bytes. - pub max_payload_len: usize, - + io: IO, + max_payload_len: usize, role: Role, - is_closed: bool, + closed: Arc>, } impl WebSocket { @@ -23,13 +21,59 @@ impl WebSocket { io: stream, max_payload_len, role, - is_closed: false, + closed: Arc::new(RwLock::new(false)), } } } +impl WebSocket { + pub fn split(self) -> (WebSocket>, WebSocket>) { + let (read, write) = split(self.io); + ( + WebSocket { + io: read, + max_payload_len: self.max_payload_len, + role: self.role, + closed: self.closed.clone(), + }, + WebSocket { + io: write, + max_payload_len: self.max_payload_len, + role: self.role, + closed: self.closed, + }, + ) + } +} + impl WebSocket { - pub async fn send(&mut self, frame: Frame<'_>) -> anyhow::Result<()> { + pub async fn send(&mut self, message: Message<'_>) -> anyhow::Result<()> { + if *self.closed.read().await { + return Err(anyhow!("connection closed"))?; + } + + let res = match message { + Message::Binary(data) => { + let frame = Frame::new(true, OpCode::Binary, data); + self.send_frame(frame).await + } + Message::Close { code, reason } => { + let buf = encode_close_body(code, reason); + let frame = Frame::new(true, OpCode::Close, &buf); + let res = self.send_frame(frame).await; + *(self.closed.write().await) = true; + res + } + }; + + if res.is_err() { + *(self.closed.write().await) = true; + } + + res + } + + async fn send_frame(&mut self, frame: Frame<'_>) -> anyhow::Result<()> { match self.role { Role::Server => frame.write_without_mask(&mut self.io).await?, Role::Client { masking } => { @@ -47,74 +91,29 @@ impl WebSocket { Ok(()) } - // TODO: implement close - // pub async fn close(mut self, reason: T) -> anyhow::Result<()> - // where - // T: CloseReason, - // T::Bytes: AsRef<[u8]>, - // { - // let frame = Frame { - // fin: true, - // opcode: 8, - // data: reason.to_bytes().as_ref(), - // }; - // - // self.send(frame).await?; - // self.flush().await?; - // Ok(()) - // } - pub async fn flush(&mut self) -> anyhow::Result<()> { self.io.flush().await?; Ok(()) } } -// ------------------------------------------------------------------------ - -impl WebSocket -where - R: Unpin + AsyncRead, -{ - /// reads [Event] from websocket stream. - pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { - if self.is_closed { - return Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "read after close", - ))?; +impl WebSocket { + pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { + if *self.closed.read().await { + return Err(anyhow!("connection closed"))?; } - let event = self.recv_event(buf).await; - if let Ok(Event::Close { .. }) | Err(..) = event { - self.is_closed = true; + + let event = self.recv_message(buf).await; + + // set connection to closed + if let Ok(Message::Close { .. }) | Err(..) = event { + *(self.closed.write().await) = true; } + event } - // ### WebSocket Frame Header - // - // ```txt - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-------+-+-------------+-------------------------------+ - // |F|R|R|R| opcode|M| Payload len | Extended payload length | - // |I|S|S|S| (4) |A| (7) | (16/64) | - // |N|V|V|V| |S| | (if payload len==126/127) | - // | |1|2|3| |K| | | - // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - // | Extended payload length continued, if payload len == 127 | - // + - - - - - - - - - - - - - - - +-------------------------------+ - // | |Masking-key, if MASK set to 1 | - // +-------------------------------+-------------------------------+ - // | Masking-key (continued) | Payload Data | - // +-------------------------------- - - - - - - - - - - - - - - - + - // : Payload Data continued ... : - // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - // | Payload Data continued ... | - // +---------------------------------------------------------------+ - // ``` - /// reads [Event] from websocket stream. - pub async fn recv_event<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { + async fn recv_message<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { let frame = Frame::read(&mut self.io, buf, self.max_payload_len).await?; if !frame.fin { @@ -124,7 +123,7 @@ where match frame.opcode { OpCode::Continuation => Err(anyhow!("framed messages are not supported")), OpCode::Text => Err(anyhow!("text frames are not supported")), - OpCode::Binary => Ok(Event::Data(frame.data)), + OpCode::Binary => Ok(Message::Binary(frame.data)), OpCode::Close => Ok(parse_close_body(frame.data)?), OpCode::Ping => Err(anyhow!("ping frames are not supported")), OpCode::Pong => Err(anyhow!("pong frames are not supported")), @@ -132,7 +131,20 @@ where } } -fn parse_close_body(msg: &[u8]) -> anyhow::Result { +fn encode_close_body(code: CloseCode, reason: Option<&str>) -> Vec { + if let Some(reason) = reason { + let mut buf = Vec::with_capacity(2 + reason.len()); + buf.copy_from_slice(&(code as u16).to_be_bytes()); + buf.copy_from_slice(reason.as_ref()); + buf + } else { + let mut buf = Vec::with_capacity(2); + buf.copy_from_slice(&(code as u16).to_be_bytes()); + buf + } +} + +fn parse_close_body(msg: &[u8]) -> anyhow::Result { let code = msg .get(..2) .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) @@ -140,17 +152,12 @@ fn parse_close_body(msg: &[u8]) -> anyhow::Result { match code { 1000..=1003 | 1007..=1011 | 1015 | 3000..=3999 | 4000..=4999 => { - match msg.get(2..).map(|data| String::from_utf8(data.to_vec())) { - Some(Ok(msg)) => Ok(Event::Close { - code: code.into(), - reason: msg, - }), - None => Ok(Event::Close { - code: code.into(), - reason: "".into(), - }), - Some(Err(_)) => Err(anyhow!("invalid utf-8 payload")), - } + let msg = msg.get(2..).map(std::str::from_utf8).transpose()?; + + Ok(Message::Close { + code: code.into(), + reason: msg, + }) } _ => Err(anyhow!("invalid close code")), } diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index 21b5306..d5aa895 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -9,7 +9,6 @@ use clap::Parser; use hyper::service::{make_service_fn, Service}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, Response, Server, StatusCode}; -use tokio::io::split; use tokio::net::UdpSocket; use tokio::select; use tokio::signal::ctrl_c; @@ -53,10 +52,9 @@ impl Future for FutA { tokio::spawn(async move { let ws = upgrade.await.unwrap().into_inner(); - let (read, write) = split(ws); - let read = WebSocket::new(read, MAX_DATAGRAM_SIZE, Role::Server); - let write = WebSocket::new(write, MAX_DATAGRAM_SIZE, Role::Server); + let ws = WebSocket::new(ws, MAX_DATAGRAM_SIZE, Role::Server); + let (read, write) = ws.split(); wread.push(ReadConnection::new(read)).await; wwrite.push(WriteConnection::new(write)).await; From 10ca03a0b1edc01d31878ba6e4167f51c5b9a751 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 2 Oct 2023 12:43:57 +0200 Subject: [PATCH 07/19] improved errors --- Cargo.lock | 1 + zia-common/Cargo.toml | 1 + zia-common/src/ws/error.rs | 34 ++++++++++++++++++++++++++++++++++ zia-common/src/ws/frame.rs | 27 +++++++++++++-------------- zia-common/src/ws/mod.rs | 13 ++++--------- zia-common/src/ws/ws.rs | 31 +++++++++++++++---------------- 6 files changed, 68 insertions(+), 39 deletions(-) create mode 100644 zia-common/src/ws/error.rs diff --git a/Cargo.lock b/Cargo.lock index ffb52df..2c3a29b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1155,6 +1155,7 @@ dependencies = [ "anyhow", "hyper", "rand", + "thiserror", "tokio", "tracing", ] diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml index 4705f60..9871e25 100644 --- a/zia-common/Cargo.toml +++ b/zia-common/Cargo.toml @@ -10,5 +10,6 @@ description = "Proxy udp over websocket, useful to use Wireguard in restricted n tokio = { version = "1.32", default-features = false, features = ["net", "sync"] } hyper = { version = "0.14", default-features = false, features = [] } rand = { version = "0.8", default-features = false } +thiserror = "1.0" tracing = "0.1" anyhow = "1.0" diff --git a/zia-common/src/ws/error.rs b/zia-common/src/ws/error.rs new file mode 100644 index 0000000..048c5a5 --- /dev/null +++ b/zia-common/src/ws/error.rs @@ -0,0 +1,34 @@ +use std::io; +use std::str::Utf8Error; + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum WebsocketError { + #[error("unknown opcode `{0}`")] + UnknownOpCode(u8), + #[error("reserve bit must be `0`")] + ReserveBitMustBeNull, + #[error("control frame must not be fragmented")] + ControlFrameMustNotBeFragmented, + #[error("control frame must have a payload length of 125 bytes or less")] + ControlFrameMustHaveAPayloadLengthOf125BytesOrLess, + #[error("payload too large")] + PayloadTooLarge, + #[error("io error")] + Io(#[from] io::Error), + #[error("not connected")] + NotConnected, + #[error("framed messages are not supported")] + FramedMessagesAreNotSupported, + #[error("text frames are not supported")] + TextFramesAreNotSupported, + #[error("ping frames are not supported")] + PingFramesAreNotSupported, + #[error("pong frames are not supported")] + PongFramesAreNotSupported, + #[error("invalid utf8")] + InvalidUtf8(#[from] Utf8Error), + #[error("invalid close close `{0}`")] + InvalidCloseCode(u16), +} diff --git a/zia-common/src/ws/frame.rs b/zia-common/src/ws/frame.rs index 7c889ed..de6b619 100644 --- a/zia-common/src/ws/frame.rs +++ b/zia-common/src/ws/frame.rs @@ -1,6 +1,7 @@ -use anyhow::anyhow; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use crate::ws::WebsocketError; + #[repr(u8)] #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) enum OpCode { @@ -13,7 +14,7 @@ pub(crate) enum OpCode { } impl TryFrom for OpCode { - type Error = anyhow::Error; + type Error = WebsocketError; fn try_from(value: u8) -> Result { match value { @@ -23,7 +24,7 @@ impl TryFrom for OpCode { 0x8 => Ok(Self::Close), 0x9 => Ok(Self::Ping), 0xA => Ok(Self::Pong), - value => Err(anyhow!("unimplemented opcode: {}", value)), + value => Err(WebsocketError::UnknownOpCode(value)), } } } @@ -67,7 +68,7 @@ impl<'a> Frame<'a> { read: &mut R, buf: &'a mut [u8], max_payload_len: usize, - ) -> anyhow::Result> { + ) -> Result, WebsocketError> { let [b1, b2] = { let mut header = [0u8; 2]; read.read_exact(&mut header).await?; @@ -82,7 +83,7 @@ impl<'a> Frame<'a> { let masked = b2 & 0b_1000_0000 != 0; if rsv != 0 { - return Err(anyhow!("reserve bit must be `0`")); + return Err(WebsocketError::ReserveBitMustBeNull); } let len = match opcode { @@ -93,13 +94,11 @@ impl<'a> Frame<'a> { }, OpCode::Close | OpCode::Ping | OpCode::Pong => { if !fin { - return Err(anyhow!("control frame must not be fragmented")); + return Err(WebsocketError::ControlFrameMustNotBeFragmented); } if len > 125 { - return Err(anyhow!( - "control frame must have a payload length of 125 bytes or less" - )); + return Err(WebsocketError::ControlFrameMustHaveAPayloadLengthOf125BytesOrLess); } len @@ -107,7 +106,7 @@ impl<'a> Frame<'a> { }; if len > max_payload_len { - return Err(anyhow!("payload too large")); + return Err(WebsocketError::PayloadTooLarge); } read_payload(read, &mut buf[..len], masked).await?; @@ -122,7 +121,7 @@ impl<'a> Frame<'a> { pub(crate) async fn write_without_mask( self, write: &mut W, - ) -> anyhow::Result<()> { + ) -> Result<(), WebsocketError> { self.write_header(write, 0).await?; write.write_all(self.data).await?; @@ -133,7 +132,7 @@ impl<'a> Frame<'a> { self, write: &mut W, mask: [u8; 4], - ) -> anyhow::Result<()> { + ) -> Result<(), WebsocketError> { self.write_header(write, 0x80).await?; write.write_all(&mask).await?; @@ -151,7 +150,7 @@ impl<'a> Frame<'a> { &self, write: &mut W, mask_bit: u8, - ) -> anyhow::Result<()> { + ) -> Result<(), WebsocketError> { write .write_u8(((self.fin as u8) << 7) | self.opcode as u8) .await?; @@ -176,7 +175,7 @@ async fn read_payload( read: &mut R, buf: &mut [u8], masked: bool, -) -> anyhow::Result<()> { +) -> Result<(), WebsocketError> { if masked { let mut mask = [0u8; 4]; read.read_exact(&mut mask).await?; diff --git a/zia-common/src/ws/mod.rs b/zia-common/src/ws/mod.rs index 3dc9863..3eb83e4 100644 --- a/zia-common/src/ws/mod.rs +++ b/zia-common/src/ws/mod.rs @@ -1,5 +1,7 @@ +pub use error::WebsocketError; pub use ws::WebSocket; +mod error; mod frame; mod ws; @@ -51,7 +53,7 @@ pub enum CloseCode { /// MUST NOT be set as a status code in a Close control frame by an endpoint. /// /// The connection was closed due to a failure to perform a TLS handshake. - TLSHandshake = 1015, + TlsHandshake = 1015, } impl From for u16 { @@ -75,15 +77,8 @@ impl From for CloseCode { 1009 => CloseCode::MessageTooBig, 1010 => CloseCode::MandatoryExt, 1011 => CloseCode::InternalError, - 1015 => CloseCode::TLSHandshake, + 1015 => CloseCode::TlsHandshake, _ => CloseCode::PolicyViolation, } } } - -impl PartialEq for CloseCode { - #[inline] - fn eq(&self, other: &u16) -> bool { - (*self as u16) == *other - } -} diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index 6eff77c..404cdad 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,11 +1,10 @@ use std::sync::Arc; -use anyhow::anyhow; use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::sync::RwLock; use crate::ws::frame::{Frame, OpCode}; -use crate::ws::{CloseCode, Message, Role}; +use crate::ws::{CloseCode, Message, Role, WebsocketError}; pub struct WebSocket { io: IO, @@ -47,9 +46,9 @@ impl WebSocket { } impl WebSocket { - pub async fn send(&mut self, message: Message<'_>) -> anyhow::Result<()> { + pub async fn send(&mut self, message: Message<'_>) -> Result<(), WebsocketError> { if *self.closed.read().await { - return Err(anyhow!("connection closed"))?; + return Err(WebsocketError::NotConnected)?; } let res = match message { @@ -73,7 +72,7 @@ impl WebSocket { res } - async fn send_frame(&mut self, frame: Frame<'_>) -> anyhow::Result<()> { + async fn send_frame(&mut self, frame: Frame<'_>) -> Result<(), WebsocketError> { match self.role { Role::Server => frame.write_without_mask(&mut self.io).await?, Role::Client { masking } => { @@ -91,16 +90,16 @@ impl WebSocket { Ok(()) } - pub async fn flush(&mut self) -> anyhow::Result<()> { + pub async fn flush(&mut self) -> Result<(), WebsocketError> { self.io.flush().await?; Ok(()) } } impl WebSocket { - pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { + pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> Result, WebsocketError> { if *self.closed.read().await { - return Err(anyhow!("connection closed"))?; + return Err(WebsocketError::NotConnected)?; } let event = self.recv_message(buf).await; @@ -113,20 +112,20 @@ impl WebSocket { event } - async fn recv_message<'a>(&mut self, buf: &'a mut [u8]) -> anyhow::Result> { + async fn recv_message<'a>(&mut self, buf: &'a mut [u8]) -> Result, WebsocketError> { let frame = Frame::read(&mut self.io, buf, self.max_payload_len).await?; if !frame.fin { - return Err(anyhow!("framed messages are not supported")); + return Err(WebsocketError::FramedMessagesAreNotSupported); } match frame.opcode { - OpCode::Continuation => Err(anyhow!("framed messages are not supported")), - OpCode::Text => Err(anyhow!("text frames are not supported")), + OpCode::Continuation => Err(WebsocketError::FramedMessagesAreNotSupported), + OpCode::Text => Err(WebsocketError::TextFramesAreNotSupported), OpCode::Binary => Ok(Message::Binary(frame.data)), OpCode::Close => Ok(parse_close_body(frame.data)?), - OpCode::Ping => Err(anyhow!("ping frames are not supported")), - OpCode::Pong => Err(anyhow!("pong frames are not supported")), + OpCode::Ping => Err( WebsocketError::PingFramesAreNotSupported), + OpCode::Pong => Err(WebsocketError::PongFramesAreNotSupported), } } } @@ -144,7 +143,7 @@ fn encode_close_body(code: CloseCode, reason: Option<&str>) -> Vec { } } -fn parse_close_body(msg: &[u8]) -> anyhow::Result { +fn parse_close_body(msg: &[u8]) -> Result { let code = msg .get(..2) .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) @@ -159,6 +158,6 @@ fn parse_close_body(msg: &[u8]) -> anyhow::Result { reason: msg, }) } - _ => Err(anyhow!("invalid close code")), + code => Err(WebsocketError::InvalidCloseCode(code)), } } From fd4670018084f53e6c18c92594f21f06be8437bb Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 2 Oct 2023 12:47:44 +0200 Subject: [PATCH 08/19] added missing payload len check --- zia-common/src/ws/ws.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index 404cdad..b8d0808 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -4,6 +4,7 @@ use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf use tokio::sync::RwLock; use crate::ws::frame::{Frame, OpCode}; +use crate::ws::WebsocketError::PayloadTooLarge; use crate::ws::{CloseCode, Message, Role, WebsocketError}; pub struct WebSocket { @@ -73,6 +74,10 @@ impl WebSocket { } async fn send_frame(&mut self, frame: Frame<'_>) -> Result<(), WebsocketError> { + if frame.data.len() > self.max_payload_len { + return Err(PayloadTooLarge); + } + match self.role { Role::Server => frame.write_without_mask(&mut self.io).await?, Role::Client { masking } => { @@ -124,7 +129,7 @@ impl WebSocket { OpCode::Text => Err(WebsocketError::TextFramesAreNotSupported), OpCode::Binary => Ok(Message::Binary(frame.data)), OpCode::Close => Ok(parse_close_body(frame.data)?), - OpCode::Ping => Err( WebsocketError::PingFramesAreNotSupported), + OpCode::Ping => Err(WebsocketError::PingFramesAreNotSupported), OpCode::Pong => Err(WebsocketError::PongFramesAreNotSupported), } } From 3ff1183008fec4476c49cf8f06aa59e781b75fcf Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 2 Oct 2023 14:03:53 +0200 Subject: [PATCH 09/19] added connection cleanup to server --- zia-common/src/pool.rs | 45 ++++++++++++++++++++++++++++++----------- zia-common/src/read.rs | 9 +++++++-- zia-common/src/write.rs | 27 ++++++++++++++++++++++--- zia-common/src/ws/ws.rs | 35 +++++++++++++++++++++----------- 4 files changed, 87 insertions(+), 29 deletions(-) diff --git a/zia-common/src/pool.rs b/zia-common/src/pool.rs index 01bdb7d..3a37911 100644 --- a/zia-common/src/pool.rs +++ b/zia-common/src/pool.rs @@ -1,59 +1,80 @@ use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; -pub struct PoolGuard { +pub trait PoolEntry { + fn is_closed(&self) -> bool; +} + +pub struct PoolGuard { inner: Option, back: mpsc::UnboundedSender, + pool_size: Arc, } -unsafe impl Send for PoolGuard {} +unsafe impl Send for PoolGuard {} -impl Drop for PoolGuard { +impl Drop for PoolGuard { fn drop(&mut self) { - if let Err(err) = self.back.send(self.inner.take().unwrap()) { - panic!("Could not put PoolGuard back to pool: {:?}", err); + if let Some(inner) = self.inner.take() { + if inner.is_closed() { + self.pool_size.fetch_sub(1, Ordering::Relaxed); + } else if let Err(err) = self.back.send(inner) { + panic!("Could not put PoolGuard back to pool: {:?}", err); + } } } } -impl Deref for PoolGuard { +impl Deref for PoolGuard { type Target = T; fn deref(&self) -> &Self::Target { self.inner.as_ref().unwrap() } } -impl DerefMut for PoolGuard { +impl DerefMut for PoolGuard { fn deref_mut(&mut self) -> &mut Self::Target { self.inner.as_mut().unwrap() } } -pub struct Pool { +pub struct Pool { + size: Arc, tx: mpsc::UnboundedSender, rx: Mutex>, } -impl Pool { +impl Pool { pub fn new() -> Self { let (tx, rx) = mpsc::unbounded_channel(); Self { + size: Arc::new(AtomicUsize::new(0)), tx, rx: Mutex::new(rx), } } - pub async fn acquire(&self) -> PoolGuard { + pub async fn acquire(&self) -> Option> { + if self.size.load(Ordering::Relaxed) == 0 { + return None + } + let inner = self.rx.lock().await.recv().await.unwrap(); - PoolGuard { + Some(PoolGuard { inner: Some(inner), back: self.tx.clone(), - } + pool_size: self.size.clone(), + }) } +} +impl Pool { pub fn push(&self, inner: T) { + self.size.fetch_add(1, Ordering::Relaxed); if let Err(err) = self.tx.send(inner) { panic!("Could not put Inner into to pool: {:?}", err); } diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs index 655e399..6f2fc01 100644 --- a/zia-common/src/read.rs +++ b/zia-common/src/read.rs @@ -2,14 +2,14 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use crate::datagram_buffer; use tokio::io::{AsyncRead, ReadHalf}; use tokio::net::UdpSocket; use tokio::select; use tokio::sync::{Mutex, RwLock}; use tokio::task::{JoinError, JoinSet}; -use tracing::error; +use tracing::{error, warn}; +use crate::datagram_buffer; use crate::ws::{Message, WebSocket}; pub struct ReadConnection { @@ -86,6 +86,11 @@ impl ReadPool { self.tasks.lock().await.spawn(async move { let mut buf = datagram_buffer(); loop { + if conn.read.is_closed() { + warn!("Read connection closed"); + // TODO: open new connection on client + return Ok(()); + } conn.handle_frame(&socket, &addr, buf.as_mut()).await?; } }); diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index 9e9273f..ae0e244 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -1,13 +1,14 @@ use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use crate::{datagram_buffer, MAX_DATAGRAM_SIZE}; use tokio::io::{AsyncWrite, WriteHalf}; use tokio::net::UdpSocket; use tokio::sync::RwLock; -use tracing::error; +use tracing::{error, warn}; -use crate::pool::Pool; +use crate::pool::{Pool, PoolEntry}; use crate::ws::{Message, WebSocket}; pub struct WriteConnection { @@ -33,6 +34,13 @@ impl WriteConnection { } } +impl PoolEntry for WriteConnection { + fn is_closed(&self) -> bool { + self.write.is_closed() + // TODO: open new connection on client - maybe fancy login in "abstract" pool + } +} + pub struct WritePool { socket: Arc, pool: Pool>, @@ -67,7 +75,20 @@ impl WritePool { pub async fn execute(&self) -> anyhow::Result<()> { loop { - let mut conn = self.pool.acquire().await; + let conn = self.pool.acquire().await; + + let mut conn = match conn { + Some(conn) => conn, + None => { + warn!("Write pool is empty, waiting 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + + if conn.is_closed() { + continue; + } // read from udp socket and save to buf of selected conn let (read, addr) = self.socket.recv_from(conn.buf.as_mut()).await.unwrap(); diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index b8d0808..ced1e37 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,29 +1,37 @@ +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::sync::RwLock; +use tracing::info; use crate::ws::frame::{Frame, OpCode}; -use crate::ws::WebsocketError::PayloadTooLarge; use crate::ws::{CloseCode, Message, Role, WebsocketError}; pub struct WebSocket { io: IO, max_payload_len: usize, role: Role, - closed: Arc>, + closed: Arc, } impl WebSocket { #[inline] - pub fn new(stream: IO, max_payload_len: usize, role: Role) -> Self { + pub fn new(io: IO, max_payload_len: usize, role: Role) -> Self { Self { - io: stream, + io, max_payload_len, role, - closed: Arc::new(RwLock::new(false)), + closed: Arc::new(AtomicBool::new(false)), } } + + pub fn is_closed(&self) -> bool { + self.closed.load(Ordering::Relaxed) + } + + fn set_closed(&self) { + self.closed.store(true, Ordering::Relaxed) + } } impl WebSocket { @@ -48,7 +56,7 @@ impl WebSocket { impl WebSocket { pub async fn send(&mut self, message: Message<'_>) -> Result<(), WebsocketError> { - if *self.closed.read().await { + if self.is_closed() { return Err(WebsocketError::NotConnected)?; } @@ -61,13 +69,15 @@ impl WebSocket { let buf = encode_close_body(code, reason); let frame = Frame::new(true, OpCode::Close, &buf); let res = self.send_frame(frame).await; - *(self.closed.write().await) = true; + self.set_closed(); + info!("Marking write channel as closed"); res } }; if res.is_err() { - *(self.closed.write().await) = true; + self.set_closed(); + info!("Marking write channel as closed"); } res @@ -75,7 +85,7 @@ impl WebSocket { async fn send_frame(&mut self, frame: Frame<'_>) -> Result<(), WebsocketError> { if frame.data.len() > self.max_payload_len { - return Err(PayloadTooLarge); + return Err(WebsocketError::PayloadTooLarge); } match self.role { @@ -103,7 +113,7 @@ impl WebSocket { impl WebSocket { pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> Result, WebsocketError> { - if *self.closed.read().await { + if self.is_closed() { return Err(WebsocketError::NotConnected)?; } @@ -111,7 +121,8 @@ impl WebSocket { // set connection to closed if let Ok(Message::Close { .. }) | Err(..) = event { - *(self.closed.write().await) = true; + info!("marking read channel as closed"); + self.set_closed(); } event From c6a534e8591f7485118507a4d18f05ba78364cad Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 2 Oct 2023 14:06:11 +0200 Subject: [PATCH 10/19] added comment --- zia-common/src/write.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index ae0e244..7fa457d 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -77,6 +77,11 @@ impl WritePool { loop { let conn = self.pool.acquire().await; + // TODO: + // maybe just block until it is not empty anymore + // .revc() in self.pool.acquire() would be blocking + // until a connection becomes available, therefore + // this would be appropriate let mut conn = match conn { Some(conn) => conn, None => { From 284cc5c17310c62e2607742f67fb05c1b3200169 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 2 Oct 2023 14:07:28 +0200 Subject: [PATCH 11/19] improved comment --- zia-common/src/write.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index 7fa457d..d98b208 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -82,6 +82,9 @@ impl WritePool { // .revc() in self.pool.acquire() would be blocking // until a connection becomes available, therefore // this would be appropriate + // - better would be an option to enable the blocking + // only on the server and on the client a `None` + // returned would open a new connection let mut conn = match conn { Some(conn) => conn, None => { From aac886ef205798a3586305954ab07095c2e1de77 Mon Sep 17 00:00:00 2001 From: Marcel Date: Tue, 3 Oct 2023 08:04:33 +0200 Subject: [PATCH 12/19] sending close frane, error (not io) occurs while sending --- Cargo.lock | 4 ++-- zia-common/src/pool.rs | 2 +- zia-common/src/ws/ws.rs | 16 ++++++++++++++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2c3a29b..0912823 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -425,9 +425,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.6.3" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "miniz_oxide" diff --git a/zia-common/src/pool.rs b/zia-common/src/pool.rs index 3a37911..19ac6c0 100644 --- a/zia-common/src/pool.rs +++ b/zia-common/src/pool.rs @@ -59,7 +59,7 @@ impl Pool { pub async fn acquire(&self) -> Option> { if self.size.load(Ordering::Relaxed) == 0 { - return None + return None; } let inner = self.rx.lock().await.recv().await.unwrap(); diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index ced1e37..d364075 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,8 +1,9 @@ +use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; -use tracing::info; +use tracing::{error, info}; use crate::ws::frame::{Frame, OpCode}; use crate::ws::{CloseCode, Message, Role, WebsocketError}; @@ -75,7 +76,18 @@ impl WebSocket { } }; - if res.is_err() { + // set stream as closed and send close frame, if error wan't a io error + if let Err(err) = &res { + match err { + WebsocketError::Io(_) => {} + _ => { + let buf = encode_close_body(CloseCode::InternalError, None); + let frame = Frame::new(true, OpCode::Close, &buf); + if let Err(err) = self.send_frame(frame).await { + error!("Failed to send close frame: {:?}", err); + } + } + } self.set_closed(); info!("Marking write channel as closed"); } From 3e401716cfd0b65d433eacd1dcffe58a4de8fc84 Mon Sep 17 00:00:00 2001 From: Marcel <34819524+MarcelCoding@users.noreply.github.com> Date: Tue, 3 Oct 2023 08:36:26 +0200 Subject: [PATCH 13/19] Update README.md (#106) * Update README.md * Update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f25d85b..dd33514 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,10 @@ graph LR | Name | Description | |-----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Websocket | The UDP datagrams are wrapped inside websocket frames. These frames are then transmitted to the server and there unwrapped. | -| TCP | The UDP datagrams are prefixed with a 8 bit length of the datagram and then transmitted to the server in TCP packages. At the server these packages are unwrapped and forwarded to the actual UDP upstream. | +| TCP | The UDP datagrams are prefixed with a 16 bit length of the datagram and then transmitted to the server in TCP packages. At the server these packages are unwrapped and forwarded to the actual UDP upstream. | -The client is capable of doing an TLSv2 or TLSv3 handshake, the server isn't. The client is also able to do a TLS -handshake between the HTTPS proxy and the Server. In a case where an end to end (zia-client <-> zia-server) TLS +The client is capable of doing a TLSv2 or TLSv3 handshake, the server isn't able to handle TLS requests. The client is also able to do a TLS +handshake for a HTTPS proxy. In a case where an end to end (zia-client <-> zia-server) TLS encryption should happen, you have to proxy the traffic for the server using a reverse proxy. ## Client From 6b38bd8bdad446678bf5a683bbe9fd73f9b37109 Mon Sep 17 00:00:00 2001 From: Marcel Date: Tue, 3 Oct 2023 08:44:33 +0200 Subject: [PATCH 14/19] improved errors --- zia-common/src/ws/error.rs | 4 ++-- zia-common/src/ws/ws.rs | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/zia-common/src/ws/error.rs b/zia-common/src/ws/error.rs index 048c5a5..47758fb 100644 --- a/zia-common/src/ws/error.rs +++ b/zia-common/src/ws/error.rs @@ -15,7 +15,7 @@ pub enum WebsocketError { ControlFrameMustHaveAPayloadLengthOf125BytesOrLess, #[error("payload too large")] PayloadTooLarge, - #[error("io error")] + #[error(transparent)] Io(#[from] io::Error), #[error("not connected")] NotConnected, @@ -27,7 +27,7 @@ pub enum WebsocketError { PingFramesAreNotSupported, #[error("pong frames are not supported")] PongFramesAreNotSupported, - #[error("invalid utf8")] + #[error(transparent)] InvalidUtf8(#[from] Utf8Error), #[error("invalid close close `{0}`")] InvalidCloseCode(u16), diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs index d364075..eee7e57 100644 --- a/zia-common/src/ws/ws.rs +++ b/zia-common/src/ws/ws.rs @@ -1,4 +1,3 @@ -use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; From 259840ed25ae7954a764fafc03e3d8073248c93f Mon Sep 17 00:00:00 2001 From: Marcel Date: Tue, 3 Oct 2023 08:53:35 +0200 Subject: [PATCH 15/19] updated readme --- README.md | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dd33514..1d2b9d6 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,7 @@ graph LR C ---|UDP| D[Wireguard Server] ``` -The benefit is that Websocket uses Http. If you are in a restricted network where you can only access external services, -and you can only use a provided Http proxy you can proxy your Wireguard Udp traffic over Websocket. +The benefit is that Websocket uses Http. If you are in a restricted network where you can only access external services, and you can only use a provided Http proxy you can proxy your Wireguard Udp traffic over Websocket. ```mermaid graph LR @@ -24,14 +23,12 @@ graph LR ## Mode -| Name | Description | -|-----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Websocket | The UDP datagrams are wrapped inside websocket frames. These frames are then transmitted to the server and there unwrapped. | +| Name | Description | +|-----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Websocket | The UDP datagrams are wrapped inside websocket frames. These frames are then transmitted to the server and there unwrapped. | | TCP | The UDP datagrams are prefixed with a 16 bit length of the datagram and then transmitted to the server in TCP packages. At the server these packages are unwrapped and forwarded to the actual UDP upstream. | -The client is capable of doing a TLSv2 or TLSv3 handshake, the server isn't able to handle TLS requests. The client is also able to do a TLS -handshake for a HTTPS proxy. In a case where an end to end (zia-client <-> zia-server) TLS -encryption should happen, you have to proxy the traffic for the server using a reverse proxy. +The client is capable of doing a TLSv2 or TLSv3 handshake, the server isn't able to handle TLS requests. In a case where an end to end (zia-client <-> zia-server) TLS encryption should happen, you have to proxy the traffic for the server using a reverse proxy. ## Client From 4386006887105c69158e66df7a9f49bc919a65bd Mon Sep 17 00:00:00 2001 From: Marcel Date: Tue, 3 Oct 2023 20:46:07 +0200 Subject: [PATCH 16/19] moved websocket impl into its own lib `wsocket` --- Cargo.lock | 17 +++- zia-client/Cargo.toml | 3 +- zia-client/src/app.rs | 10 +- zia-common/Cargo.toml | 3 +- zia-common/src/lib.rs | 4 +- zia-common/src/read.rs | 2 +- zia-common/src/write.rs | 4 +- zia-common/src/ws/README.md | 1 - zia-common/src/ws/error.rs | 34 ------- zia-common/src/ws/frame.rs | 192 ------------------------------------ zia-common/src/ws/mod.rs | 84 ---------------- zia-common/src/ws/ws.rs | 190 ----------------------------------- zia-server/Cargo.toml | 1 + zia-server/src/main.rs | 4 +- 14 files changed, 28 insertions(+), 521 deletions(-) delete mode 100644 zia-common/src/ws/README.md delete mode 100644 zia-common/src/ws/error.rs delete mode 100644 zia-common/src/ws/frame.rs delete mode 100644 zia-common/src/ws/mod.rs delete mode 100644 zia-common/src/ws/ws.rs diff --git a/Cargo.lock b/Cargo.lock index 0912823..8f6de34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1129,6 +1129,18 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "wsocket" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9292da6c1a19221e1438774e8718db7163191c845f948b8b16675765396bdc7" +dependencies = [ + "rand", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "zia-client" version = "0.0.0-git" @@ -1145,6 +1157,7 @@ dependencies = [ "tracing-subscriber", "url", "webpki-roots", + "wsocket", "zia-common", ] @@ -1154,10 +1167,9 @@ version = "0.0.0-git" dependencies = [ "anyhow", "hyper", - "rand", - "thiserror", "tokio", "tracing", + "wsocket", ] [[package]] @@ -1174,5 +1186,6 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", + "wsocket", "zia-common", ] diff --git a/zia-client/Cargo.toml b/zia-client/Cargo.toml index c624531..52deb6d 100644 --- a/zia-client/Cargo.toml +++ b/zia-client/Cargo.toml @@ -13,6 +13,7 @@ async-http-proxy = { version = "1.2", default-features = false, features = ["run hyper = { version = "0.14", default-features = false, features = [] } tracing-subscriber = { version = "0.3", features = ["tracing-log"] } clap = { version = "4.4", features = ["derive", "env"] } +wsocket = { version = "0.1", features = ["client"] } url = { version = "2.4", features = ["serde"] } webpki-roots = "0.25" tokio-rustls = "0.24" @@ -24,5 +25,5 @@ zia-common = { path = '../zia-common' } [package.metadata.generate-rpm] assets = [ - { source = "../LICENSE", dest = "/usr/share/doc/zia-client/LICENSE", doc = true, mode = "0644" }, + { source = "../LICENSE", dest = "/usr/share/doc/zia-client/LICENSE", doc = true, mode = "0644" }, ] diff --git a/zia-client/src/app.rs b/zia-client/src/app.rs index bc69f63..08cacd9 100644 --- a/zia-client/src/app.rs +++ b/zia-client/src/app.rs @@ -15,8 +15,8 @@ use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, Server use tokio_rustls::TlsConnector; use tracing::info; use url::Url; +use wsocket::WebSocket; -use zia_common::ws::{Role, WebSocket}; use zia_common::{ReadConnection, WriteConnection, MAX_DATAGRAM_SIZE}; static TLS_CONNECTOR: Lazy = Lazy::new(|| { @@ -112,13 +112,7 @@ pub(crate) async fn open_connection( info!("Finished websocket handshake"); - let ws = WebSocket::new( - ws.into_inner(), - MAX_DATAGRAM_SIZE, - Role::Client { - masking: websocket_masking, - }, - ); + let ws = WebSocket::client(ws.into_inner(), MAX_DATAGRAM_SIZE, websocket_masking); let (read, write) = ws.split(); diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml index 9871e25..435c65a 100644 --- a/zia-common/Cargo.toml +++ b/zia-common/Cargo.toml @@ -9,7 +9,6 @@ description = "Proxy udp over websocket, useful to use Wireguard in restricted n [dependencies] tokio = { version = "1.32", default-features = false, features = ["net", "sync"] } hyper = { version = "0.14", default-features = false, features = [] } -rand = { version = "0.8", default-features = false } -thiserror = "1.0" +wsocket = "0.1" tracing = "0.1" anyhow = "1.0" diff --git a/zia-common/src/lib.rs b/zia-common/src/lib.rs index a30bd5f..c064096 100644 --- a/zia-common/src/lib.rs +++ b/zia-common/src/lib.rs @@ -1,11 +1,11 @@ -pub use read::*; use std::mem; + +pub use read::*; pub use write::*; mod pool; mod read; mod write; -pub mod ws; pub const MAX_DATAGRAM_SIZE: usize = u16::MAX as usize - mem::size_of::(); diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs index 6f2fc01..a40c2cb 100644 --- a/zia-common/src/read.rs +++ b/zia-common/src/read.rs @@ -8,9 +8,9 @@ use tokio::select; use tokio::sync::{Mutex, RwLock}; use tokio::task::{JoinError, JoinSet}; use tracing::{error, warn}; +use wsocket::{Message, WebSocket}; use crate::datagram_buffer; -use crate::ws::{Message, WebSocket}; pub struct ReadConnection { read: WebSocket>, diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index d98b208..a405af2 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -2,14 +2,14 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use crate::{datagram_buffer, MAX_DATAGRAM_SIZE}; use tokio::io::{AsyncWrite, WriteHalf}; use tokio::net::UdpSocket; use tokio::sync::RwLock; use tracing::{error, warn}; +use wsocket::{Message, WebSocket}; use crate::pool::{Pool, PoolEntry}; -use crate::ws::{Message, WebSocket}; +use crate::{datagram_buffer, MAX_DATAGRAM_SIZE}; pub struct WriteConnection { write: WebSocket>, diff --git a/zia-common/src/ws/README.md b/zia-common/src/ws/README.md deleted file mode 100644 index 79b9bdb..0000000 --- a/zia-common/src/ws/README.md +++ /dev/null @@ -1 +0,0 @@ -based on https://github.com/nurmohammed840/websocket.rs diff --git a/zia-common/src/ws/error.rs b/zia-common/src/ws/error.rs deleted file mode 100644 index 47758fb..0000000 --- a/zia-common/src/ws/error.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::io; -use std::str::Utf8Error; - -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum WebsocketError { - #[error("unknown opcode `{0}`")] - UnknownOpCode(u8), - #[error("reserve bit must be `0`")] - ReserveBitMustBeNull, - #[error("control frame must not be fragmented")] - ControlFrameMustNotBeFragmented, - #[error("control frame must have a payload length of 125 bytes or less")] - ControlFrameMustHaveAPayloadLengthOf125BytesOrLess, - #[error("payload too large")] - PayloadTooLarge, - #[error(transparent)] - Io(#[from] io::Error), - #[error("not connected")] - NotConnected, - #[error("framed messages are not supported")] - FramedMessagesAreNotSupported, - #[error("text frames are not supported")] - TextFramesAreNotSupported, - #[error("ping frames are not supported")] - PingFramesAreNotSupported, - #[error("pong frames are not supported")] - PongFramesAreNotSupported, - #[error(transparent)] - InvalidUtf8(#[from] Utf8Error), - #[error("invalid close close `{0}`")] - InvalidCloseCode(u16), -} diff --git a/zia-common/src/ws/frame.rs b/zia-common/src/ws/frame.rs deleted file mode 100644 index de6b619..0000000 --- a/zia-common/src/ws/frame.rs +++ /dev/null @@ -1,192 +0,0 @@ -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -use crate::ws::WebsocketError; - -#[repr(u8)] -#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] -pub(crate) enum OpCode { - Continuation = 0x0, - Text = 0x1, - Binary = 0x2, - Close = 0x8, - Ping = 0x9, - Pong = 0xA, -} - -impl TryFrom for OpCode { - type Error = WebsocketError; - - fn try_from(value: u8) -> Result { - match value { - 0x0 => Ok(Self::Continuation), - 0x1 => Ok(Self::Text), - 0x2 => Ok(Self::Binary), - 0x8 => Ok(Self::Close), - 0x9 => Ok(Self::Ping), - 0xA => Ok(Self::Pong), - value => Err(WebsocketError::UnknownOpCode(value)), - } - } -} - -pub(crate) struct Frame<'a> { - pub(crate) fin: bool, - pub(crate) opcode: OpCode, - pub(crate) data: &'a [u8], -} - -impl<'a> Frame<'a> { - #[inline] - pub(crate) fn new(fin: bool, opcode: OpCode, data: &'a [u8]) -> Self { - Self { fin, opcode, data } - } - - /// ### WebSocket Frame Header - /// - /// - /// ```txt - /// 0 1 2 3 - /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - /// +-+-+-+-+-------+-+-------------+-------------------------------+ - /// |F|R|R|R| opcode|M| Payload len | Extended payload length | - /// |I|S|S|S| (4) |A| (7) | (16/64) | - /// |N|V|V|V| |S| | (if payload len==126/127) | - /// | |1|2|3| |K| | | - /// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - /// | Extended payload length continued, if payload len == 127 | - /// + - - - - - - - - - - - - - - - +-------------------------------+ - /// | |Masking-key, if MASK set to 1 | - /// +-------------------------------+-------------------------------+ - /// | Masking-key (continued) | Payload Data | - /// +-------------------------------- - - - - - - - - - - - - - - - + - /// : Payload Data continued ... : - /// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - /// | Payload Data continued ... | - /// +---------------------------------------------------------------+ - /// ``` - pub(crate) async fn read( - read: &mut R, - buf: &'a mut [u8], - max_payload_len: usize, - ) -> Result, WebsocketError> { - let [b1, b2] = { - let mut header = [0u8; 2]; - read.read_exact(&mut header).await?; - header - }; - - let fin = b1 & 0b1000_0000 != 0; - let rsv = b1 & 0b0111_0000; - let opcode = OpCode::try_from(b1 & 0b0000_1111)?; - - let len = (b2 & 0b0111_1111) as usize; - let masked = b2 & 0b_1000_0000 != 0; - - if rsv != 0 { - return Err(WebsocketError::ReserveBitMustBeNull); - } - - let len = match opcode { - OpCode::Continuation | OpCode::Text | OpCode::Binary => match len { - 126 => read.read_u16().await? as usize, - 127 => read.read_u64().await? as usize, - len => len, - }, - OpCode::Close | OpCode::Ping | OpCode::Pong => { - if !fin { - return Err(WebsocketError::ControlFrameMustNotBeFragmented); - } - - if len > 125 { - return Err(WebsocketError::ControlFrameMustHaveAPayloadLengthOf125BytesOrLess); - } - - len - } - }; - - if len > max_payload_len { - return Err(WebsocketError::PayloadTooLarge); - } - - read_payload(read, &mut buf[..len], masked).await?; - - Ok(Self { - fin, - opcode, - data: &buf[..len], - }) - } - - pub(crate) async fn write_without_mask( - self, - write: &mut W, - ) -> Result<(), WebsocketError> { - self.write_header(write, 0).await?; - write.write_all(self.data).await?; - - Ok(()) - } - - pub async fn write_with_mask( - self, - write: &mut W, - mask: [u8; 4], - ) -> Result<(), WebsocketError> { - self.write_header(write, 0x80).await?; - write.write_all(&mask).await?; - - for i in 0..self.data.len() { - // TODO: Use SIMD wherever possible for best performance - write - .write_u8(unsafe { self.data.get_unchecked(i) ^ mask.get_unchecked(i & 3) }) - .await? - } - - Ok(()) - } - - async fn write_header( - &self, - write: &mut W, - mask_bit: u8, - ) -> Result<(), WebsocketError> { - write - .write_u8(((self.fin as u8) << 7) | self.opcode as u8) - .await?; - - let len = self.data.len(); - - if len < 126 { - write.write_u8(mask_bit | len as u8).await?; - } else if len < 65536 { - write.write_u8(mask_bit | 126).await?; - write.write_u16(len as u16).await?; - } else { - write.write_u8(mask_bit | 127).await?; - write.write_u64(len as u64).await?; - } - - Ok(()) - } -} - -async fn read_payload( - read: &mut R, - buf: &mut [u8], - masked: bool, -) -> Result<(), WebsocketError> { - if masked { - let mut mask = [0u8; 4]; - read.read_exact(&mut mask).await?; - read.read_exact(buf).await?; - // TODO: Use SIMD wherever possible for best performance - for i in 0..buf.len() { - buf[i] ^= mask[i & 3]; - } - } else { - read.read_exact(buf).await?; - } - - Ok(()) -} diff --git a/zia-common/src/ws/mod.rs b/zia-common/src/ws/mod.rs deleted file mode 100644 index 3eb83e4..0000000 --- a/zia-common/src/ws/mod.rs +++ /dev/null @@ -1,84 +0,0 @@ -pub use error::WebsocketError; -pub use ws::WebSocket; - -mod error; -mod frame; -mod ws; - -#[derive(Copy, Clone)] -pub enum Role { - Server, - Client { masking: bool }, -} - -pub enum Message<'a> { - Binary(&'a [u8]), - Close { - code: CloseCode, - reason: Option<&'a str>, - }, -} - -/// When closing an established connection an endpoint MAY indicate a reason for closure. -#[derive(Debug, Clone, Copy)] -pub enum CloseCode { - /// The purpose for which the connection was established has been fulfilled - Normal = 1000, - /// Server going down or a browser having navigated away from a page - Away = 1001, - /// An endpoint is terminating the connection due to a protocol error. - ProtocolError = 1002, - /// It has received a type of data it cannot accept - Unsupported = 1003, - - // reserved 1004 - /// MUST NOT be set as a status code in a Close control frame by an endpoint. - /// - /// No status code was actually present. - NoStatusRcvd = 1005, - /// MUST NOT be set as a status code in a Close control frame by an endpoint. - /// - /// Connection was closed abnormally. - Abnormal = 1006, - /// Application has received data within a message that was not consistent with the type of the message. - InvalidPayload = 1007, - /// This is a generic status code that can be returned when there is no other more suitable status code. - PolicyViolation = 1008, - /// Message that is too big for it to process. - MessageTooBig = 1009, - /// It has expected the server to negotiate one or more extension. - MandatoryExt = 1010, - /// The server has encountered an unexpected condition that prevented it from fulfilling the request. - InternalError = 1011, - /// MUST NOT be set as a status code in a Close control frame by an endpoint. - /// - /// The connection was closed due to a failure to perform a TLS handshake. - TlsHandshake = 1015, -} - -impl From for u16 { - #[inline] - fn from(code: CloseCode) -> Self { - code as u16 - } -} - -impl From for CloseCode { - #[inline] - fn from(value: u16) -> Self { - match value { - 1000 => CloseCode::Normal, - 1001 => CloseCode::Away, - 1002 => CloseCode::ProtocolError, - 1003 => CloseCode::Unsupported, - 1005 => CloseCode::NoStatusRcvd, - 1006 => CloseCode::Abnormal, - 1007 => CloseCode::InvalidPayload, - 1009 => CloseCode::MessageTooBig, - 1010 => CloseCode::MandatoryExt, - 1011 => CloseCode::InternalError, - 1015 => CloseCode::TlsHandshake, - _ => CloseCode::PolicyViolation, - } - } -} diff --git a/zia-common/src/ws/ws.rs b/zia-common/src/ws/ws.rs deleted file mode 100644 index eee7e57..0000000 --- a/zia-common/src/ws/ws.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; -use tracing::{error, info}; - -use crate::ws::frame::{Frame, OpCode}; -use crate::ws::{CloseCode, Message, Role, WebsocketError}; - -pub struct WebSocket { - io: IO, - max_payload_len: usize, - role: Role, - closed: Arc, -} - -impl WebSocket { - #[inline] - pub fn new(io: IO, max_payload_len: usize, role: Role) -> Self { - Self { - io, - max_payload_len, - role, - closed: Arc::new(AtomicBool::new(false)), - } - } - - pub fn is_closed(&self) -> bool { - self.closed.load(Ordering::Relaxed) - } - - fn set_closed(&self) { - self.closed.store(true, Ordering::Relaxed) - } -} - -impl WebSocket { - pub fn split(self) -> (WebSocket>, WebSocket>) { - let (read, write) = split(self.io); - ( - WebSocket { - io: read, - max_payload_len: self.max_payload_len, - role: self.role, - closed: self.closed.clone(), - }, - WebSocket { - io: write, - max_payload_len: self.max_payload_len, - role: self.role, - closed: self.closed, - }, - ) - } -} - -impl WebSocket { - pub async fn send(&mut self, message: Message<'_>) -> Result<(), WebsocketError> { - if self.is_closed() { - return Err(WebsocketError::NotConnected)?; - } - - let res = match message { - Message::Binary(data) => { - let frame = Frame::new(true, OpCode::Binary, data); - self.send_frame(frame).await - } - Message::Close { code, reason } => { - let buf = encode_close_body(code, reason); - let frame = Frame::new(true, OpCode::Close, &buf); - let res = self.send_frame(frame).await; - self.set_closed(); - info!("Marking write channel as closed"); - res - } - }; - - // set stream as closed and send close frame, if error wan't a io error - if let Err(err) = &res { - match err { - WebsocketError::Io(_) => {} - _ => { - let buf = encode_close_body(CloseCode::InternalError, None); - let frame = Frame::new(true, OpCode::Close, &buf); - if let Err(err) = self.send_frame(frame).await { - error!("Failed to send close frame: {:?}", err); - } - } - } - self.set_closed(); - info!("Marking write channel as closed"); - } - - res - } - - async fn send_frame(&mut self, frame: Frame<'_>) -> Result<(), WebsocketError> { - if frame.data.len() > self.max_payload_len { - return Err(WebsocketError::PayloadTooLarge); - } - - match self.role { - Role::Server => frame.write_without_mask(&mut self.io).await?, - Role::Client { masking } => { - if masking { - let mask = rand::random::().to_ne_bytes(); - frame.write_with_mask(&mut self.io, mask).await?; - } else { - frame.write_without_mask(&mut self.io).await?; - } - } - } - - self.io.flush().await?; - - Ok(()) - } - - pub async fn flush(&mut self) -> Result<(), WebsocketError> { - self.io.flush().await?; - Ok(()) - } -} - -impl WebSocket { - pub async fn recv<'a>(&mut self, buf: &'a mut [u8]) -> Result, WebsocketError> { - if self.is_closed() { - return Err(WebsocketError::NotConnected)?; - } - - let event = self.recv_message(buf).await; - - // set connection to closed - if let Ok(Message::Close { .. }) | Err(..) = event { - info!("marking read channel as closed"); - self.set_closed(); - } - - event - } - - async fn recv_message<'a>(&mut self, buf: &'a mut [u8]) -> Result, WebsocketError> { - let frame = Frame::read(&mut self.io, buf, self.max_payload_len).await?; - - if !frame.fin { - return Err(WebsocketError::FramedMessagesAreNotSupported); - } - - match frame.opcode { - OpCode::Continuation => Err(WebsocketError::FramedMessagesAreNotSupported), - OpCode::Text => Err(WebsocketError::TextFramesAreNotSupported), - OpCode::Binary => Ok(Message::Binary(frame.data)), - OpCode::Close => Ok(parse_close_body(frame.data)?), - OpCode::Ping => Err(WebsocketError::PingFramesAreNotSupported), - OpCode::Pong => Err(WebsocketError::PongFramesAreNotSupported), - } - } -} - -fn encode_close_body(code: CloseCode, reason: Option<&str>) -> Vec { - if let Some(reason) = reason { - let mut buf = Vec::with_capacity(2 + reason.len()); - buf.copy_from_slice(&(code as u16).to_be_bytes()); - buf.copy_from_slice(reason.as_ref()); - buf - } else { - let mut buf = Vec::with_capacity(2); - buf.copy_from_slice(&(code as u16).to_be_bytes()); - buf - } -} - -fn parse_close_body(msg: &[u8]) -> Result { - let code = msg - .get(..2) - .map(|bytes| u16::from_be_bytes([bytes[0], bytes[1]])) - .unwrap_or(1000); - - match code { - 1000..=1003 | 1007..=1011 | 1015 | 3000..=3999 | 4000..=4999 => { - let msg = msg.get(2..).map(std::str::from_utf8).transpose()?; - - Ok(Message::Close { - code: code.into(), - reason: msg, - }) - } - code => Err(WebsocketError::InvalidCloseCode(code)), - } -} diff --git a/zia-server/Cargo.toml b/zia-server/Cargo.toml index 8b122e3..d143ab9 100644 --- a/zia-server/Cargo.toml +++ b/zia-server/Cargo.toml @@ -15,6 +15,7 @@ clap = { version = "4.4", features = ["derive", "env"] } webpki-roots = "0.25" pin-project = "1.1" once_cell = "1.18" +wsocket = "0.1" tracing = "0.1" anyhow = "1.0" diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index d5aa895..0e4a18b 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -15,8 +15,8 @@ use tokio::signal::ctrl_c; use tokio::sync::RwLock; use tokio::task::JoinHandle; use tracing::info; +use wsocket::WebSocket; -use zia_common::ws::{Role, WebSocket}; use zia_common::{ReadConnection, ReadPool, WriteConnection, WritePool, MAX_DATAGRAM_SIZE}; use crate::cfg::ServerCfg; @@ -53,7 +53,7 @@ impl Future for FutA { tokio::spawn(async move { let ws = upgrade.await.unwrap().into_inner(); - let ws = WebSocket::new(ws, MAX_DATAGRAM_SIZE, Role::Server); + let ws = WebSocket::server(ws, MAX_DATAGRAM_SIZE); let (read, write) = ws.split(); wread.push(ReadConnection::new(read)).await; From 71c01d4bcfa4066e5a4680bc9bb5c2d02c02980d Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Nov 2023 13:35:30 +0100 Subject: [PATCH 17/19] updated dependencies --- Cargo.lock | 245 ++++++++++++++-------------------------- zia-client/Cargo.toml | 2 +- zia-client/src/main.rs | 1 - zia-common/Cargo.toml | 2 +- zia-common/src/read.rs | 6 +- zia-common/src/write.rs | 2 +- zia-server/Cargo.toml | 2 +- 7 files changed, 89 insertions(+), 171 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f6de34..f70ea7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,9 +106,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.4" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "block-buffer" @@ -119,12 +119,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "bumpalo" -version = "3.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" - [[package]] name = "bytes" version = "1.5.0" @@ -148,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.6" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" +checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" dependencies = [ "clap_builder", "clap_derive", @@ -158,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.6" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" +checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" dependencies = [ "anstream", "anstyle", @@ -170,9 +164,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.4.2" +version = "4.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", @@ -182,9 +176,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "colorchoice" @@ -194,9 +188,9 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "cpufeatures" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" dependencies = [ "libc", ] @@ -226,7 +220,7 @@ name = "fastwebsockets" version = "0.4.4" source = "git+https://github.com/MarcelCoding/fastwebsockets?branch=split#fa0521b583d88e88a74ac1e0b50957a4d3244c45" dependencies = [ - "base64 0.21.4", + "base64 0.21.5", "hyper", "pin-project", "rand", @@ -253,30 +247,30 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", ] [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-core", "futures-task", @@ -296,9 +290,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -325,9 +319,9 @@ checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -373,7 +367,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.9", + "socket2 0.4.10", "tokio", "tower-service", "tracing", @@ -396,15 +390,6 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" -[[package]] -name = "js-sys" -version = "0.3.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" -dependencies = [ - "wasm-bindgen", -] - [[package]] name = "lazy_static" version = "1.4.0" @@ -413,9 +398,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.148" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "log" @@ -440,9 +425,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "wasi", @@ -536,9 +521,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.67" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] @@ -584,17 +569,16 @@ dependencies = [ [[package]] name = "ring" -version = "0.16.20" +version = "0.17.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" dependencies = [ "cc", + "getrandom", "libc", - "once_cell", "spin", "untrusted", - "web-sys", - "winapi", + "windows-sys", ] [[package]] @@ -605,9 +589,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustls" -version = "0.21.7" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", "ring", @@ -617,9 +601,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.6" +version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ "ring", "untrusted", @@ -627,9 +611,9 @@ dependencies = [ [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ "ring", "untrusted", @@ -637,18 +621,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.188" +version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" dependencies = [ "proc-macro2", "quote", @@ -668,9 +652,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b21f559e07218024e7e9f90f96f601825397de0e25420135f7f952453fed0b" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" dependencies = [ "lazy_static", ] @@ -686,15 +670,15 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "socket2" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" dependencies = [ "libc", "winapi", @@ -702,9 +686,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", "windows-sys", @@ -712,9 +696,9 @@ dependencies = [ [[package]] name = "spin" -version = "0.5.2" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "strsim" @@ -724,9 +708,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "2.0.37" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", @@ -735,18 +719,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", @@ -780,9 +764,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.32.0" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ "backtrace", "bytes", @@ -791,16 +775,16 @@ dependencies = [ "num_cpus", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.4", + "socket2 0.5.5", "tokio-macros", "windows-sys", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -825,11 +809,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -837,9 +820,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -848,9 +831,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -858,20 +841,20 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ - "lazy_static", "log", + "once_cell", "tracing-core", ] [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -916,9 +899,9 @@ dependencies = [ [[package]] name = "untrusted" -version = "0.7.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" @@ -971,70 +954,6 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" -[[package]] -name = "wasm-bindgen" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" - -[[package]] -name = "web-sys" -version = "0.3.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - [[package]] name = "webpki-roots" version = "0.25.2" @@ -1131,9 +1050,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "wsocket" -version = "0.1.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9292da6c1a19221e1438774e8718db7163191c845f948b8b16675765396bdc7" +checksum = "c8405de0c5f7bdd3e25b537e91e3bba56b4a8efd9d5599a369ac4104e739023b" dependencies = [ "rand", "thiserror", diff --git a/zia-client/Cargo.toml b/zia-client/Cargo.toml index 52deb6d..fdec144 100644 --- a/zia-client/Cargo.toml +++ b/zia-client/Cargo.toml @@ -8,7 +8,7 @@ description = "Proxy udp over websocket, useful to use Wireguard in restricted n [dependencies] fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } -tokio = { version = "1.32", default-features = false, features = ["rt-multi-thread", "macros", "net", "sync", "time", "signal"] } +tokio = { version = "1.34", default-features = false, features = ["rt-multi-thread", "macros", "net", "sync", "time", "signal"] } async-http-proxy = { version = "1.2", default-features = false, features = ["runtime-tokio", "basic-auth"] } hyper = { version = "0.14", default-features = false, features = [] } tracing-subscriber = { version = "0.3", features = ["tracing-log"] } diff --git a/zia-client/src/main.rs b/zia-client/src/main.rs index 2b54f5e..953c386 100644 --- a/zia-client/src/main.rs +++ b/zia-client/src/main.rs @@ -81,7 +81,6 @@ async fn listen( for _ in 0..connection_count { let upstream = upstream.clone(); let proxy = proxy.clone(); - let websocket_masking = websocket_masking; conns.spawn(async move { open_connection(&upstream, &proxy, websocket_masking).await }); } diff --git a/zia-common/Cargo.toml b/zia-common/Cargo.toml index 435c65a..8f3d4d3 100644 --- a/zia-common/Cargo.toml +++ b/zia-common/Cargo.toml @@ -7,7 +7,7 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -tokio = { version = "1.32", default-features = false, features = ["net", "sync"] } +tokio = { version = "1.34", default-features = false, features = ["net", "sync"] } hyper = { version = "0.14", default-features = false, features = [] } wsocket = "0.1" tracing = "0.1" diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs index a40c2cb..b9e329f 100644 --- a/zia-common/src/read.rs +++ b/zia-common/src/read.rs @@ -27,14 +27,14 @@ impl ReadConnection { addr: &RwLock>, buf: &mut [u8], ) -> anyhow::Result<()> { - let event = self.read.recv(buf).await?; + let message = self.read.recv(buf).await?; - match event { + match message { Message::Binary(data) => { let addr = addr.read().await.unwrap(); socket.send_to(data, addr).await?; } - Message::Close { .. } => {} + _ => unimplemented!(), } Ok(()) diff --git a/zia-common/src/write.rs b/zia-common/src/write.rs index a405af2..ccc2744 100644 --- a/zia-common/src/write.rs +++ b/zia-common/src/write.rs @@ -79,7 +79,7 @@ impl WritePool { // TODO: // maybe just block until it is not empty anymore - // .revc() in self.pool.acquire() would be blocking + // .revc() in self.pool.acquire() would be "blocking" (asynchronous) // until a connection becomes available, therefore // this would be appropriate // - better would be an option to enable the blocking diff --git a/zia-server/Cargo.toml b/zia-server/Cargo.toml index d143ab9..d9f540e 100644 --- a/zia-server/Cargo.toml +++ b/zia-server/Cargo.toml @@ -7,7 +7,7 @@ license = "AGPL-3.0" description = "Proxy udp over websocket, useful to use Wireguard in restricted networks." [dependencies] -tokio = { version = "1.32", default-features = false, features = ["rt-multi-thread", "macros", "net","sync", "time", "signal"] } +tokio = { version = "1.34", default-features = false, features = ["rt-multi-thread", "macros", "net","sync", "time", "signal"] } fastwebsockets = { git = "https://github.com/MarcelCoding/fastwebsockets", branch = "split", default-features = false, features = ["upgrade"] } hyper = { version = "0.14", default-features = false, features = ["tcp"] } tracing-subscriber = { version = "0.3", features = ["tracing-log"] } From b95b40d3a3582a529f097eec952f6402f8b7de5d Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Nov 2023 13:49:54 +0100 Subject: [PATCH 18/19] applied klemens feedback from #107 --- zia-server/src/main.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index 0e4a18b..9a6771f 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -24,13 +24,13 @@ use crate::cfg::ServerCfg; mod cfg; #[pin_project::pin_project] -struct FutA { +struct HandleRequestFuture { req: Request, read: Arc, write: Arc>, } -impl Future for FutA { +impl Future for HandleRequestFuture { type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { @@ -65,22 +65,22 @@ impl Future for FutA { } // mod app; -struct Handler { +struct ConnectionHandler { read: Arc, write: Arc>, } -impl Service> for Handler { +impl Service> for ConnectionHandler { type Response = Response; type Error = Infallible; - type Future = FutA; + type Future = HandleRequestFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { - FutA { + HandleRequestFuture { req, read: self.read.clone(), write: self.write.clone(), @@ -109,7 +109,7 @@ async fn main() -> anyhow::Result<()> { let read = rp.clone(); let write = wp.clone(); - async move { Ok::<_, Infallible>(Handler { read, write }) } + async move { Ok::<_, Infallible>(ConnectionHandler { read, write }) } }); let server = Server::bind(&config.listen_addr).serve(make_service); From a27d57c74bebd51467c271b9188f07212c0b2cc6 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Nov 2023 13:55:40 +0100 Subject: [PATCH 19/19] applied klemens feedback from #107 --- zia-common/src/read.rs | 4 ++-- zia-server/src/main.rs | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/zia-common/src/read.rs b/zia-common/src/read.rs index b9e329f..8371b08 100644 --- a/zia-common/src/read.rs +++ b/zia-common/src/read.rs @@ -56,7 +56,7 @@ impl ReadPool { } } - async fn wait(&self) -> Option, JoinError>> { + async fn wait_for_connections_to_close(&self) -> Option, JoinError>> { let mut set = self.tasks.lock().await; select! { result = set.join_next() => result, @@ -67,7 +67,7 @@ impl ReadPool { pub async fn join(&self) -> anyhow::Result<()> { // hack loop { - while let Some(result) = self.wait().await { + while let Some(result) = self.wait_for_connections_to_close().await { if let Err(err) = result? { error!("Error while handling websocket frame: {}", err); // TODO: close and remove from write pool diff --git a/zia-server/src/main.rs b/zia-server/src/main.rs index 9a6771f..9b292c6 100644 --- a/zia-server/src/main.rs +++ b/zia-server/src/main.rs @@ -47,8 +47,8 @@ impl Future for HandleRequestFuture { let (resp, upgrade) = fastwebsockets::upgrade::upgrade(this.req).unwrap(); - let wread = this.read.clone(); - let wwrite = this.write.clone(); + let cloned_read = this.read.clone(); + let cloned_write = this.write.clone(); tokio::spawn(async move { let ws = upgrade.await.unwrap().into_inner(); @@ -56,8 +56,8 @@ impl Future for HandleRequestFuture { let ws = WebSocket::server(ws, MAX_DATAGRAM_SIZE); let (read, write) = ws.split(); - wread.push(ReadConnection::new(read)).await; - wwrite.push(WriteConnection::new(write)).await; + cloned_read.push(ReadConnection::new(read)).await; + cloned_write.push(WriteConnection::new(write)).await; }); Poll::Ready(Ok(resp))