diff --git a/CHANGELOG.md b/CHANGELOG.md index 069889fc9..c46f4d709 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# v4.5.522 +### Client +* Update: Disable "Exclude My Country" when unable to load the IP location database +* Update: Automatically disconnect client after changing exclude my country settings +* Improve: Performance & Speed on the connection when the user has selected exclude my country +* Fix: Freeze network on some situation + +### Server +* Update: Improve performance + # v4.5.520 ### Client * Feature: Allow to select servers by country if the server supports it diff --git a/Pub/PubVersion.json b/Pub/PubVersion.json index 49685a544..271683927 100644 --- a/Pub/PubVersion.json +++ b/Pub/PubVersion.json @@ -1,6 +1,6 @@ { - "Version": "4.5.520", - "BumpTime": "2024-05-29T06:25:48.3640756Z", + "Version": "4.5.522", + "BumpTime": "2024-06-04T09:05:15.4355193Z", "Prerelease": false, "DeprecatedVersion": "4.0.00" } diff --git a/Tests/VpnHood.Test/TestHelper.cs b/Tests/VpnHood.Test/TestHelper.cs index bdb04f21c..2d6ab8c1a 100644 --- a/Tests/VpnHood.Test/TestHelper.cs +++ b/Tests/VpnHood.Test/TestHelper.cs @@ -217,7 +217,7 @@ public static string CreateAccessManagerWorkingDir() return Path.Combine(WorkingPath, $"AccessManager_{Guid.NewGuid()}"); } - public static FileAccessManager CreateFileAccessManager(FileAccessManagerOptions? options = null, string? storagePath = null, + public static FileAccessManager CreateFileAccessManager(FileAccessManagerOptions? options = null, string? storagePath = null, string? serverLocation = null) { storagePath ??= CreateAccessManagerWorkingDir(); @@ -332,7 +332,7 @@ public static async Task CreateClient(Token token, clientId ??= Guid.NewGuid(); clientOptions ??= CreateClientOptions(); if (clientOptions.ConnectTimeout == new ClientOptions().ConnectTimeout) clientOptions.ConnectTimeout = TimeSpan.FromSeconds(3); - clientOptions.PacketCaptureIncludeIpRanges = TestIpAddresses.Select(x => new IpRange(x)).ToArray(); + clientOptions.PacketCaptureIncludeIpRanges = TestIpAddresses.Select(IpRange.FromIpAddress).ToOrderedList(); clientOptions.IncludeLocalNetwork = true; var client = new VpnHoodClient( @@ -362,7 +362,7 @@ public static AppOptions CreateClientAppOptions() { StorageFolderPath = Path.Combine(WorkingPath, "AppData_" + Guid.NewGuid()), SessionTimeout = TimeSpan.FromSeconds(2), - UseIpGroupManager = false, + UseInternalLocationService = false, UseExternalLocationService = false, LogVerbose = LogVerbose }; @@ -435,4 +435,11 @@ internal static void Init() JobRunner.Default.Interval = TimeSpan.FromMilliseconds(200); JobSection.DefaultInterval = TimeSpan.FromMilliseconds(200); } + public static string GetParentDirectory(string path, int level = 1) + { + for (var i = 0; i < level; i++) + path = Path.GetDirectoryName(path) ?? throw new Exception("Invalid path"); + + return path; + } } \ No newline at end of file diff --git a/Tests/VpnHood.Test/TestNullPacketCapture.cs b/Tests/VpnHood.Test/TestNullPacketCapture.cs index 3218b4760..0b5a52218 100644 --- a/Tests/VpnHood.Test/TestNullPacketCapture.cs +++ b/Tests/VpnHood.Test/TestNullPacketCapture.cs @@ -47,7 +47,7 @@ public void SendPacketToInbound(IPPacket ipPacket) // nothing } - public void SendPacketToInbound(IEnumerable packets) + public void SendPacketToInbound(IList packets) { // nothing } @@ -57,7 +57,7 @@ public void SendPacketToOutbound(IPPacket ipPacket) // nothing } - public void SendPacketToOutbound(IEnumerable ipPackets) + public void SendPacketToOutbound(IList ipPackets) { // nothing } diff --git a/Tests/VpnHood.Test/Tests/ClientAppTest.cs b/Tests/VpnHood.Test/Tests/ClientAppTest.cs index 352268e37..25e059a46 100644 --- a/Tests/VpnHood.Test/Tests/ClientAppTest.cs +++ b/Tests/VpnHood.Test/Tests/ClientAppTest.cs @@ -1,6 +1,8 @@ -using System.Net; +using System.IO.Compression; +using System.Net; using System.Net.NetworkInformation; using System.Text; +using System.Text.Json; using EmbedIO; using Microsoft.Extensions.Logging; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -13,8 +15,6 @@ using VpnHood.Common.Net; using VpnHood.Common.Utils; -// ReSharper disable DisposeOnUsingVariable - namespace VpnHood.Test.Tests; [TestClass] @@ -72,25 +72,45 @@ public async Task BuiltIn_AccessKeys_initialization() } } + private static async Task UpdateIp2LocationFile() + { + // update current ipLocation in app project after a week + var solutionFolder = TestHelper.GetParentDirectory(Directory.GetCurrentDirectory(), 5); + var ipLocationFile = Path.Combine(solutionFolder, "VpnHood.Client.App", "Resources", "IpLocations.zip"); + if (File.GetCreationTime(ipLocationFile) >= DateTime.Now - TimeSpan.FromDays(7)) + return; + + // find token + var userSecretFile = Path.Combine(Path.GetDirectoryName(solutionFolder)!, ".user", "credentials.json"); + var document = JsonDocument.Parse(await File.ReadAllTextAsync(userSecretFile)); + var ip2LocationToken = document.RootElement.GetProperty("Ip2LocationToken").GetString(); + + // copy zip to memory + var httpClient = new HttpClient(); + // ReSharper disable once StringLiteralTypo + await using var ipLocationZipNetStream = await httpClient.GetStreamAsync( + $"https://www.ip2location.com/download/?token={ip2LocationToken}&file=DB1LITECSVIPV6"); + using var ipLocationZipStream = new MemoryStream(); + await ipLocationZipNetStream.CopyToAsync(ipLocationZipStream); + ipLocationZipStream.Position = 0; + + // build new ipLocation file + using var ipLocationZipArchive = new ZipArchive(ipLocationZipStream, ZipArchiveMode.Read); + await using var crvStream = ipLocationZipArchive.GetEntry("IP2LOCATION-LITE-DB1.IPV6.CSV")!.Open(); + await IpGroupBuilder.BuildIpGroupArchiveFromIp2Location(crvStream, ipLocationFile); + } [TestMethod] - public async Task Load_country_ip_groups() + public async Task IpLocations_must_be_loaded() { - // ************ - // *** TEST ***: - await using var app1 = TestHelper.CreateClientApp(); - var ipGroups = await app1.GetIpGroups(); - Assert.IsFalse(ipGroups.Any(x => x.IpGroupId == "us"), - "Countries should not be extracted in test due to performance."); - await app1.DisposeAsync(); + await UpdateIp2LocationFile(); - // ************ - // *** TEST ***: var appOptions = TestHelper.CreateClientAppOptions(); - appOptions.UseIpGroupManager = true; - await using var app2 = TestHelper.CreateClientApp(appOptions: appOptions); - var ipGroups2 = await app2.GetIpGroups(); - Assert.IsTrue(ipGroups2.Any(x => x.IpGroupId == "us"), + appOptions.UseInternalLocationService = true; + await using var app = TestHelper.CreateClientApp(appOptions: appOptions); + var ipGroupsManager = await app.GetIpGroupManager(); + var ipGroupIds = await ipGroupsManager.GetIpGroupIds(); + Assert.IsTrue(ipGroupIds.Any(x => x == "us"), "Countries has not been extracted."); } diff --git a/Tests/VpnHood.Test/Tests/DnsConfigurationTest.cs b/Tests/VpnHood.Test/Tests/DnsConfigurationTest.cs index 1ab989e38..db68fd6bb 100644 --- a/Tests/VpnHood.Test/Tests/DnsConfigurationTest.cs +++ b/Tests/VpnHood.Test/Tests/DnsConfigurationTest.cs @@ -86,7 +86,7 @@ public async Task Server_should_not_block_own_dns_servers() await using var client = await TestHelper.CreateClient(token, clientOptions: clientOptions); foreach (var serverDnsServer in serverDnsServers) - Assert.IsNull(server.SessionManager.NetFilter.BlockedIpRanges.FindInSortedRanges(serverDnsServer)); + Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(serverDnsServer)); CollectionAssert.AreEqual(fileAccessManagerOptions.DnsServers, client.DnsServers); } } \ No newline at end of file diff --git a/Tests/VpnHood.Test/Tests/IpNetworkTest.cs b/Tests/VpnHood.Test/Tests/IpNetworkTest.cs index 486e248fc..7a8db9911 100644 --- a/Tests/VpnHood.Test/Tests/IpNetworkTest.cs +++ b/Tests/VpnHood.Test/Tests/IpNetworkTest.cs @@ -28,10 +28,9 @@ public void Invert_Unify_Convert() IpRange.Parse("192.168.10.0 - 192.168.255.255"), IpRange.Parse("127.0.0.0 - 127.255.255.255"), IpRange.Parse("127.0.0.0 - 127.255.255.254") //extra - }; - CollectionAssert.AreEqual(ipRangesSorted, ipRanges.Sort().ToArray()); + }.ToOrderedList(); + CollectionAssert.AreEqual(ipRangesSorted, ipRanges.ToArray()); - var inverted = ipRanges.Invert(); var expected = new[] { IpRange.Parse("0.0.0.0 - 126.255.255.255"), @@ -39,13 +38,13 @@ public void Invert_Unify_Convert() IpRange.Parse("192.169.0.0 - 255.255.255.255"), IpRange.Parse(":: - 99:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF"), IpRange.Parse("AA::01:0000 - FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF") - }; + }.ToOrderedList(); - CollectionAssert.AreEqual(expected, inverted.ToArray()); + CollectionAssert.AreEqual(expected.ToArray(), ipRanges.Invert().ToArray()); // check network - CollectionAssert.AreEqual(expected.ToIpNetworks().ToArray(), ipRanges.ToIpNetworks().Invert().ToArray()); - CollectionAssert.AreEqual(ipRangesSorted, ipRanges.ToIpNetworks().ToIpRanges().Sort().ToArray()); + CollectionAssert.AreEqual(expected.ToIpNetworks().ToArray(), ipRanges.Invert().ToIpNetworks().ToArray()); + CollectionAssert.AreEqual(ipRangesSorted, ipRanges.ToArray()); } [TestMethod] @@ -54,7 +53,7 @@ public void IpNetwork_Unit() var ipNetwork = IpNetwork.Parse("192.168.23.23/32"); var inverted = ipNetwork.Invert().ToArray(); Assert.AreEqual(32, inverted.Length); - CollectionAssert.AreEqual(new[] { ipNetwork }, inverted.Invert(true, false).ToArray()); + CollectionAssert.AreEqual(new[] { ipNetwork }, inverted.ToIpRanges().Invert(true, false).ToIpNetworks().ToArray()); ipNetwork = IpNetwork.AllV4; Assert.AreEqual(0, ipNetwork.Invert().ToArray().Length); @@ -62,7 +61,7 @@ public void IpNetwork_Unit() ipNetwork = IpNetwork.AllV6; Assert.AreEqual(0, ipNetwork.Invert().ToArray().Length); - CollectionAssert.AreEqual(IpNetwork.All, Array.Empty().Invert().ToArray()); + CollectionAssert.AreEqual(IpNetwork.All, Array.Empty().ToIpRanges().Invert().ToIpNetworks().ToArray()); } [TestMethod] @@ -87,14 +86,14 @@ public void IpRange_IsInRange() IpRange.Parse("5.5.5.5-5.5.5.10") }; - ipRanges = ipRanges.Sort().ToArray(); - Assert.IsFalse(ipRanges.IsInSortedRanges(IPAddress.Parse("9.9.9.7"))); - Assert.IsTrue(ipRanges.IsInSortedRanges(IPAddress.Parse("8.8.8.8"))); - Assert.IsTrue(ipRanges.IsInSortedRanges(IPAddress.Parse("9.9.9.9"))); - Assert.IsFalse(ipRanges.IsInSortedRanges(IPAddress.Parse("4.4.4.5"))); - Assert.IsTrue(ipRanges.IsInSortedRanges(IPAddress.Parse("4.4.4.3"))); - Assert.IsTrue(ipRanges.IsInSortedRanges(IPAddress.Parse("FF::F0"))); - Assert.IsFalse(ipRanges.IsInSortedRanges(IPAddress.Parse("AF::F0"))); + var ipRangeOrderedList = ipRanges.ToOrderedList(); + Assert.IsFalse(ipRangeOrderedList.IsInRange(IPAddress.Parse("9.9.9.7"))); + Assert.IsTrue(ipRangeOrderedList.IsInRange(IPAddress.Parse("8.8.8.8"))); + Assert.IsTrue(ipRangeOrderedList.IsInRange(IPAddress.Parse("9.9.9.9"))); + Assert.IsFalse(ipRangeOrderedList.IsInRange(IPAddress.Parse("4.4.4.5"))); + Assert.IsTrue(ipRangeOrderedList.IsInRange(IPAddress.Parse("4.4.4.3"))); + Assert.IsTrue(ipRangeOrderedList.IsInRange(IPAddress.Parse("FF::F0"))); + Assert.IsFalse(ipRangeOrderedList.IsInRange(IPAddress.Parse("AF::F0"))); } [TestMethod] @@ -110,7 +109,7 @@ public void IpRange_Intersect(bool swap) IpRange.Parse("30.30.10.50 - 30.30.10.100"), IpRange.Parse("20.20.10.50 - 20.20.10.55"), IpRange.Parse("20.20.10.60 - 20.20.10.100") - }; + }.ToOrderedList(); var ipRanges2 = new[] { @@ -119,14 +118,13 @@ public void IpRange_Intersect(bool swap) IpRange.Parse("190.190.11.1 - 190.190.11.50"), //ignore IpRange.Parse("30.30.10.70 - 30.30.10.110"), IpRange.Parse("20.20.10.0 - 20.20.10.90") - }; + }.ToOrderedList(); // Expected // AA::FFF5 - AA::FFF6 - var ranges = swap - ? ipRanges2.Intersect(ipRanges1).ToArray() - : ipRanges1.Intersect(ipRanges2).ToArray(); + ? ipRanges2.Intersect(ipRanges1) + : ipRanges1.Intersect(ipRanges2); var i = 0; Assert.AreEqual("20.20.10.50-20.20.10.55", ranges[i++].ToString().ToUpper()); @@ -134,7 +132,7 @@ public void IpRange_Intersect(bool swap) Assert.AreEqual("30.30.10.70-30.30.10.100", ranges[i++].ToString().ToUpper()); Assert.AreEqual("192.168.10.0-192.168.12.12", ranges[i++].ToString().ToUpper()); Assert.AreEqual("AA::FFF5-AA::FFF6", ranges[i++].ToString().ToUpper()); - Assert.AreEqual(i, ranges.Length); + Assert.AreEqual(i, ranges.Count); } } \ No newline at end of file diff --git a/Tests/VpnHood.Test/Tests/ServerNetFilterConfigTest.cs b/Tests/VpnHood.Test/Tests/ServerNetFilterConfigTest.cs index 9f3111983..81baee00a 100644 --- a/Tests/VpnHood.Test/Tests/ServerNetFilterConfigTest.cs +++ b/Tests/VpnHood.Test/Tests/ServerNetFilterConfigTest.cs @@ -22,22 +22,19 @@ public async Task PacketCapture_Include() await using var client = new VpnHoodClient(TestHelper.CreatePacketCapture(), Guid.NewGuid(), token, new ClientOptions { - PacketCaptureIncludeIpRanges = - [ - IpRange.Parse("230.0.0.0-230.0.0.200") - ] + PacketCaptureIncludeIpRanges = new IpRangeOrderedList([IpRange.Parse("230.0.0.0-230.0.0.200")]) }); await client.Connect(); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.0"))); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.10"))); - Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.100"))); - Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.150"))); - Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.200"))); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.220"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.0"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.10"))); + Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.100"))); + Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.150"))); + Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.200"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.220"))); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.50"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.50"))); } [TestMethod] @@ -54,23 +51,20 @@ public async Task PacketCapture_Exclude() await using var client = new VpnHoodClient(TestHelper.CreatePacketCapture(), Guid.NewGuid(), token, new ClientOptions { - PacketCaptureIncludeIpRanges = - [ - IpRange.Parse("230.0.0.0-230.0.0.200") - ] + PacketCaptureIncludeIpRanges = new IpRangeOrderedList([IpRange.Parse("230.0.0.0-230.0.0.200")]) }); await client.Connect(); - Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.0"))); - Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.10"))); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.100"))); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.150"))); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.200"))); + Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.0"))); + Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.10"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.100"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.150"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.200"))); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.220"))); //block by client - Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.50"))); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.220"))); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.220"))); //block by client + Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.50"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.220"))); } [TestMethod] @@ -89,16 +83,16 @@ public async Task PacketCapture_Include_Exclude_LocalNetwork() await using var client = new VpnHoodClient(TestHelper.CreatePacketCapture(), Guid.NewGuid(), token, new ClientOptions()); await client.Connect(); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("192.168.0.100")), "LocalNetWorks failed"); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.110")), "Excludes failed"); - Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.50")), "Includes failed"); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.240")), "Includes failed"); - Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.254")), "Includes failed"); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("192.168.0.100")), "LocalNetWorks failed"); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.110")), "Excludes failed"); + Assert.IsTrue(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.50")), "Includes failed"); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.240")), "Includes failed"); + Assert.IsFalse(client.PacketCaptureIncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.254")), "Includes failed"); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("192.168.0.100"))); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.110"))); - Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.50"))); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.254"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("192.168.0.100"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.110"))); + Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.50"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.254"))); } [TestMethod] @@ -116,13 +110,13 @@ public async Task IpRange_Include_Exclude() await using var client = new VpnHoodClient(TestHelper.CreatePacketCapture(), Guid.NewGuid(), token, new ClientOptions()); await client.Connect(); - Assert.IsFalse(client.IncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.110")), "Excludes failed"); - Assert.IsTrue(client.IncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.50")), "Includes failed"); - Assert.IsFalse(client.IncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.240")), "Includes failed"); - Assert.IsFalse(client.IncludeIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.254")), "Includes & Excludes failed"); + Assert.IsFalse(client.IncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.110")), "Excludes failed"); + Assert.IsTrue(client.IncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.50")), "Includes failed"); + Assert.IsFalse(client.IncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.240")), "Includes failed"); + Assert.IsFalse(client.IncludeIpRanges.IsInRange(IPAddress.Parse("230.0.0.254")), "Includes & Excludes failed"); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.110"))); - Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.50"))); - Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInSortedRanges(IPAddress.Parse("230.0.0.254"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.110"))); + Assert.IsFalse(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.50"))); + Assert.IsTrue(server.SessionManager.NetFilter.BlockedIpRanges.IsInRange(IPAddress.Parse("230.0.0.254"))); } } \ No newline at end of file diff --git a/Tests/VpnHood.Test/Tests/TcpDatagramTest.cs b/Tests/VpnHood.Test/Tests/TcpDatagramTest.cs index c6e77fe33..0c083ab75 100644 --- a/Tests/VpnHood.Test/Tests/TcpDatagramTest.cs +++ b/Tests/VpnHood.Test/Tests/TcpDatagramTest.cs @@ -60,7 +60,7 @@ public async Task AutoCloseChannel() // Check sending packet to server // ------ var testPacket = PacketUtil.CreateUdpPacket(IPEndPoint.Parse("1.1.1.1:1"), IPEndPoint.Parse("1.1.1.1:2"), [1, 2, 3]); - await clientTunnel.SendPacket(testPacket); + await clientTunnel.SendPacketsAsync([testPacket], CancellationToken.None); await VhTestUtil.AssertEqualsWait(testPacket.ToString(), () => lastServerReceivedPacket?.ToString()); await VhTestUtil.AssertEqualsWait(0, () => clientTunnel.DatagramChannelCount); await VhTestUtil.AssertEqualsWait(0, () => serverTunnel.DatagramChannelCount); diff --git a/Tests/VpnHood.Test/Tests/TunnelTest.cs b/Tests/VpnHood.Test/Tests/TunnelTest.cs index da15bf88e..ba82148ce 100644 --- a/Tests/VpnHood.Test/Tests/TunnelTest.cs +++ b/Tests/VpnHood.Test/Tests/TunnelTest.cs @@ -194,7 +194,7 @@ public async Task UdpChannel_via_Tunnel() }; // send packet to server through tunnel - await clientTunnel.SendPackets(packets.ToArray()); + await clientTunnel.SendPacketsAsync(packets.ToArray(), CancellationToken.None); await VhTestUtil.AssertEqualsWait(packets.Count, () => serverReceivedPackets.Length); await VhTestUtil.AssertEqualsWait(packets.Count, () => clientReceivedPackets.Length); } diff --git a/VpnHood.Client.App.Android.Common/AndroidAppUiService.cs b/VpnHood.Client.App.Android.Common/AndroidAppUiService.cs index b5ca00b35..6d41f3ddf 100644 --- a/VpnHood.Client.App.Android.Common/AndroidAppUiService.cs +++ b/VpnHood.Client.App.Android.Common/AndroidAppUiService.cs @@ -4,6 +4,7 @@ using VpnHood.Client.Device; using VpnHood.Client.Device.Droid; using VpnHood.Client.Device.Droid.Utils; +using VpnHood.Common.Utils; using Permission = Android.Content.PM.Permission; namespace VpnHood.Client.App.Droid.Common; @@ -23,8 +24,10 @@ public async Task RequestQuickLaunch(IUiContext context, CancellationToken // request for adding tile // result. 0: reject, 1: already granted, 2: granted - var res = await QuickLaunchTileService.RequestAddTile(appUiContext.Activity) -.WaitAsync(cancellationToken); + var res = await QuickLaunchTileService + .RequestAddTile(appUiContext.Activity) + .WaitAsync(cancellationToken) + .VhConfigureAwait(); return res != 0; } @@ -48,7 +51,9 @@ public async Task RequestNotification(IUiContext context, CancellationToke // request for notification _requestPostNotificationsCompletionTask = new TaskCompletionSource(); appUiContext.Activity.RequestPermissions([Manifest.Permission.PostNotifications], RequestPostNotificationId); - var res = await _requestPostNotificationsCompletionTask.Task.WaitAsync(cancellationToken); + var res = await _requestPostNotificationsCompletionTask.Task + .WaitAsync(cancellationToken) + .VhConfigureAwait(); return res == Permission.Granted; } finally diff --git a/VpnHood.Client.App.Android.Common/VpnHood.Client.App.Android.Common.csproj b/VpnHood.Client.App.Android.Common/VpnHood.Client.App.Android.Common.csproj index 0479389fa..0245e394c 100644 --- a/VpnHood.Client.App.Android.Common/VpnHood.Client.App.Android.Common.csproj +++ b/VpnHood.Client.App.Android.Common/VpnHood.Client.App.Android.Common.csproj @@ -21,7 +21,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Android.Connect/MainActivity.cs b/VpnHood.Client.App.Android.Connect/MainActivity.cs index 994e113a8..629761ce3 100644 --- a/VpnHood.Client.App.Android.Connect/MainActivity.cs +++ b/VpnHood.Client.App.Android.Connect/MainActivity.cs @@ -2,6 +2,7 @@ using Android.Content.PM; using Android.Service.QuickSettings; using Android.Views; +using Java.Lang; using VpnHood.Client.App.Droid.Common.Activities; using VpnHood.Client.App.Droid.Connect.Properties; @@ -25,6 +26,7 @@ public class MainActivity : AndroidAppMainActivity { protected override AndroidAppMainActivityHandler CreateMainActivityHandler() { + JavaSystem.SetProperty("debug.checkjni", "true"); //todo: remove return new AndroidAppWebViewMainActivityHandler(this, new AndroidMainActivityWebViewOptions { DefaultSpaPort = AssemblyInfo.DefaultSpaPort, diff --git a/VpnHood.Client.App.Android.GooglePlay.Ads/GooglePlayAdService.cs b/VpnHood.Client.App.Android.GooglePlay.Ads/GooglePlayAdService.cs index 40558a2dd..3409e88bd 100644 --- a/VpnHood.Client.App.Android.GooglePlay.Ads/GooglePlayAdService.cs +++ b/VpnHood.Client.App.Android.GooglePlay.Ads/GooglePlayAdService.cs @@ -4,6 +4,7 @@ using VpnHood.Client.Device; using VpnHood.Client.Device.Droid; using VpnHood.Client.Exceptions; +using VpnHood.Common.Utils; using Object = Java.Lang.Object; namespace VpnHood.Client.App.Droid.GooglePlay.Ads; @@ -26,7 +27,7 @@ public async Task LoadRewardedAd(Activity activity, CancellationToke try { if (_rewardedAdLoadCallback != null && _lastLoadRewardedAdTime.AddHours(1) < DateTime.Now) - return await _rewardedAdLoadCallback.Task; + return await _rewardedAdLoadCallback.Task.VhConfigureAwait(); _rewardedAdLoadCallback = new MyRewardedAdLoadCallback(); var adRequest = new AdRequest.Builder().Build(); @@ -34,10 +35,10 @@ public async Task LoadRewardedAd(Activity activity, CancellationToke var cancellationTask = new TaskCompletionSource(); cancellationToken.Register(cancellationTask.SetResult); - await Task.WhenAny(_rewardedAdLoadCallback.Task, cancellationTask.Task); + await Task.WhenAny(_rewardedAdLoadCallback.Task, cancellationTask.Task).VhConfigureAwait(); cancellationToken.ThrowIfCancellationRequested(); - var rewardedAd = await _rewardedAdLoadCallback.Task; + var rewardedAd = await _rewardedAdLoadCallback.Task.VhConfigureAwait(); _lastLoadRewardedAdTime = DateTime.Now; return rewardedAd; } @@ -56,7 +57,7 @@ public async Task ShowAd(IUiContext uiContext, string customData, CancellationTo var activity = appUiContext.Activity; // create ad custom data - var rewardedAd = await LoadRewardedAd(activity, cancellationToken); + var rewardedAd = await LoadRewardedAd(activity, cancellationToken).VhConfigureAwait(); if (activity.IsDestroyed) throw new AdException("MainActivity has been destroyed before showing the ad."); @@ -77,7 +78,7 @@ public async Task ShowAd(IUiContext uiContext, string customData, CancellationTo // wait for earn reward or dismiss var cancellationTask = new TaskCompletionSource(); cancellationToken.Register(cancellationTask.SetResult); - await Task.WhenAny(fullScreenContentCallback.DismissedTask, userEarnedRewardListener.UserEarnedRewardTask, cancellationTask.Task); + await Task.WhenAny(fullScreenContentCallback.DismissedTask, userEarnedRewardListener.UserEarnedRewardTask, cancellationTask.Task).VhConfigureAwait(); cancellationToken.ThrowIfCancellationRequested(); // check task errors diff --git a/VpnHood.Client.App.Android.GooglePlay.Core/GooglePlayAppUpdaterService.cs b/VpnHood.Client.App.Android.GooglePlay.Core/GooglePlayAppUpdaterService.cs index 2e70b6ecc..5d7b6927d 100644 --- a/VpnHood.Client.App.Android.GooglePlay.Core/GooglePlayAppUpdaterService.cs +++ b/VpnHood.Client.App.Android.GooglePlay.Core/GooglePlayAppUpdaterService.cs @@ -4,6 +4,7 @@ using VpnHood.Client.Device; using VpnHood.Client.Device.Droid; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; using Xamarin.Google.Android.Play.Core.AppUpdate; using Xamarin.Google.Android.Play.Core.Install; using Xamarin.Google.Android.Play.Core.Install.Model; @@ -20,7 +21,7 @@ public async Task Update(IUiContext uiContext) using var appUpdateManager = AppUpdateManagerFactory.Create(appUiContext.Activity); try { - var appUpdateInfo = await new GooglePlayTaskCompleteListener(appUpdateManager.AppUpdateInfo).Task; + var appUpdateInfo = await new GooglePlayTaskCompleteListener(appUpdateManager.AppUpdateInfo).Task.VhConfigureAwait(); var updateAvailability = appUpdateInfo.UpdateAvailability(); // play set UpdateAvailability.UpdateNotAvailable even when there is no connection to google @@ -33,16 +34,16 @@ public async Task Update(IUiContext uiContext) // Show Google Play update dialog var updateFlowPlayTask = appUpdateManager.StartUpdateFlow(appUpdateInfo, appUiContext.Activity, AppUpdateOptions.NewBuilder(AppUpdateType.Flexible).Build()); - var updateFlowResult = await new GooglePlayTaskCompleteListener(updateFlowPlayTask).Task; + var updateFlowResult = await new GooglePlayTaskCompleteListener(updateFlowPlayTask).Task.VhConfigureAwait(); if (updateFlowResult.IntValue() != -1) throw new Exception("Could not start update flow."); // Wait for download complete - await googlePlayDownloadStateListener.WaitForCompletion(); + await googlePlayDownloadStateListener.WaitForCompletion().VhConfigureAwait(); // Start install downloaded update var installUpdateTask = appUpdateManager.CompleteUpdate(); - var installUpdateStatus = await new GooglePlayTaskCompleteListener(installUpdateTask).Task; + var installUpdateStatus = await new GooglePlayTaskCompleteListener(installUpdateTask).Task.VhConfigureAwait(); // Could not start install if (installUpdateStatus.IntValue() != -1) diff --git a/VpnHood.Client.App.Android.GooglePlay.Core/VpnHood.Client.App.Android.GooglePlay.Core.csproj b/VpnHood.Client.App.Android.GooglePlay.Core/VpnHood.Client.App.Android.GooglePlay.Core.csproj index 95221c434..84e567f8a 100644 --- a/VpnHood.Client.App.Android.GooglePlay.Core/VpnHood.Client.App.Android.GooglePlay.Core.csproj +++ b/VpnHood.Client.App.Android.GooglePlay.Core/VpnHood.Client.App.Android.GooglePlay.Core.csproj @@ -21,7 +21,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Android.GooglePlay/GooglePlayAuthenticationService.cs b/VpnHood.Client.App.Android.GooglePlay/GooglePlayAuthenticationService.cs index b682785d0..f8e021a93 100644 --- a/VpnHood.Client.App.Android.GooglePlay/GooglePlayAuthenticationService.cs +++ b/VpnHood.Client.App.Android.GooglePlay/GooglePlayAuthenticationService.cs @@ -6,6 +6,7 @@ using VpnHood.Client.Device; using VpnHood.Client.Device.Droid; using VpnHood.Client.Device.Droid.Utils; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.Droid.GooglePlay; @@ -24,7 +25,7 @@ public async Task SilentSignIn(IUiContext uiContext) var appUiContext = (AndroidUiContext)uiContext; using var googleSignInClient = GoogleSignIn.GetClient(appUiContext.Activity, _googleSignInOptions); - var account = await googleSignInClient.SilentSignInAsync(); + var account = await googleSignInClient.SilentSignInAsync().VhConfigureAwait(); return account?.IdToken ?? throw new AuthenticationException("Could not perform SilentSignIn by Google."); } @@ -39,7 +40,7 @@ public async Task SignIn(IUiContext uiContext) _taskCompletionSource = new TaskCompletionSource(); appUiContext.ActivityEvent.ActivityResultEvent += Activity_OnActivityResult; appUiContext.ActivityEvent.Activity.StartActivityForResult(googleSignInClient.SignInIntent, SignInIntentId); - var account = await _taskCompletionSource.Task; + var account = await _taskCompletionSource.Task.VhConfigureAwait(); if (account.IdToken == null) throw new ArgumentNullException(account.IdToken); @@ -63,7 +64,7 @@ public async Task SignOut(IUiContext uiContext) { var appUiContext = (AndroidUiContext)uiContext; using var googleSignInClient = GoogleSignIn.GetClient(appUiContext.Activity, _googleSignInOptions); - await googleSignInClient.SignOutAsync(); + await googleSignInClient.SignOutAsync().VhConfigureAwait(); } private void Activity_OnActivityResult(object? sender, ActivityResultEventArgs e) @@ -77,7 +78,7 @@ private async Task ProcessSignedInAccountFromIntent(Intent? intent) { try { - var googleSignInAccount = await GoogleSignIn.GetSignedInAccountFromIntentAsync(intent); + var googleSignInAccount = await GoogleSignIn.GetSignedInAccountFromIntentAsync(intent).VhConfigureAwait(); _taskCompletionSource?.SetResult(googleSignInAccount); } catch (Exception e) diff --git a/VpnHood.Client.App.Android.GooglePlay/GooglePlayBillingService.cs b/VpnHood.Client.App.Android.GooglePlay/GooglePlayBillingService.cs index afda8b78b..9e7b1d0f3 100644 --- a/VpnHood.Client.App.Android.GooglePlay/GooglePlayBillingService.cs +++ b/VpnHood.Client.App.Android.GooglePlay/GooglePlayBillingService.cs @@ -5,6 +5,7 @@ using VpnHood.Client.Device; using VpnHood.Client.Device.Droid; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.Droid.GooglePlay; @@ -48,7 +49,7 @@ private void PurchasesUpdatedListener(BillingResult billingResult, IList GetSubscriptionPlans() { - await EnsureConnected(); + await EnsureConnected().VhConfigureAwait(); // Check if the purchase subscription is supported on the user's device try @@ -76,7 +77,7 @@ public async Task GetSubscriptionPlans() // Get products list from GooglePlay. try { - var response = await _billingClient.QueryProductDetailsAsync(productDetailsParams); + var response = await _billingClient.QueryProductDetailsAsync(productDetailsParams).VhConfigureAwait(); if (response.Result.ResponseCode != BillingResponseCode.Ok) throw new Exception($"Could not get products from google play. BillingResponseCode: {response.Result.ResponseCode}"); if (!response.ProductDetails.Any()) throw new Exception($"Product list is empty. ProductList: {response.ProductDetails}"); @@ -107,7 +108,7 @@ public async Task GetSubscriptionPlans() public async Task Purchase(IUiContext uiContext, string planId) { var appUiContext = (AndroidUiContext)uiContext; - await EnsureConnected(); + await EnsureConnected().VhConfigureAwait(); if (_authenticationService.UserId == null) throw new AuthenticationException(); @@ -138,7 +139,7 @@ public async Task Purchase(IUiContext uiContext, string planId) throw CreateBillingResultException(billingResult); _taskCompletionSource = new TaskCompletionSource(); - var orderId = await _taskCompletionSource.Task; + var orderId = await _taskCompletionSource.Task.VhConfigureAwait(); return orderId; } catch (TaskCanceledException ex) @@ -165,7 +166,7 @@ private async Task EnsureConnected() try { - var billingResult = await _billingClient.StartConnectionAsync(); + var billingResult = await _billingClient.StartConnectionAsync().VhConfigureAwait(); if (billingResult.ResponseCode != BillingResponseCode.Ok) throw new Exception(billingResult.DebugMessage); diff --git a/VpnHood.Client.App.Android.GooglePlay/VpnHood.Client.App.Android.GooglePlay.csproj b/VpnHood.Client.App.Android.GooglePlay/VpnHood.Client.App.Android.GooglePlay.csproj index 86024fe1d..5e4094bb7 100644 --- a/VpnHood.Client.App.Android.GooglePlay/VpnHood.Client.App.Android.GooglePlay.csproj +++ b/VpnHood.Client.App.Android.GooglePlay/VpnHood.Client.App.Android.GooglePlay.csproj @@ -21,7 +21,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Android/MainActivity.cs b/VpnHood.Client.App.Android/MainActivity.cs index 37f7cb8d9..a759da81a 100644 --- a/VpnHood.Client.App.Android/MainActivity.cs +++ b/VpnHood.Client.App.Android/MainActivity.cs @@ -2,6 +2,7 @@ using Android.Content.PM; using Android.Service.QuickSettings; using Android.Views; +using Java.Lang; using VpnHood.Client.App.Droid.Common.Activities; using VpnHood.Client.App.Droid.Properties; @@ -35,6 +36,7 @@ public class MainActivity : AndroidAppMainActivity protected override AndroidAppMainActivityHandler CreateMainActivityHandler() { + JavaSystem.SetProperty("debug.checkjni", "true"); //todo: remove return new AndroidAppWebViewMainActivityHandler(this, new AndroidMainActivityWebViewOptions { DefaultSpaPort = AssemblyInfo.DefaultSpaPort, diff --git a/VpnHood.Client.App.Android/VpnHood.Client.App.Android.csproj b/VpnHood.Client.App.Android/VpnHood.Client.App.Android.csproj index 2ba105366..b1dfa602a 100644 --- a/VpnHood.Client.App.Android/VpnHood.Client.App.Android.csproj +++ b/VpnHood.Client.App.Android/VpnHood.Client.App.Android.csproj @@ -5,8 +5,8 @@ VpnHood.Client.App.Droid Exe com.vpnhood.client.android.web - 520 - 4.5.520 + 522 + 4.5.522 23.0 @@ -30,7 +30,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Maui.Common/VpnHood.Client.App.Maui.Common.csproj b/VpnHood.Client.App.Maui.Common/VpnHood.Client.App.Maui.Common.csproj index 1b804243e..87b23d818 100644 --- a/VpnHood.Client.App.Maui.Common/VpnHood.Client.App.Maui.Common.csproj +++ b/VpnHood.Client.App.Maui.Common/VpnHood.Client.App.Maui.Common.csproj @@ -30,7 +30,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Resources/VpnHood.Client.App.Resources.csproj b/VpnHood.Client.App.Resources/VpnHood.Client.App.Resources.csproj index 0b9e57749..da6914cbf 100644 --- a/VpnHood.Client.App.Resources/VpnHood.Client.App.Resources.csproj +++ b/VpnHood.Client.App.Resources/VpnHood.Client.App.Resources.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Store/StoreAccountService.cs b/VpnHood.Client.App.Store/StoreAccountService.cs index 098b28ea9..67dacc065 100644 --- a/VpnHood.Client.App.Store/StoreAccountService.cs +++ b/VpnHood.Client.App.Store/StoreAccountService.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging; using VpnHood.Client.App.Abstractions; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; using VpnHood.Store.Api; namespace VpnHood.Client.App.Store; @@ -27,10 +28,10 @@ public StoreAccountService(IAppAuthenticationService authenticationService, var httpClient = Authentication.HttpClient; var authenticationClient = new AuthenticationClient(httpClient); - var currentUser = await authenticationClient.GetCurrentUserAsync(); + var currentUser = await authenticationClient.GetCurrentUserAsync().VhConfigureAwait(); var currentVpnUserClient = new CurrentVpnUserClient(httpClient); - var activeSubscription = await currentVpnUserClient.ListSubscriptionsAsync(_storeAppId, false, false); + var activeSubscription = await currentVpnUserClient.ListSubscriptionsAsync(_storeAppId, false, false).VhConfigureAwait(); var subscriptionLastOrder = activeSubscription.SingleOrDefault()?.LastOrder; var appAccount = new AppAccount @@ -55,7 +56,7 @@ public async Task WaitForProcessProviderOrder(string providerOrderId) { try { - var subscriptionOrder = await currentVpnUserClient.GetSubscriptionOrderByProviderOrderIdAsync(_storeAppId, providerOrderId); + var subscriptionOrder = await currentVpnUserClient.GetSubscriptionOrderByProviderOrderIdAsync(_storeAppId, providerOrderId).VhConfigureAwait(); if (subscriptionOrder.IsProcessed) return; throw new Exception("Order has not processed yet."); @@ -65,7 +66,7 @@ public async Task WaitForProcessProviderOrder(string providerOrderId) // We might encounter a ‘not exist’ exception. Therefore, we need to wait for Google to send the provider order to the Store. VhLogger.Instance.LogWarning(ex, ex.Message); if (counter == 5) throw; - await Task.Delay(TimeSpan.FromSeconds(5)); + await Task.Delay(TimeSpan.FromSeconds(5)).VhConfigureAwait(); } } } @@ -76,12 +77,12 @@ public async Task GetAccessKeys(string subscriptionId) var currentVpnUserClient = new CurrentVpnUserClient(httpClient); // todo: add includeAccessKey parameter and return accessKey in accessToken - var accessTokens = await currentVpnUserClient.ListAccessTokensAsync(_storeAppId, subscriptionId: Guid.Parse(subscriptionId)); + var accessTokens = await currentVpnUserClient.ListAccessTokensAsync(_storeAppId, subscriptionId: Guid.Parse(subscriptionId)).VhConfigureAwait(); var accessKeyList = new List(); foreach (var accessToken in accessTokens) { - var accessKey = await currentVpnUserClient.GetAccessKeyAsync(_storeAppId, accessToken.AccessTokenId); + var accessKey = await currentVpnUserClient.GetAccessKeyAsync(_storeAppId, accessToken.AccessTokenId).VhConfigureAwait(); accessKeyList.Add(accessKey); } diff --git a/VpnHood.Client.App.Store/StoreAuthenticationService.cs b/VpnHood.Client.App.Store/StoreAuthenticationService.cs index d165cefcb..58ed705cf 100644 --- a/VpnHood.Client.App.Store/StoreAuthenticationService.cs +++ b/VpnHood.Client.App.Store/StoreAuthenticationService.cs @@ -80,7 +80,7 @@ private ApiKey? ApiKey if (ApiKey.RefreshToken != null && ApiKey.RefreshToken.ExpirationTime < DateTime.UtcNow) { var authenticationClient = new AuthenticationClient(_httpClientWithoutAuth); - ApiKey = await authenticationClient.RefreshTokenAsync(new RefreshTokenRequest { RefreshToken = ApiKey.RefreshToken.Value }); + ApiKey = await authenticationClient.RefreshTokenAsync(new RefreshTokenRequest { RefreshToken = ApiKey.RefreshToken.Value }).VhConfigureAwait(); return ApiKey; } } @@ -95,11 +95,11 @@ private ApiKey? ApiKey if (uiContext == null) throw new Exception("UI context is not available."); - var idToken = _externalAuthenticationService != null ? await _externalAuthenticationService.SilentSignIn(uiContext) : null; + var idToken = _externalAuthenticationService != null ? await _externalAuthenticationService.SilentSignIn(uiContext).VhConfigureAwait() : null; if (!string.IsNullOrWhiteSpace(idToken)) { var authenticationClient = new AuthenticationClient(_httpClientWithoutAuth); - ApiKey = await authenticationClient.SignInAsync(new SignInRequest { IdToken = idToken }); + ApiKey = await authenticationClient.SignInAsync(new SignInRequest { IdToken = idToken }).VhConfigureAwait(); return ApiKey; } } @@ -116,8 +116,8 @@ public async Task SignInWithGoogle(IUiContext uiContext) if (_externalAuthenticationService == null) throw new InvalidOperationException("Google sign in is not supported."); - var idToken = await _externalAuthenticationService.SignIn(uiContext); - await SignInToVpnHoodStore(idToken, true); + var idToken = await _externalAuthenticationService.SignIn(uiContext).VhConfigureAwait(); + await SignInToVpnHoodStore(idToken, true).VhConfigureAwait(); } public async Task SignOut(IUiContext uiContext) @@ -128,7 +128,7 @@ public async Task SignOut(IUiContext uiContext) if (_externalAuthenticationService != null) - await _externalAuthenticationService.SignOut(uiContext); + await _externalAuthenticationService.SignOut(uiContext).VhConfigureAwait(); } private async Task SignInToVpnHoodStore(string idToken, bool autoSignUp) @@ -136,16 +136,18 @@ private async Task SignInToVpnHoodStore(string idToken, bool autoSignUp) var authenticationClient = new AuthenticationClient(_httpClientWithoutAuth); try { - ApiKey = await authenticationClient.SignInAsync(new SignInRequest - { - IdToken = idToken, - RefreshTokenType = RefreshTokenType.None - }); + ApiKey = await authenticationClient.SignInAsync( + new SignInRequest + { + IdToken = idToken, + RefreshTokenType = RefreshTokenType.None + }) + .VhConfigureAwait(); } catch (ApiException ex) { if (ex.ExceptionTypeName == "UnregisteredUserException" && autoSignUp) - await SignUpToVpnHoodStore(idToken); + await SignUpToVpnHoodStore(idToken).VhConfigureAwait(); else throw; } @@ -154,11 +156,13 @@ private async Task SignInToVpnHoodStore(string idToken, bool autoSignUp) private async Task SignUpToVpnHoodStore(string idToken) { var authenticationClient = new AuthenticationClient(_httpClientWithoutAuth); - ApiKey = await authenticationClient.SignUpAsync(new SignUpRequest - { - IdToken = idToken, - RefreshTokenType = RefreshTokenType.None - }); + ApiKey = await authenticationClient.SignUpAsync( + new SignUpRequest + { + IdToken = idToken, + RefreshTokenType = RefreshTokenType.None + }) + .VhConfigureAwait(); } public void Dispose() @@ -175,9 +179,9 @@ public class HttpClientHandlerAuth(StoreAuthenticationService accountService) : { protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - var apiKey = await accountService.TryGetApiKey(VpnHoodApp.Instance.UiContext); + var apiKey = await accountService.TryGetApiKey(VpnHoodApp.Instance.UiContext).VhConfigureAwait(); request.Headers.Authorization = apiKey != null ? new AuthenticationHeaderValue(apiKey.AccessToken.Scheme, apiKey.AccessToken.Value) : null; - return await base.SendAsync(request, cancellationToken); + return await base.SendAsync(request, cancellationToken).VhConfigureAwait(); } } } \ No newline at end of file diff --git a/VpnHood.Client.App.Store/StoreBillingService.cs b/VpnHood.Client.App.Store/StoreBillingService.cs index 09dc2b52f..68b6a6008 100644 --- a/VpnHood.Client.App.Store/StoreBillingService.cs +++ b/VpnHood.Client.App.Store/StoreBillingService.cs @@ -1,5 +1,6 @@ using VpnHood.Client.App.Abstractions; using VpnHood.Client.Device; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.Store; @@ -22,9 +23,9 @@ public async Task Purchase(IUiContext uiContext, string planId) try { PurchaseState = BillingPurchaseState.Started; - var providerOrderId = await billingService.Purchase(uiContext, planId); + var providerOrderId = await billingService.Purchase(uiContext, planId).VhConfigureAwait(); PurchaseState = BillingPurchaseState.Processing; - await storeAccountService.WaitForProcessProviderOrder(providerOrderId); + await storeAccountService.WaitForProcessProviderOrder(providerOrderId).VhConfigureAwait(); return providerOrderId; } diff --git a/VpnHood.Client.App.Store/VpnHood.Client.App.Store.csproj b/VpnHood.Client.App.Store/VpnHood.Client.App.Store.csproj index 0bbf8df80..f91a2f224 100644 --- a/VpnHood.Client.App.Store/VpnHood.Client.App.Store.csproj +++ b/VpnHood.Client.App.Store/VpnHood.Client.App.Store.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Swagger/Controllers/AppController.cs b/VpnHood.Client.App.Swagger/Controllers/AppController.cs index 1d6ef5401..eff0b2c24 100644 --- a/VpnHood.Client.App.Swagger/Controllers/AppController.cs +++ b/VpnHood.Client.App.Swagger/Controllers/AppController.cs @@ -73,7 +73,7 @@ public Task GetInstalledApps() } [HttpGet("ip-groups")] - public Task GetIpGroups() + public Task GetIpGroups() { throw new NotImplementedException(); } diff --git a/VpnHood.Client.App.WebServer/Api/IAppController.cs b/VpnHood.Client.App.WebServer/Api/IAppController.cs index 6a3f9e97f..d1172333c 100644 --- a/VpnHood.Client.App.WebServer/Api/IAppController.cs +++ b/VpnHood.Client.App.WebServer/Api/IAppController.cs @@ -19,7 +19,7 @@ public interface IAppController Task SetUserSettings(UserSettings userSettings); Task Log(); Task GetInstalledApps(); - Task GetIpGroups(); + Task GetIpGroups(); Task VersionCheck(); void VersionCheckPostpone(); void OpenAlwaysOnPage(); diff --git a/VpnHood.Client.App.WebServer/Controllers/AppController.cs b/VpnHood.Client.App.WebServer/Controllers/AppController.cs index d49a3f5ad..e2d1898d6 100644 --- a/VpnHood.Client.App.WebServer/Controllers/AppController.cs +++ b/VpnHood.Client.App.WebServer/Controllers/AppController.cs @@ -6,6 +6,7 @@ using VpnHood.Client.App.Settings; using VpnHood.Client.App.WebServer.Api; using VpnHood.Client.Device; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.WebServer.Controllers; @@ -14,7 +15,7 @@ internal class AppController : WebApiController, IAppController private static VpnHoodApp App => VpnHoodApp.Instance; private async Task GetRequestDataAsync() { - var json = await HttpContext.GetRequestBodyAsByteArrayAsync(); + var json = await HttpContext.GetRequestBodyAsByteArrayAsync().VhConfigureAwait(); var res = JsonSerializer.Deserialize(json, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); @@ -24,12 +25,12 @@ private async Task GetRequestDataAsync() [Route(HttpVerbs.Patch, "/configure")] public async Task Configure(ConfigParams configParams) { - configParams = await GetRequestDataAsync(); + configParams = await GetRequestDataAsync().VhConfigureAwait(); App.Services.AppCultureService.AvailableCultures = configParams.AvailableCultures; if (configParams.Strings != null) App.Resource.Strings = configParams.Strings; App.UpdateUi(); - return await GetConfig(); + return await GetConfig().VhConfigureAwait(); } [Route(HttpVerbs.Get, "/config")] @@ -104,7 +105,7 @@ public void ClearLastError() [Route(HttpVerbs.Put, "/user-settings")] public async Task SetUserSettings(UserSettings userSettings) { - userSettings = await GetRequestDataAsync(); + userSettings = await GetRequestDataAsync().VhConfigureAwait(); App.Settings.UserSettings = userSettings; App.Settings.Save(); } @@ -115,8 +116,8 @@ public async Task Log() Response.ContentType = MimeType.PlainText; await using var stream = HttpContext.OpenResponseStream(); await using var streamWriter = new StreamWriter(stream); - var log = await App.LogService.GetLog(); - await streamWriter.WriteAsync(log); + var log = await App.LogService.GetLog().VhConfigureAwait(); + await streamWriter.WriteAsync(log).VhConfigureAwait(); return ""; } @@ -127,15 +128,17 @@ public Task GetInstalledApps() } [Route(HttpVerbs.Get, "/ip-groups")] - public Task GetIpGroups() + public async Task GetIpGroups() { - return App.GetIpGroups(); + var ipGroupManager = await App.GetIpGroupManager().VhConfigureAwait(); + var ipGroupIds = await ipGroupManager.GetIpGroupIds().VhConfigureAwait(); + return ipGroupIds.Select(x=>new IpGroupInfo{IpGroupId = x}).ToArray(); } [Route(HttpVerbs.Patch, "/client-profiles/{clientProfileId}")] public async Task UpdateClientProfile(Guid clientProfileId, ClientProfileUpdateParams updateParams) { - updateParams = await GetRequestDataAsync(); + updateParams = await GetRequestDataAsync().VhConfigureAwait(); var clientProfile = App.ClientProfileService.Update(clientProfileId, updateParams); return clientProfile.ToInfo(); } @@ -144,7 +147,7 @@ public async Task UpdateClientProfile(Guid clientProfileId, C public async Task DeleteClientProfile(Guid clientProfileId) { if (clientProfileId == App.CurrentClientProfile?.ClientProfileId) - await App.Disconnect(true); + await App.Disconnect(true).VhConfigureAwait(); App.ClientProfileService.Remove(clientProfileId); } diff --git a/VpnHood.Client.App.WebServer/VpnHood.Client.App.WebServer.csproj b/VpnHood.Client.App.WebServer/VpnHood.Client.App.WebServer.csproj index 84fd70749..a914fe18e 100644 --- a/VpnHood.Client.App.WebServer/VpnHood.Client.App.WebServer.csproj +++ b/VpnHood.Client.App.WebServer/VpnHood.Client.App.WebServer.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.WebServer/VpnHoodAppWebServer.cs b/VpnHood.Client.App.WebServer/VpnHoodAppWebServer.cs index c0219db96..91d58b743 100644 --- a/VpnHood.Client.App.WebServer/VpnHoodAppWebServer.cs +++ b/VpnHood.Client.App.WebServer/VpnHoodAppWebServer.cs @@ -142,7 +142,8 @@ private static async Task ResponseSerializerCallback(IHttpContext context, objec context.Response.ContentType = MimeType.Json; await using var text = context.OpenResponseText(new UTF8Encoding(false)); await text.WriteAsync(JsonSerializer.Serialize(data, - new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })); + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })) + .VhConfigureAwait(); } // manage SPA fallback diff --git a/VpnHood.Client.App.Win.Common/VpnHood.Client.App.Win.Common.csproj b/VpnHood.Client.App.Win.Common/VpnHood.Client.App.Win.Common.csproj index 7b4767c29..b4965c2f1 100644 --- a/VpnHood.Client.App.Win.Common/VpnHood.Client.App.Win.Common.csproj +++ b/VpnHood.Client.App.Win.Common/VpnHood.Client.App.Win.Common.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App.Win.Common/WinAppUpdaterService.cs b/VpnHood.Client.App.Win.Common/WinAppUpdaterService.cs index b4c95c8a4..257a0c654 100644 --- a/VpnHood.Client.App.Win.Common/WinAppUpdaterService.cs +++ b/VpnHood.Client.App.Win.Common/WinAppUpdaterService.cs @@ -4,6 +4,7 @@ using VpnHood.Client.App.Abstractions; using VpnHood.Client.Device; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.Win.Common; @@ -25,7 +26,7 @@ public async Task Update(IUiContext uiContext) var process = Process.Start(updaterFilePath, "/justcheck"); if (process == null) return false; while (process is { HasExited: false }) - await Task.Delay(500); + await Task.Delay(500).VhConfigureAwait(); // install update if (process.ExitCode == 0) @@ -33,7 +34,7 @@ public async Task Update(IUiContext uiContext) process = Process.Start(updaterFilePath); if (process == null) return false; while (process is { HasExited: false }) - await Task.Delay(500); + await Task.Delay(500).VhConfigureAwait(); } // https://www.advancedinstaller.com/user-guide/updater.html#updater-return-codes diff --git a/VpnHood.Client.App.Win/VpnHood.Client.App.Win.csproj b/VpnHood.Client.App.Win/VpnHood.Client.App.Win.csproj index 957a98406..1c3a5b231 100644 --- a/VpnHood.Client.App.Win/VpnHood.Client.App.Win.csproj +++ b/VpnHood.Client.App.Win/VpnHood.Client.App.Win.csproj @@ -27,7 +27,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App/AppOptions.cs b/VpnHood.Client.App/AppOptions.cs index 9256397be..869b64f41 100644 --- a/VpnHood.Client.App/AppOptions.cs +++ b/VpnHood.Client.App/AppOptions.cs @@ -12,7 +12,7 @@ public class AppOptions public SocketFactory? SocketFactory { get; set; } public TimeSpan VersionCheckInterval { get; set; } = TimeSpan.FromHours(24); public Uri? UpdateInfoUrl { get; set; } - public bool UseIpGroupManager { get; set; } = true; + public bool UseInternalLocationService { get; set; } = true; public bool UseExternalLocationService { get; set; } = true; public AppResource Resource { get; set; } = new(); public string? AppGa4MeasurementId { get; set; } = "G-4LE99XKZYE"; diff --git a/VpnHood.Client.App/ClientProfiles/ClientProfileService.cs b/VpnHood.Client.App/ClientProfiles/ClientProfileService.cs index b7b515356..4907e863f 100644 --- a/VpnHood.Client.App/ClientProfiles/ClientProfileService.cs +++ b/VpnHood.Client.App/ClientProfiles/ClientProfileService.cs @@ -178,7 +178,7 @@ public async Task UpdateServerTokenByUrl(Token token) try { using var client = new HttpClient(); - var encryptedServerToken = await VhUtil.RunTask(client.GetStringAsync(token.ServerToken.Url), TimeSpan.FromSeconds(20)); + var encryptedServerToken = await VhUtil.RunTask(client.GetStringAsync(token.ServerToken.Url), TimeSpan.FromSeconds(20)).VhConfigureAwait(); var newServerToken = ServerToken.Decrypt(token.ServerToken.Secret, encryptedServerToken); // return older only if token body is same and created time is newer diff --git a/VpnHood.Client.App/IpGroup.cs b/VpnHood.Client.App/IpGroup.cs index e78b97e57..a11d50fed 100644 --- a/VpnHood.Client.App/IpGroup.cs +++ b/VpnHood.Client.App/IpGroup.cs @@ -1,8 +1,16 @@ -namespace VpnHood.Client.App; +using VpnHood.Common.Net; + +namespace VpnHood.Client.App; public class IpGroup { public required string IpGroupId { get; init; } - public required string IpGroupName { get; init; } - + public required IpRangeOrderedList IpRanges { get; init; } + public IpGroupInfo ToInfo() + { + return new IpGroupInfo + { + IpGroupId = IpGroupId, + }; + } } \ No newline at end of file diff --git a/VpnHood.Client.App/IpGroupBuilder.cs b/VpnHood.Client.App/IpGroupBuilder.cs new file mode 100644 index 000000000..95f67dc06 --- /dev/null +++ b/VpnHood.Client.App/IpGroupBuilder.cs @@ -0,0 +1,61 @@ +using System.IO.Compression; +using System.Net.Sockets; +using System.Numerics; +using Microsoft.Extensions.Logging; +using VpnHood.Common.Logging; +using VpnHood.Common.Net; +using VpnHood.Common.Utils; + +namespace VpnHood.Client.App; + +public class IpGroupBuilder +{ + public static async Task BuildIpGroupArchiveFromIp2Location(Stream crvStream, string outputZipFile) + { + var ipGroups = await LoadIp2Location(crvStream).VhConfigureAwait(); + + // Building the IpGroups directory structure + VhLogger.Instance.LogTrace("Building the optimized Ip2Location archive..."); + await using var outputStream = File.Create(outputZipFile); + using var newArchive = new ZipArchive(outputStream, ZipArchiveMode.Create, leaveOpen: true); + foreach (var ipGroup in ipGroups) + { + var ipRanges = new IpRangeOrderedList(ipGroup.Value); + var entry = newArchive.CreateEntry($"{ipGroup.Key}.ips"); + await using var entryStream = entry.Open(); + ipRanges.Serialize(entryStream); + } + } + + private static async Task>> LoadIp2Location(Stream ipLocationsStream) + { + // extract IpGroups + var ipGroupIpRanges = new Dictionary>(); + using var streamReader = new StreamReader(ipLocationsStream); + while (!streamReader.EndOfStream) + { + var line = await streamReader.ReadLineAsync().VhConfigureAwait(); + var items = line.Replace("\"", "").Split(','); + if (items.Length != 4) + continue; + + var ipGroupId = items[2].ToLower(); + if (ipGroupId == "-") continue; + if (ipGroupId == "um") ipGroupId = "us"; + if (!ipGroupIpRanges.TryGetValue(ipGroupId, out var ipRanges)) + { + ipRanges = []; + ipGroupIpRanges.Add(ipGroupId, ipRanges); + } + + var addressFamily = items[0].Length > 10 || items[1].Length > 10 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork; + var ipRange = new IpRange( + IPAddressUtil.FromBigInteger(BigInteger.Parse(items[0]), addressFamily), + IPAddressUtil.FromBigInteger(BigInteger.Parse(items[1]), addressFamily)); + + ipRanges.Add(ipRange.IsIPv4MappedToIPv6 ? ipRange.MapToIPv4() : ipRange); + } + + return ipGroupIpRanges; + } +} \ No newline at end of file diff --git a/VpnHood.Client.App/IpGroupInfo.cs b/VpnHood.Client.App/IpGroupInfo.cs new file mode 100644 index 000000000..3353e4f1e --- /dev/null +++ b/VpnHood.Client.App/IpGroupInfo.cs @@ -0,0 +1,24 @@ +using System.Globalization; + +namespace VpnHood.Client.App; + +public class IpGroupInfo +{ + public required string IpGroupId { get; init; } + + public string IpGroupName + { + get + { + try + { + var regionInfo = new RegionInfo(IpGroupId); + return regionInfo.EnglishName; + } + catch (Exception) + { + return IpGroupId; + } + } + } +} \ No newline at end of file diff --git a/VpnHood.Client.App/IpGroupManager.cs b/VpnHood.Client.App/IpGroupManager.cs index af812beb1..f32170e09 100644 --- a/VpnHood.Client.App/IpGroupManager.cs +++ b/VpnHood.Client.App/IpGroupManager.cs @@ -1,10 +1,8 @@ using System.IO.Compression; using System.Net; using System.Net.Sockets; -using System.Numerics; -using System.Text.Json; -using System.Text.RegularExpressions; using Microsoft.Extensions.Logging; +using VpnHood.Common.Exceptions; using VpnHood.Common.Logging; using VpnHood.Common.Net; using VpnHood.Common.Utils; @@ -13,174 +11,93 @@ namespace VpnHood.Client.App; public class IpGroupManager { - private IpGroup[]? _ipGroups; + private readonly ZipArchive _zipArchive; + private string[]? _ipGroupIds; + private readonly Dictionary _ipGroupIpRanges = new(); - private string IpGroupsFolderPath => Path.Combine(StorageFolder, "ipgroups"); - private string IpGroupsFilePath => Path.Combine(StorageFolder, "ipgroups.json"); - private string VersionFilePath => Path.Combine(StorageFolder, "version.txt"); - private string GetIpGroupFilePath(string ipGroup) => Path.Combine(IpGroupsFolderPath, ipGroup + ".json"); - public string StorageFolder { get; } - - private IpGroupManager(string storageFolder) + private IpGroupManager(ZipArchive zipArchive) { - StorageFolder = storageFolder; + _zipArchive = zipArchive; } - public static Task Create(string storageFolder) + + public static Task Create(ZipArchive zipArchive) { - var ret = new IpGroupManager(storageFolder); + var ret = new IpGroupManager(zipArchive); return Task.FromResult(ret); } - public async Task InitByIp2LocationZipStream(ZipArchiveEntry archiveEntry) + public Task GetIpGroupIds() { - var newVersion = archiveEntry.LastWriteTime.ToUniversalTime().ToString("u"); - var oldVersion = "NotFound"; - - // check is version changed - if (File.Exists(VersionFilePath)) - { - try - { - oldVersion = await File.ReadAllTextAsync(VersionFilePath); - if (oldVersion == newVersion) - return; - } - catch (Exception ex) - { - VhLogger.Instance.LogError(ex, "Could not read last version file. File: {File}", VersionFilePath); - } - } - - // Build new structure - VhLogger.Instance.LogInformation("Building IPLocation. OldVersion: {OldVersion}, NewVersion {NewVersion},", oldVersion, newVersion); - - // delete all files and other versions if any - if (Directory.Exists(IpGroupsFolderPath)) - { - VhLogger.Instance.LogTrace("Deleting the old IpGroups..."); - Directory.Delete(IpGroupsFolderPath, true); - } - - // Loading the ip2Location stream - VhLogger.Instance.LogTrace("Loading the ip2Location stream..."); - await using var ipLocationsStream = archiveEntry.Open(); - var ipGroupNetworks = await LoadIp2Location(ipLocationsStream); - - // Building the IpGroups directory structure - VhLogger.Instance.LogTrace("Building the IpGroups directory structure..."); - Directory.CreateDirectory(IpGroupsFolderPath); - foreach (var ipGroupNetwork in ipGroupNetworks) - { - ipGroupNetwork.Value.IpRanges = ipGroupNetwork.Value.IpRanges.ToArray().Sort().ToList(); - await File.WriteAllTextAsync(GetIpGroupFilePath(ipGroupNetwork.Key), JsonSerializer.Serialize(ipGroupNetwork.Value.IpRanges)); - } - - // write IpGroups file - var ipGroups = ipGroupNetworks.Select(x => - new IpGroup - { - IpGroupId = x.Value.IpGroupId, - IpGroupName = x.Value.IpGroupName - }) - .OrderBy(x => x.IpGroupName) + _ipGroupIds ??= _zipArchive.Entries + .Where(x=>Path.GetExtension(x.Name)==".ips") + .Select(x=>Path.GetFileNameWithoutExtension(x.Name)) .ToArray(); - await File.WriteAllTextAsync(IpGroupsFilePath, JsonSerializer.Serialize(ipGroups)); - // write version - await File.WriteAllTextAsync(VersionFilePath, newVersion); - _ipGroups = null; // clear cache + return Task.FromResult(_ipGroupIds); } - private static async Task> LoadIp2Location(Stream ipLocationsStream) + public async Task GetIpRanges(string ipGroupId) { - // extract IpGroups - var ipGroupNetworks = new Dictionary(); - using var streamReader = new StreamReader(ipLocationsStream); - while (!streamReader.EndOfStream) - { - var line = await streamReader.ReadLineAsync(); - var items = line.Replace("\"", "").Split(','); - if (items.Length != 4) - continue; - - var ipGroupId = items[2].ToLower(); - if (ipGroupId == "-") continue; - if (ipGroupId == "um") ipGroupId = "us"; - if (!ipGroupNetworks.TryGetValue(ipGroupId, out var ipGroupNetwork)) - { - var ipGroupName = ipGroupId switch - { - "us" => "United States", - "gb" => "United Kingdom", - _ => items[3] - }; - - ipGroupName = Regex.Replace(ipGroupName, @"\(.*?\)", "").Replace(" ", " "); - ipGroupNetwork = new IpGroupNetwork - { - IpGroupId = ipGroupId, - IpGroupName = ipGroupName - }; - - ipGroupNetworks.Add(ipGroupId, ipGroupNetwork); - } - - var addressFamily = items[0].Length > 10 || items[1].Length > 10 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork; - var ipRange = new IpRange( - IPAddressUtil.FromBigInteger(BigInteger.Parse(items[0]), addressFamily), - IPAddressUtil.FromBigInteger(BigInteger.Parse(items[1]), addressFamily)); - - ipGroupNetwork.IpRanges.Add(ipRange.IsIPv4MappedToIPv6 ? ipRange.MapToIPv4() : ipRange); - } - - return ipGroupNetworks; + var ipRanges = await GetIpRangesInternal(ipGroupId).VhConfigureAwait(); + _ipGroupIpRanges.TryAdd(ipGroupId, ipRanges); + return ipRanges; } - public async Task GetIpGroups() + private async Task GetIpRangesInternal(string ipGroupId) { - if (_ipGroups != null) - return _ipGroups; - - // no countries if there is no import - if (!File.Exists(IpGroupsFilePath)) - return []; + if (_ipGroupIpRanges.TryGetValue(ipGroupId, out var ipGroupRangeCache)) + return ipGroupRangeCache; - var json = await File.ReadAllTextAsync(IpGroupsFilePath); - _ipGroups = VhUtil.JsonDeserialize(json); - return _ipGroups; + try + { + await using var stream = _zipArchive.GetEntry($"{ipGroupId}.ips")?.Open() ?? throw new NotExistsException(); + return IpRangeOrderedList.Deserialize(stream); + } + catch (Exception ex) + { + VhLogger.Instance.LogError(ex, "Could not load ip ranges for {IpGroupId}", ipGroupId); + return IpRangeOrderedList.Empty; + } } - public async Task> GetIpRanges(string ipGroupId) + public async Task GetIpGroup(IPAddress ipAddress, string? lastIpGroupId) { - var filePath = GetIpGroupFilePath(ipGroupId); - var json = await File.ReadAllTextAsync(filePath); - var ipRanges = JsonSerializer.Deserialize(json) ?? throw new Exception($"Could not deserialize {filePath}!"); - var ip4MappedRanges = ipRanges.Where(x => x.AddressFamily==AddressFamily.InterNetwork).Select(x => x.MapToIPv6()); - var ret = ipRanges.Concat(ip4MappedRanges); - return ret; + return await FindIpGroup(ipAddress, lastIpGroupId).VhConfigureAwait() + ?? throw new NotExistsException($"Could not find any ip group for the given ip. IP: {VhLogger.Format(ipAddress)}"); } - // it is sequential search public async Task FindIpGroup(IPAddress ipAddress, string? lastIpGroupId) { - var ipGroups = await GetIpGroups(); - var lastIpGroup = ipGroups.FirstOrDefault(x => x.IpGroupId == lastIpGroupId); - // IpGroup - if (lastIpGroup != null) + if (lastIpGroupId != null) { - var ipRanges = await GetIpRanges(lastIpGroup.IpGroupId); + var ipRanges = await GetIpRanges(lastIpGroupId).VhConfigureAwait(); if (ipRanges.Any(x => x.IsInRange(ipAddress))) - return lastIpGroup; + { + _ipGroupIpRanges.TryAdd(lastIpGroupId, ipRanges); + return new IpGroup + { + IpGroupId = lastIpGroupId, + IpRanges = ipRanges + }; + } } // iterate through all groups - foreach (var ipGroup in ipGroups) + var ipGroupIds = await GetIpGroupIds(); + foreach (var ipGroupId in ipGroupIds) { - var ipRanges = await GetIpRanges(ipGroup.IpGroupId); + var ipRanges = await GetIpRanges(ipGroupId).VhConfigureAwait(); if (ipRanges.Any(x => x.IsInRange(ipAddress))) - return ipGroup; + { + _ipGroupIpRanges.TryAdd(ipGroupId, ipRanges); + return new IpGroup + { + IpGroupId = ipGroupId, + IpRanges = ipRanges + }; + } } return null; @@ -191,13 +108,13 @@ public async Task> GetIpRanges(string ipGroupId) try { var ipAddress = - await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetwork) ?? - await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetworkV6); + await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetwork).VhConfigureAwait() ?? + await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetworkV6).VhConfigureAwait(); if (ipAddress == null) return null; - var ipGroup = await FindIpGroup(ipAddress, null); + var ipGroup = await FindIpGroup(ipAddress, null).VhConfigureAwait(); return ipGroup?.IpGroupId; } catch (Exception ex) @@ -206,9 +123,4 @@ await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetwork) ?? return null; } } - - private class IpGroupNetwork : IpGroup - { - public List IpRanges { get; set; } = []; - } } \ No newline at end of file diff --git a/VpnHood.Client.App/IpGroupRange.cs b/VpnHood.Client.App/IpGroupRange.cs new file mode 100644 index 000000000..dc01b5f61 --- /dev/null +++ b/VpnHood.Client.App/IpGroupRange.cs @@ -0,0 +1,10 @@ +using VpnHood.Common.Net; + +namespace VpnHood.Client.App; + +public class IpGroupRange +{ + public required string IpGroupId { get; init; } + public required IpRangeOrderedList IpRanges { get; init; } + +} \ No newline at end of file diff --git a/VpnHood.Client.App/Resource.Designer.cs b/VpnHood.Client.App/Resource.Designer.cs index 77a1c39cc..dfb33280f 100644 --- a/VpnHood.Client.App/Resource.Designer.cs +++ b/VpnHood.Client.App/Resource.Designer.cs @@ -128,9 +128,9 @@ internal static string Exit { /// /// Looks up a localized resource of type System.Byte[]. /// - internal static byte[] IP2LOCATION_LITE_DB1_IPV6_CSV { + internal static byte[] IpLocations { get { - object obj = ResourceManager.GetObject("IP2LOCATION_LITE_DB1_IPV6_CSV", resourceCulture); + object obj = ResourceManager.GetObject("IpLocations", resourceCulture); return ((byte[])(obj)); } } diff --git a/VpnHood.Client.App/Resource.resx b/VpnHood.Client.App/Resource.resx index b055d8752..99d538c3e 100644 --- a/VpnHood.Client.App/Resource.resx +++ b/VpnHood.Client.App/Resource.resx @@ -6,7 +6,7 @@ Version 2.0 The primary goals of this format is to allow a simple XML format - that is mostly human-readable. The generation and parsing of the + that is mostly human readable. The generation and parsing of the various data types are done through the TypeConverter classes associated with the data types. @@ -139,9 +139,6 @@ Exit - - Resources\IP2LOCATION-LITE-DB1.IPV6.CSV.zip;System.Byte[], mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - Manage @@ -157,7 +154,7 @@ Unsupported file type. - + This server requires a display Ad but could not display it. @@ -166,4 +163,7 @@ Open in Browser + + Resources\IpLocations.zip;System.Byte[], mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + \ No newline at end of file diff --git a/VpnHood.Client.App/Resources/IP2LOCATION-LITE-DB1.IPV6.CSV.zip b/VpnHood.Client.App/Resources/IP2LOCATION-LITE-DB1.IPV6.CSV.zip deleted file mode 100644 index b94d77eb6..000000000 Binary files a/VpnHood.Client.App/Resources/IP2LOCATION-LITE-DB1.IPV6.CSV.zip and /dev/null differ diff --git a/VpnHood.Client.App/Resources/IpLocations.zip b/VpnHood.Client.App/Resources/IpLocations.zip new file mode 100644 index 000000000..0ae3c649c Binary files /dev/null and b/VpnHood.Client.App/Resources/IpLocations.zip differ diff --git a/VpnHood.Client.App/Services/AppAccountService.cs b/VpnHood.Client.App/Services/AppAccountService.cs index af95fe61d..9c3265468 100644 --- a/VpnHood.Client.App/Services/AppAccountService.cs +++ b/VpnHood.Client.App/Services/AppAccountService.cs @@ -28,9 +28,9 @@ internal class AppAccountService(VpnHoodApp vpnHoodApp, IAppAccountService accou return _appAccount; // Update cache from server and update local cache - _appAccount = await accountService.GetAccount(); + _appAccount = await accountService.GetAccount().VhConfigureAwait(); Directory.CreateDirectory(Path.GetDirectoryName(AppAccountFilePath)!); - await File.WriteAllTextAsync(AppAccountFilePath, JsonSerializer.Serialize(_appAccount)); + await File.WriteAllTextAsync(AppAccountFilePath, JsonSerializer.Serialize(_appAccount)).VhConfigureAwait(); return _appAccount; } diff --git a/VpnHood.Client.App/Services/AppAuthenticationService.cs b/VpnHood.Client.App/Services/AppAuthenticationService.cs index 99159fc37..e9a451c2c 100644 --- a/VpnHood.Client.App/Services/AppAuthenticationService.cs +++ b/VpnHood.Client.App/Services/AppAuthenticationService.cs @@ -1,5 +1,6 @@ using VpnHood.Client.App.Abstractions; using VpnHood.Client.Device; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.Services; @@ -12,14 +13,14 @@ internal class AppAuthenticationService(VpnHoodApp vpnHoodApp, IAppAuthenticatio public async Task SignInWithGoogle(IUiContext uiContext) { - await accountService.SignInWithGoogle(uiContext); - await vpnHoodApp.RefreshAccount(updateCurrentClientProfile:true); + await accountService.SignInWithGoogle(uiContext).VhConfigureAwait(); + await vpnHoodApp.RefreshAccount(updateCurrentClientProfile:true).VhConfigureAwait(); } public async Task SignOut(IUiContext uiContext) { - await accountService.SignOut(uiContext); - await vpnHoodApp.RefreshAccount(updateCurrentClientProfile: true); + await accountService.SignOut(uiContext).VhConfigureAwait(); + await vpnHoodApp.RefreshAccount(updateCurrentClientProfile: true).VhConfigureAwait(); } public void Dispose() diff --git a/VpnHood.Client.App/Services/AppBillingService.cs b/VpnHood.Client.App/Services/AppBillingService.cs index bff862932..dd1bae72d 100644 --- a/VpnHood.Client.App/Services/AppBillingService.cs +++ b/VpnHood.Client.App/Services/AppBillingService.cs @@ -1,5 +1,6 @@ using VpnHood.Client.App.Abstractions; using VpnHood.Client.Device; +using VpnHood.Common.Utils; namespace VpnHood.Client.App.Services; @@ -15,8 +16,8 @@ public Task GetSubscriptionPlans() public async Task Purchase(IUiContext uiContext, string planId) { - var ret = await billingService.Purchase(uiContext, planId); - await vpnHoodApp.RefreshAccount(updateCurrentClientProfile: true); + var ret = await billingService.Purchase(uiContext, planId).VhConfigureAwait(); + await vpnHoodApp.RefreshAccount(updateCurrentClientProfile: true).VhConfigureAwait(); return ret; } diff --git a/VpnHood.Client.App/VpnHood.Client.App.csproj b/VpnHood.Client.App/VpnHood.Client.App.csproj index f02614e7f..e3ff964fa 100644 --- a/VpnHood.Client.App/VpnHood.Client.App.csproj +++ b/VpnHood.Client.App/VpnHood.Client.App.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.App/VpnHoodApp.cs b/VpnHood.Client.App/VpnHoodApp.cs index 0324a00ac..ac26e8699 100644 --- a/VpnHood.Client.App/VpnHoodApp.cs +++ b/VpnHood.Client.App/VpnHoodApp.cs @@ -30,7 +30,7 @@ public class VpnHoodApp : Singleton, private const string FileNamePersistState = "state.json"; private const string FolderNameProfiles = "profiles"; private readonly SocketFactory? _socketFactory; - private readonly bool _useIpGroupManager; + private readonly bool _useInternalLocationService; private readonly bool _useExternalLocationService; private readonly string? _appGa4MeasurementId; private bool _hasConnectRequested; @@ -55,11 +55,10 @@ public class VpnHoodApp : Singleton, private VpnHoodClient? _client; private readonly bool? _logVerbose; private readonly bool? _logAnonymous; + private UserSettings _oldUserSettings; private SessionStatus? LastSessionStatus => _client?.SessionStatus ?? _lastSessionStatus; - private string TempFolderPath => Path.Combine(StorageFolderPath, "Temp"); - private string IpGroupsFolderPath => Path.Combine(TempFolderPath, "ipgroups"); private string VersionCheckFilePath => Path.Combine(StorageFolderPath, "version.json"); - + public string TempFolderPath => Path.Combine(StorageFolderPath, "Temp"); public event EventHandler? ConnectionStateChanged; public event EventHandler? UiHasChanged; public bool IsIdle => ConnectionState == AppConnectionState.None; @@ -95,8 +94,9 @@ private VpnHoodApp(IDevice device, AppOptions? options = default) Settings.BeforeSave += SettingsBeforeSave; ClientProfileService = new ClientProfileService(Path.Combine(StorageFolderPath, FolderNameProfiles)); SessionTimeout = options.SessionTimeout; + _oldUserSettings = VhUtil.JsonClone(UserSettings); _socketFactory = options.SocketFactory; - _useIpGroupManager = options.UseIpGroupManager; + _useInternalLocationService = options.UseInternalLocationService; _useExternalLocationService = options.UseExternalLocationService; _appGa4MeasurementId = options.AppGa4MeasurementId; _versionCheckInterval = options.VersionCheckInterval; @@ -259,7 +259,7 @@ private void FireConnectionStateChanged() public async ValueTask DisposeAsync() { - await Disconnect(); + await Disconnect().VhConfigureAwait(); Device.Dispose(); LogService.Dispose(); DisposeSingleton(); @@ -299,10 +299,10 @@ public async Task Connect(Guid? clientProfileId = null, string? serverLocation = { // disconnect current connection if (!IsIdle) - await Disconnect(true); + await Disconnect(true).VhConfigureAwait(); // request features for the first time - await RequestFeatures(cancellationToken); + await RequestFeatures(cancellationToken).VhConfigureAwait(); // set use default clientProfile and serverLocation serverLocation ??= UserSettings.ServerLocation; @@ -348,7 +348,7 @@ public async Task Connect(Guid? clientProfileId = null, string? serverLocation = // it slows down tests and does not need to be logged in normal situation if (diagnose) - VhLogger.Instance.LogInformation("Country: {Country}", await GetClientCountry()); + VhLogger.Instance.LogInformation("Country: {Country}", await GetClientCountry().VhConfigureAwait()); VhLogger.Instance.LogInformation("VpnHood Client is Connecting ..."); @@ -358,7 +358,7 @@ public async Task Connect(Guid? clientProfileId = null, string? serverLocation = cancellationToken = linkedCts.Token; // connect - await ConnectInternal(clientProfile.Token, _activeServerLocation, userAgent, true, cancellationToken); + await ConnectInternal(clientProfile.Token, _activeServerLocation, userAgent, true, cancellationToken).VhConfigureAwait(); } catch (Exception ex) { @@ -386,7 +386,7 @@ public async Task Connect(Guid? clientProfileId = null, string? serverLocation = private async Task CreatePacketCapture() { // create packet capture - var packetCapture = await Device.CreatePacketCapture(UiContext); + var packetCapture = await Device.CreatePacketCapture(UiContext).VhConfigureAwait(); // init packet capture if (packetCapture.IsMtuSupported) @@ -410,7 +410,7 @@ private async Task ConnectInternal(Token token, string? serverLocationInfo, stri $"TokenId: {VhLogger.FormatId(token.TokenId)}, SupportId: {VhLogger.FormatId(token.SupportId)}"); // calculate packetCaptureIpRanges - var packetCaptureIpRanges = IpNetwork.All.ToIpRanges(); + var packetCaptureIpRanges = new IpRangeOrderedList(IpNetwork.All.ToIpRanges()); if (!VhUtil.IsNullOrEmpty(UserSettings.PacketCaptureIncludeIpRanges)) packetCaptureIpRanges = packetCaptureIpRanges.Intersect(UserSettings.PacketCaptureIncludeIpRanges); if (!VhUtil.IsNullOrEmpty(UserSettings.PacketCaptureExcludeIpRanges)) @@ -425,7 +425,7 @@ private async Task ConnectInternal(Token token, string? serverLocationInfo, stri IncludeLocalNetwork = UserSettings.IncludeLocalNetwork, IpRangeProvider = this, AdProvider = this, - PacketCaptureIncludeIpRanges = packetCaptureIpRanges.ToArray(), + PacketCaptureIncludeIpRanges = packetCaptureIpRanges, MaxDatagramChannelCount = UserSettings.MaxDatagramChannelCount, ConnectTimeout = TcpTimeout, AllowAnonymousTracker = UserSettings.AllowAnonymousTracker, @@ -440,16 +440,16 @@ private async Task ConnectInternal(Token token, string? serverLocationInfo, stri // Create Client with a new PacketCapture if (_client != null) throw new Exception("Last client has not been disposed properly."); - var packetCapture = await CreatePacketCapture(); + var packetCapture = await CreatePacketCapture().VhConfigureAwait(); _client = new VpnHoodClient(packetCapture, Settings.ClientId, token, clientOptions); _client.StateChanged += Client_StateChanged; try { if (_hasDiagnoseStarted) - await Diagnoser.Diagnose(_client, cancellationToken); + await Diagnoser.Diagnose(_client, cancellationToken).VhConfigureAwait(); else - await Diagnoser.Connect(_client, cancellationToken); + await Diagnoser.Connect(_client, cancellationToken).VhConfigureAwait(); // set connected time ConnectedTime = DateTime.Now; @@ -467,16 +467,16 @@ private async Task ConnectInternal(Token token, string? serverLocationInfo, stri } catch (Exception) { - await _client.DisposeAsync(); + await _client.DisposeAsync().VhConfigureAwait(); _client = null; // try to update token from url after connection or error if ResponseAccessKey is not set // check _client is not null to make sure if (allowUpdateToken && !string.IsNullOrEmpty(token.ServerToken.Url) && - await ClientProfileService.UpdateServerTokenByUrl(token)) + await ClientProfileService.UpdateServerTokenByUrl(token).VhConfigureAwait()) { token = ClientProfileService.GetToken(token.TokenId); - await ConnectInternal(token, serverLocationInfo, userAgent, false, cancellationToken); + await ConnectInternal(token, serverLocationInfo, userAgent, false, cancellationToken).VhConfigureAwait(); return; } @@ -494,7 +494,7 @@ private async Task RequestFeatures(CancellationToken cancellationToken) try { Settings.IsQuickLaunchEnabled = - await Services.UiService.RequestQuickLaunch(RequiredUiContext, cancellationToken); + await Services.UiService.RequestQuickLaunch(RequiredUiContext, cancellationToken).VhConfigureAwait(); } catch (Exception ex) { @@ -512,7 +512,7 @@ private async Task RequestFeatures(CancellationToken cancellationToken) try { Settings.IsNotificationEnabled = - await Services.UiService.RequestNotification(RequiredUiContext, cancellationToken); + await Services.UiService.RequestNotification(RequiredUiContext, cancellationToken).VhConfigureAwait(); } catch (Exception ex) { @@ -549,9 +549,10 @@ private void SettingsBeforeSave(object sender, EventArgs e) // check is disconnect required var disconnectRequired = + (_oldUserSettings.TunnelClientCountry != UserSettings.TunnelClientCountry) || (_activeClientProfileId != null && UserSettings.ClientProfileId != _activeClientProfileId) || //ClientProfileId has been changed - (state.CanDisconnect && _activeServerLocation != state.ClientServerLocationInfo?.ServerLocation) || //ClientProfileId has been changed - (state.CanDisconnect && UserSettings.IncludeLocalNetwork != client.IncludeLocalNetwork); // IncludeLocalNetwork has been changed + (_activeServerLocation != state.ClientServerLocationInfo?.ServerLocation) || //ClientProfileId has been changed + (UserSettings.IncludeLocalNetwork != client.IncludeLocalNetwork); // IncludeLocalNetwork has been changed // disconnect if (state.CanDisconnect && disconnectRequired) @@ -571,6 +572,7 @@ private void SettingsBeforeSave(object sender, EventArgs e) UserSettings.CultureCode != null ? [UserSettings.CultureCode] : []; InitCulture(); + _oldUserSettings = VhUtil.JsonClone(UserSettings); } public async Task GetClientCountry() @@ -581,7 +583,7 @@ private void SettingsBeforeSave(object sender, EventArgs e) try { var ipLocationProvider = new IpLocationProviderFactory().CreateDefault("VpnHood-Client"); - var ipLocation = await ipLocationProvider.GetLocation(new HttpClient()); + var ipLocation = await ipLocationProvider.GetLocation(new HttpClient()).VhConfigureAwait(); _appPersistState.ClientCountryCode = ipLocation.CountryCode; } catch (Exception ex) @@ -591,10 +593,17 @@ private void SettingsBeforeSave(object sender, EventArgs e) } // try to get by ip group - if (_appPersistState.ClientCountryCode == null && _useIpGroupManager) + if (_appPersistState.ClientCountryCode == null && _useInternalLocationService) { - var ipGroupManager = await GetIpGroupManager(); - _appPersistState.ClientCountryCode ??= await ipGroupManager.GetCountryCodeByCurrentIp(); + try + { + var ipGroupManager = await GetIpGroupManager().VhConfigureAwait(); + _appPersistState.ClientCountryCode ??= await ipGroupManager.GetCountryCodeByCurrentIp().VhConfigureAwait(); + } + catch (Exception ex) + { + VhLogger.Instance.LogError(ex, "Could not find country code."); + } } // return last country @@ -605,7 +614,7 @@ public async Task ShowAd(string sessionId, CancellationToken cancellatio { if (Services.AdService == null) throw new Exception("AdService has not been initialized."); var adData = $"sid:{sessionId};ad:{Guid.NewGuid()}"; - await Services.AdService.ShowAd(RequiredUiContext, adData, cancellationToken); + await Services.AdService.ShowAd(RequiredUiContext, adData, cancellationToken).VhConfigureAwait(); return adData; } @@ -626,7 +635,7 @@ private void Client_StateChanged(object sender, EventArgs e) private readonly AsyncLock _disconnectLock = new(); public async Task Disconnect(bool byUser = false) { - using var lockAsync = await _disconnectLock.LockAsync(); + using var lockAsync = await _disconnectLock.LockAsync().VhConfigureAwait(); if (_isDisconnecting || IsIdle) return; @@ -651,7 +660,7 @@ public async Task Disconnect(bool byUser = false) // close client // do not wait for bye if user request disconnection if (_client != null) - await _client.DisposeAsync(waitForBye: !byUser); + await _client.DisposeAsync(waitForBye: !byUser).VhConfigureAwait(); LogService.Stop(); } @@ -673,41 +682,14 @@ public async Task Disconnect(bool byUser = false) } } - public async Task GetIpGroups() - { - var ipGroupManager = await GetIpGroupManager(); - return await ipGroupManager.GetIpGroups(); - } - - private async Task GetIpGroupManager() + public async Task GetIpGroupManager() { - using var asyncLock = await AsyncLock.LockAsync("GetIpGroupManager"); if (_ipGroupManager != null) return _ipGroupManager; - _ipGroupManager = await IpGroupManager.Create(IpGroupsFolderPath); - - // ignore country ip groups if not required usually by tests - if (!_useIpGroupManager) - return _ipGroupManager; - - // AddFromIp2Location if hash has been changed - try - { - _isLoadingIpGroup = true; - FireConnectionStateChanged(); - await using var memZipStream = new MemoryStream(App.Resource.IP2LOCATION_LITE_DB1_IPV6_CSV); - using var zipArchive = new ZipArchive(memZipStream); - var entry = zipArchive.GetEntry("IP2LOCATION-LITE-DB1.IPV6.CSV") ?? - throw new Exception("Could not find ip2location database."); - await _ipGroupManager.InitByIp2LocationZipStream(entry); - return _ipGroupManager; - } - finally - { - _isLoadingIpGroup = false; - FireConnectionStateChanged(); - } + var zipArchive = new ZipArchive(new MemoryStream(App.Resource.IpLocations), ZipArchiveMode.Read, leaveOpen: false); + _ipGroupManager = await IpGroupManager.Create(zipArchive).VhConfigureAwait(); + return _ipGroupManager; } public void VersionCheckPostpone() @@ -731,7 +713,7 @@ public async Task VersionCheck(bool force = false) // check version by app container try { - if (UiContext != null && Services.UpdaterService != null && await Services.UpdaterService.Update(UiContext)) + if (UiContext != null && Services.UpdaterService != null && await Services.UpdaterService.Update(UiContext).VhConfigureAwait()) { VersionCheckPostpone(); return; @@ -743,11 +725,12 @@ public async Task VersionCheck(bool force = false) } // check version by UpdateInfoUrl - _versionCheckResult = await VersionCheckByUpdateInfo(); + _versionCheckResult = await VersionCheckByUpdateInfo().VhConfigureAwait(); // save the result if (_versionCheckResult != null) - await File.WriteAllTextAsync(VersionCheckFilePath, JsonSerializer.Serialize(_versionCheckResult)); + await File.WriteAllTextAsync(VersionCheckFilePath, JsonSerializer.Serialize(_versionCheckResult)).VhConfigureAwait(); + else if (File.Exists(VersionCheckFilePath)) File.Delete(VersionCheckFilePath); } @@ -762,7 +745,7 @@ public async Task VersionCheck(bool force = false) VhLogger.Instance.LogTrace("Retrieving the latest publish info..."); using var httpClient = new HttpClient(); - var publishInfoJson = await httpClient.GetStringAsync(Features.UpdateInfoUrl); + var publishInfoJson = await httpClient.GetStringAsync(Features.UpdateInfoUrl).VhConfigureAwait(); var latestPublishInfo = VhUtil.JsonDeserialize(publishInfoJson); VersionStatus versionStatus; @@ -803,7 +786,7 @@ public async Task VersionCheck(bool force = false) } } - public async Task GetIncludeIpRanges(IPAddress clientIp) + public async Task GetIncludeIpRanges(IPAddress clientIp) { // calculate packetCaptureIpRanges var ipRanges = IpNetwork.All.ToIpRanges(); @@ -811,18 +794,40 @@ public async Task VersionCheck(bool force = false) if (!VhUtil.IsNullOrEmpty(UserSettings.ExcludeIpRanges)) ipRanges = ipRanges.Exclude(UserSettings.ExcludeIpRanges); // exclude client country IPs - if (!UserSettings.TunnelClientCountry) + if (UserSettings.TunnelClientCountry) + return ipRanges; + + VhLogger.Instance.LogTrace("Finding Country IPs for split tunneling. Country: {Country}", _appPersistState.ClientCountryName); + _isLoadingIpGroup = true; + FireConnectionStateChanged(); + try { - var ipGroupManager = await GetIpGroupManager(); - var ipGroup = await ipGroupManager.FindIpGroup(clientIp, _appPersistState.ClientCountryCode); - _appPersistState.ClientCountryCode = ipGroup?.IpGroupId; + if (!_useInternalLocationService) + throw new InvalidOperationException("Could not use internal location service because it is disabled."); + + var ipGroupManager = await GetIpGroupManager().VhConfigureAwait(); + var ipGroup = await ipGroupManager.GetIpGroup(clientIp, _appPersistState.ClientCountryCode ?? RegionInfo.CurrentRegion.Name).VhConfigureAwait(); + _appPersistState.ClientCountryCode = ipGroup.IpGroupId; VhLogger.Instance.LogInformation("Client Country is: {Country}", _appPersistState.ClientCountryName); - if (ipGroup != null) - ipRanges = ipRanges.Exclude(await ipGroupManager.GetIpRanges(ipGroup.IpGroupId)); + ipRanges = ipRanges.Exclude(ipGroup.IpRanges); } + catch (Exception ex) + { + VhLogger.Instance.LogError(ex, "Could not get ip locations of your country."); + if (!UserSettings.TunnelClientCountry) + { + UserSettings.TunnelClientCountry = true; + Settings.Save(); + } + } + + finally + { + _isLoadingIpGroup = false; + } - return ipRanges.ToArray(); + return ipRanges; } public async Task RefreshAccount(bool updateCurrentClientProfile = false) @@ -835,9 +840,9 @@ public async Task RefreshAccount(bool updateCurrentClientProfile = false) // update profiles // get access tokens from account - var account = await Services.AccountService.GetAccount(); + var account = await Services.AccountService.GetAccount().VhConfigureAwait(); var accessKeys = account?.SubscriptionId != null - ? await Services.AccountService.GetAccessKeys(account.SubscriptionId) + ? await Services.AccountService.GetAccessKeys(account.SubscriptionId).VhConfigureAwait() : []; ClientProfileService.UpdateFromAccount(accessKeys); diff --git a/VpnHood.Client.Device.Android/AndroidDevice.cs b/VpnHood.Client.Device.Android/AndroidDevice.cs index 3c3c99c0c..a4bd2b433 100644 --- a/VpnHood.Client.Device.Android/AndroidDevice.cs +++ b/VpnHood.Client.Device.Android/AndroidDevice.cs @@ -136,7 +136,7 @@ public async Task CreatePacketCapture(IUiContext? uiContext) try { androidUiContext.Activity.StartActivityForResult(prepareIntent, RequestVpnPermissionId); - await Task.WhenAny(_grantPermissionTaskSource.Task, Task.Delay(TimeSpan.FromMinutes(2))); + await Task.WhenAny(_grantPermissionTaskSource.Task, Task.Delay(TimeSpan.FromMinutes(2))).VhConfigureAwait(); if (!_grantPermissionTaskSource.Task.IsCompletedSuccessfully) throw new Exception("Could not grant VPN permission in the given time."); @@ -164,7 +164,7 @@ public async Task CreatePacketCapture(IUiContext? uiContext) // check is service started _startServiceTaskSource = new TaskCompletionSource(); - await Task.WhenAny(_startServiceTaskSource.Task, Task.Delay(TimeSpan.FromSeconds(10))); + await Task.WhenAny(_startServiceTaskSource.Task, Task.Delay(TimeSpan.FromSeconds(10))).VhConfigureAwait(); if (_packetCapture == null) throw new Exception("Could not start VpnService in the given time."); diff --git a/VpnHood.Client.Device.Android/AndroidPacketCapture.cs b/VpnHood.Client.Device.Android/AndroidPacketCapture.cs index f87c8cb9f..3f2c14f11 100644 --- a/VpnHood.Client.Device.Android/AndroidPacketCapture.cs +++ b/VpnHood.Client.Device.Android/AndroidPacketCapture.cs @@ -113,13 +113,13 @@ public void SendPacketToInbound(IPPacket ipPacket) _outStream?.Write(ipPacket.Bytes); } - public void SendPacketToInbound(IEnumerable ipPackets) + public void SendPacketToInbound(IList ipPackets) { foreach (var ipPacket in ipPackets) _outStream?.Write(ipPacket.Bytes); } - public void SendPacketToOutbound(IEnumerable ipPackets) + public void SendPacketToOutbound(IList ipPackets) { throw new NotSupportedException(); } diff --git a/VpnHood.Client.Device.Android/VpnHood.Client.Device.Android.csproj b/VpnHood.Client.Device.Android/VpnHood.Client.Device.Android.csproj index 223a4a889..924278826 100644 --- a/VpnHood.Client.Device.Android/VpnHood.Client.Device.Android.csproj +++ b/VpnHood.Client.Device.Android/VpnHood.Client.Device.Android.csproj @@ -21,7 +21,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.Device.WinDivert/VpnHood.Client.Device.WinDivert.csproj b/VpnHood.Client.Device.WinDivert/VpnHood.Client.Device.WinDivert.csproj index c2578a805..5885e0e90 100644 --- a/VpnHood.Client.Device.WinDivert/VpnHood.Client.Device.WinDivert.csproj +++ b/VpnHood.Client.Device.WinDivert/VpnHood.Client.Device.WinDivert.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client.Device.WinDivert/WinDivertPacketCapture.cs b/VpnHood.Client.Device.WinDivert/WinDivertPacketCapture.cs index 2b52f2101..a4ce42645 100644 --- a/VpnHood.Client.Device.WinDivert/WinDivertPacketCapture.cs +++ b/VpnHood.Client.Device.WinDivert/WinDivertPacketCapture.cs @@ -54,7 +54,7 @@ public virtual void ProtectSocket(Socket socket) $"{nameof(ProtectSocket)} is not supported by {GetType().Name}"); } - public void SendPacketToInbound(IEnumerable ipPackets) + public void SendPacketToInbound(IList ipPackets) { foreach (var ipPacket in ipPackets) SendPacket(ipPacket, false); @@ -70,7 +70,7 @@ public void SendPacketToOutbound(IPPacket ipPacket) SendPacket(ipPacket, true); } - public void SendPacketToOutbound(IEnumerable ipPackets) + public void SendPacketToOutbound(IList ipPackets) { foreach (var ipPacket in ipPackets) SendPacket(ipPacket, true); @@ -105,7 +105,7 @@ public void StartCapture() var phraseX = "true"; if (IncludeNetworks != null) { - var ipRanges = IpNetwork.ToIpRange(IncludeNetworks); + var ipRanges = IncludeNetworks.ToIpRanges(); var phrases = ipRanges.Select(x => x.FirstIpAddress.Equals(x.LastIpAddress) ? $"{Ip(x)}.DstAddr=={x.FirstIpAddress}" : $"({Ip(x)}.DstAddr>={x.FirstIpAddress} and {Ip(x)}.DstAddr<={x.LastIpAddress})"); diff --git a/VpnHood.Client.Device/IPacketCapture.cs b/VpnHood.Client.Device/IPacketCapture.cs index b6336d871..a49b35044 100644 --- a/VpnHood.Client.Device/IPacketCapture.cs +++ b/VpnHood.Client.Device/IPacketCapture.cs @@ -27,7 +27,7 @@ public interface IPacketCapture : IDisposable void StopCapture(); void ProtectSocket(Socket socket); void SendPacketToInbound(IPPacket ipPacket); - void SendPacketToInbound(IEnumerable packets); + void SendPacketToInbound(IList packets); void SendPacketToOutbound(IPPacket ipPacket); - void SendPacketToOutbound(IEnumerable ipPackets); + void SendPacketToOutbound(IList ipPackets); } \ No newline at end of file diff --git a/VpnHood.Client.Device/PacketReceivedEventArgs.cs b/VpnHood.Client.Device/PacketReceivedEventArgs.cs index 7a40dcf8d..3fe2eef1d 100644 --- a/VpnHood.Client.Device/PacketReceivedEventArgs.cs +++ b/VpnHood.Client.Device/PacketReceivedEventArgs.cs @@ -2,8 +2,8 @@ namespace VpnHood.Client.Device; -public sealed class PacketReceivedEventArgs(IPPacket[] ipPackets, IPacketCapture packetCapture) : EventArgs +public sealed class PacketReceivedEventArgs(IList ipPackets, IPacketCapture packetCapture) : EventArgs { - public IPPacket[] IpPackets { get; } = ipPackets; + public IList IpPackets { get; } = ipPackets; public IPacketCapture PacketCapture { get; } = packetCapture; } \ No newline at end of file diff --git a/VpnHood.Client.Device/VpnHood.Client.Device.csproj b/VpnHood.Client.Device/VpnHood.Client.Device.csproj index 72a9067f5..ab3f0d640 100644 --- a/VpnHood.Client.Device/VpnHood.Client.Device.csproj +++ b/VpnHood.Client.Device/VpnHood.Client.Device.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client/Abstractions/IIpRangeProvider.cs b/VpnHood.Client/Abstractions/IIpRangeProvider.cs index 3b508f7d6..100a7c00e 100644 --- a/VpnHood.Client/Abstractions/IIpRangeProvider.cs +++ b/VpnHood.Client/Abstractions/IIpRangeProvider.cs @@ -5,5 +5,5 @@ namespace VpnHood.Client.Abstractions; public interface IIpRangeProvider { - Task GetIncludeIpRanges(IPAddress clientIp); + Task GetIncludeIpRanges(IPAddress clientIp); } \ No newline at end of file diff --git a/VpnHood.Client/ClientHost.cs b/VpnHood.Client/ClientHost.cs index c6af03e1b..873271d5d 100644 --- a/VpnHood.Client/ClientHost.cs +++ b/VpnHood.Client/ClientHost.cs @@ -71,7 +71,7 @@ private async Task AcceptTcpClientLoop(TcpListener tcpListener) { while (!cancellationToken.IsCancellationRequested) { - var tcpClient = await tcpListener.AcceptTcpClientAsync(); + var tcpClient = await tcpListener.AcceptTcpClientAsync().VhConfigureAwait(); _ = ProcessClient(tcpClient, cancellationToken); } } @@ -201,8 +201,8 @@ private async Task ProcessClient(TcpClient orgTcpClient, CancellationToken cance await vpnHoodClient.AddPassthruTcpStream( new TcpClientStream(orgTcpClient, orgTcpClient.GetStream(), channelId), new IPEndPoint(natItem.DestinationAddress, natItem.DestinationPort), - channelId, - cancellationToken); + channelId, cancellationToken) + .VhConfigureAwait(); return; } @@ -218,7 +218,7 @@ await vpnHoodClient.AddPassthruTcpStream( }; // read the response - requestResult = await vpnHoodClient.SendRequest(request, cancellationToken); + requestResult = await vpnHoodClient.SendRequest(request, cancellationToken).VhConfigureAwait(); var proxyClientStream = requestResult.ClientStream; // create a StreamProxyChannel @@ -231,8 +231,8 @@ await vpnHoodClient.AddPassthruTcpStream( } catch (Exception ex) { - if (channel != null) await channel.DisposeAsync(); - if (requestResult != null) await requestResult.DisposeAsync(); + if (channel != null) await channel.DisposeAsync().VhConfigureAwait(); + if (requestResult != null) await requestResult.DisposeAsync().VhConfigureAwait(); orgTcpClient.Dispose(); VhLogger.LogError(GeneralEventId.StreamProxyChannel, ex, ""); } diff --git a/VpnHood.Client/ClientOptions.cs b/VpnHood.Client/ClientOptions.cs index 8f23013cb..13afd2b78 100644 --- a/VpnHood.Client/ClientOptions.cs +++ b/VpnHood.Client/ClientOptions.cs @@ -33,7 +33,7 @@ public class ClientOptions public bool IncludeLocalNetwork { get; set; } public IIpRangeProvider? IpRangeProvider { get; set; } public IAdProvider? AdProvider { get; set; } - public IpRange[] PacketCaptureIncludeIpRanges { get; set; } = IpNetwork.All.ToIpRanges().ToArray(); + public IpRangeOrderedList PacketCaptureIncludeIpRanges { get; set; } = new (IpNetwork.All.ToIpRanges()); public SocketFactory SocketFactory { get; set; } = new(); public int MaxDatagramChannelCount { get; set; } = 4; public string UserAgent { get; set; } = Environment.OSVersion.ToString(); diff --git a/VpnHood.Client/ClientUsageTracker.cs b/VpnHood.Client/ClientUsageTracker.cs index 6268f9552..99209066d 100644 --- a/VpnHood.Client/ClientUsageTracker.cs +++ b/VpnHood.Client/ClientUsageTracker.cs @@ -33,7 +33,7 @@ public Task RunJob() public async Task Report() { - using var lockAsync = await _reportLock.LockAsync(); + using var lockAsync = await _reportLock.LockAsync().VhConfigureAwait(); if (_disposed) throw new ObjectDisposedException(GetType().Name); @@ -56,7 +56,7 @@ public async Task Report() } }; - await _ga4Tracker.Track(tagEvent); + await _ga4Tracker.Track(tagEvent).VhConfigureAwait(); _lastTraffic = traffic; _lastRequestCount = requestCount; _lastConnectionCount = connectionCount; @@ -68,7 +68,7 @@ public async ValueTask DisposeAsync() { // Make sure no exception in dispose if (_clientStat.SessionTraffic - _lastTraffic != new Traffic()) - await Report(); + await Report().VhConfigureAwait(); } catch { diff --git a/VpnHood.Client/ConnectorServices/ConnectorService.cs b/VpnHood.Client/ConnectorServices/ConnectorService.cs index 9b67671a8..c81ab77b4 100644 --- a/VpnHood.Client/ConnectorServices/ConnectorService.cs +++ b/VpnHood.Client/ConnectorServices/ConnectorService.cs @@ -33,8 +33,8 @@ public async Task> SendRequest(ClientRequest reques await using var mem = new MemoryStream(); mem.WriteByte(1); mem.WriteByte(request.RequestCode); - await StreamUtil.WriteJsonAsync(mem, request, cancellationToken); - var ret = await SendRequest(mem.ToArray(), request.RequestId, cancellationToken); + await StreamUtil.WriteJsonAsync(mem, request, cancellationToken).VhConfigureAwait(); + var ret = await SendRequest(mem.ToArray(), request.RequestId, cancellationToken).VhConfigureAwait(); // log the response VhLogger.Instance.LogTrace(eventId, "Received a response... ErrorCode: {ErrorCode}.", ret.Response.ErrorCode); @@ -57,8 +57,8 @@ private async Task> SendRequest(byte[] request, str try { // we may use this buffer to encrypt so clone it for retry - await clientStream.Stream.WriteAsync((byte[])request.Clone(), cancellationToken); - var response = await ReadSessionResponse(clientStream.Stream, cancellationToken); + await clientStream.Stream.WriteAsync((byte[])request.Clone(), cancellationToken).VhConfigureAwait(); + var response = await ReadSessionResponse(clientStream.Stream, cancellationToken).VhConfigureAwait(); lock (Stat) Stat.ReusedConnectionSucceededCount++; return new ConnectorRequestResult { @@ -78,13 +78,13 @@ private async Task> SendRequest(byte[] request, str } // create free connection - clientStream = await GetTlsConnectionToServer(requestId, cancellationToken); + clientStream = await GetTlsConnectionToServer(requestId, cancellationToken).VhConfigureAwait(); // send request try { - await clientStream.Stream.WriteAsync(request, cancellationToken); - var response2 = await ReadSessionResponse(clientStream.Stream, cancellationToken); + await clientStream.Stream.WriteAsync(request, cancellationToken).VhConfigureAwait(); + var response2 = await ReadSessionResponse(clientStream.Stream, cancellationToken).VhConfigureAwait(); return new ConnectorRequestResult { Response = response2, @@ -100,7 +100,7 @@ private async Task> SendRequest(byte[] request, str private static async Task ReadSessionResponse(Stream stream, CancellationToken cancellationToken) where T : SessionResponse { - var message = await StreamUtil.ReadMessage(stream, cancellationToken); + var message = await StreamUtil.ReadMessage(stream, cancellationToken).VhConfigureAwait(); try { var response = VhUtil.JsonDeserialize(message); diff --git a/VpnHood.Client/ConnectorServices/ConnectorServiceBase.cs b/VpnHood.Client/ConnectorServices/ConnectorServiceBase.cs index b76cc8044..07f53434c 100644 --- a/VpnHood.Client/ConnectorServices/ConnectorServiceBase.cs +++ b/VpnHood.Client/ConnectorServices/ConnectorServiceBase.cs @@ -72,8 +72,8 @@ private async Task CreateClientStream(TcpClient tcpClient, Stream "\r\n"; // Send header and wait for its response - await sslStream.WriteAsync(Encoding.UTF8.GetBytes(header), cancellationToken); - await HttpUtil.ReadHeadersAsync(sslStream, cancellationToken); + await sslStream.WriteAsync(Encoding.UTF8.GetBytes(header), cancellationToken).VhConfigureAwait(); + await HttpUtil.ReadHeadersAsync(sslStream, cancellationToken).VhConfigureAwait(); return binaryStreamType == BinaryStreamType.None ? new TcpClientStream(tcpClient, sslStream, streamId) : new TcpClientStream(tcpClient, new BinaryStreamStandard(tcpClient.GetStream(), streamId, useBuffer), streamId, ReuseStreamClient); @@ -90,7 +90,7 @@ protected async Task GetTlsConnectionToServer(string streamId, Ca // Client.SessionTimeout does not affect in ConnectAsync VhLogger.Instance.LogTrace(GeneralEventId.Tcp, "Connecting to Server... EndPoint: {EndPoint}", VhLogger.Format(tcpEndPoint)); - await VhUtil.RunTask(tcpClient.ConnectAsync(tcpEndPoint.Address, tcpEndPoint.Port), TcpConnectTimeout, cancellationToken); + await VhUtil.RunTask(tcpClient.ConnectAsync(tcpEndPoint.Address, tcpEndPoint.Port), TcpConnectTimeout, cancellationToken).VhConfigureAwait(); // Establish a TLS connection var sslStream = new SslStream(tcpClient.GetStream(), true, UserCertificateValidationCallback); @@ -99,9 +99,10 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions { TargetHost = hostName, EnabledSslProtocols = SslProtocols.None // auto - }, cancellationToken); + }, cancellationToken) + .VhConfigureAwait(); - var clientStream = await CreateClientStream(tcpClient, sslStream, streamId, cancellationToken); + var clientStream = await CreateClientStream(tcpClient, sslStream, streamId, cancellationToken).VhConfigureAwait(); lock (Stat) Stat.CreatedConnectionCount++; return clientStream; } diff --git a/VpnHood.Client/Diagnosing/DiagnoseUtil.cs b/VpnHood.Client/Diagnosing/DiagnoseUtil.cs index c1d09189c..d801e73d6 100644 --- a/VpnHood.Client/Diagnosing/DiagnoseUtil.cs +++ b/VpnHood.Client/Diagnosing/DiagnoseUtil.cs @@ -33,7 +33,7 @@ public class DiagnoseUtil Exception? exception = null; while (tasks.Length > 0) { - var task = await Task.WhenAny(tasks); + var task = await Task.WhenAny(tasks).VhConfigureAwait(); exception = task.Result; if (task.Result == null) return null; //at least one task is success @@ -53,7 +53,7 @@ public class DiagnoseUtil using var httpClient = new HttpClient(); httpClient.Timeout = TimeSpan.FromMilliseconds(timeout); - var result = await httpClient.GetStringAsync(uri); + var result = await httpClient.GetStringAsync(uri).VhConfigureAwait(); if (result.Length < 100) throw new Exception("The http response data length is not expected!"); @@ -82,7 +82,7 @@ public class DiagnoseUtil "UdpTest: {UdpTestStatus}, DnsName: {DnsName}, NsServer: {NsServer}, Timeout: {Timeout}...", "Started", dnsName, nsIpEndPoint, timeout); - var res = await GetHostEntry(dnsName, nsIpEndPoint, udpClient, timeout); + var res = await GetHostEntry(dnsName, nsIpEndPoint, udpClient, timeout).VhConfigureAwait(); if (res.AddressList.Length == 0) throw new Exception("Could not find any host!"); @@ -113,7 +113,7 @@ public class DiagnoseUtil "PingTest: {PingTestStatus}, RemoteAddress: {RemoteAddress}, Timeout: {Timeout}...", "Started", logIpAddress, timeout); - var pingReply = await ping.SendPingAsync(ipAddress, timeout); + var pingReply = await ping.SendPingAsync(ipAddress, timeout).VhConfigureAwait(); if (pingReply.Status != IPStatus.Success) throw new Exception($"Status: {pingReply.Status}"); @@ -175,8 +175,8 @@ public static async Task GetHostEntry(string host, IPEndPoint dnsEn var buffer = ms.ToArray(); udpClient.Client.SendTimeout = timeout; udpClient.Client.ReceiveTimeout = timeout; - await udpClient.SendAsync(buffer, buffer.Length, dnsEndPoint); - var receiveTask = await VhUtil.RunTask(udpClient.ReceiveAsync(), TimeSpan.FromMilliseconds(timeout)); + await udpClient.SendAsync(buffer, buffer.Length, dnsEndPoint).VhConfigureAwait(); + var receiveTask = await VhUtil.RunTask(udpClient.ReceiveAsync(), TimeSpan.FromMilliseconds(timeout)).VhConfigureAwait(); buffer = receiveTask.Buffer; //The response message has the same header and question structure, so we move index to the answer part directly. diff --git a/VpnHood.Client/Diagnosing/Diagnoser.cs b/VpnHood.Client/Diagnosing/Diagnoser.cs index 95974eed4..fc1650ffa 100644 --- a/VpnHood.Client/Diagnosing/Diagnoser.cs +++ b/VpnHood.Client/Diagnosing/Diagnoser.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.Logging; using VpnHood.Client.Exceptions; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; namespace VpnHood.Client.Diagnosing; @@ -34,7 +35,7 @@ public async Task Connect(VpnHoodClient vpnHoodClient, CancellationToken cancell { try { - await vpnHoodClient.Connect(cancellationToken); + await vpnHoodClient.Connect(cancellationToken).VhConfigureAwait(); } catch (OperationCanceledException) { @@ -44,7 +45,7 @@ public async Task Connect(VpnHoodClient vpnHoodClient, CancellationToken cancell { VhLogger.Instance.LogTrace("Checking the Internet connection..."); IsWorking = true; - if (!await NetworkCheck()) + if (!await NetworkCheck().VhConfigureAwait()) throw new NoInternetException(); throw; @@ -61,13 +62,13 @@ public async Task Diagnose(VpnHoodClient vpnHoodClient, CancellationToken cancel { VhLogger.Instance.LogTrace("Checking the Internet connection..."); IsWorking = true; - if (!await NetworkCheck()) + if (!await NetworkCheck().VhConfigureAwait()) throw new NoInternetException(); // ping server VhLogger.Instance.LogTrace("Checking the VpnServer ping..."); - var hostEndPoint = await ServerTokenHelper.ResolveHostEndPoint(vpnHoodClient.Token.ServerToken); - var pingRes = await DiagnoseUtil.CheckPing([hostEndPoint.Address], NsTimeout, true); + var hostEndPoint = await ServerTokenHelper.ResolveHostEndPoint(vpnHoodClient.Token.ServerToken).VhConfigureAwait(); + var pingRes = await DiagnoseUtil.CheckPing([hostEndPoint.Address], NsTimeout, true).VhConfigureAwait(); if (pingRes == null) VhLogger.Instance.LogTrace("Pinging server is OK."); else @@ -75,12 +76,12 @@ public async Task Diagnose(VpnHoodClient vpnHoodClient, CancellationToken cancel // VpnConnect IsWorking = false; - await vpnHoodClient.Connect(cancellationToken); + await vpnHoodClient.Connect(cancellationToken).VhConfigureAwait(); VhLogger.Instance.LogTrace("Checking the Vpn Connection..."); IsWorking = true; - await Task.Delay(2000, cancellationToken); // connections can not be established on android immediately - if (!await NetworkCheck()) + await Task.Delay(2000, cancellationToken).VhConfigureAwait(); // connections can not be established on android immediately + if (!await NetworkCheck().VhConfigureAwait()) throw new NoStableVpnException(); VhLogger.Instance.LogTrace("VPN has been established and tested successfully."); } @@ -102,7 +103,7 @@ private async Task NetworkCheck(bool checkPing = true, bool checkUdp = tru var taskHttps = DiagnoseUtil.CheckHttps(TestHttpUris, HttpTimeout); - await Task.WhenAll(taskPing, taskUdp, taskHttps); + await Task.WhenAll(taskPing, taskUdp, taskHttps).VhConfigureAwait(); var hasInternet = taskPing.Result == null && taskUdp.Result == null && taskHttps.Result == null; return hasInternet; } diff --git a/VpnHood.Client/ServerFinder.cs b/VpnHood.Client/ServerFinder.cs index 5a44d50c1..a65812216 100644 --- a/VpnHood.Client/ServerFinder.cs +++ b/VpnHood.Client/ServerFinder.cs @@ -35,7 +35,7 @@ public class ServerFinder(int maxDegreeOfParallelism = 10) }); // find endpoint status - HostEndPointStatus = await VerifyServersStatus(connectors, cancellationToken); + HostEndPointStatus = await VerifyServersStatus(connectors, cancellationToken).VhConfigureAwait(); return HostEndPointStatus.FirstOrDefault(x=>x.Value).Key; //todo check if it is null } @@ -51,12 +51,12 @@ private async Task> VerifyServersStatus(IE using var linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationTokenSource.Token, cancellationToken); await VhUtil.ParallelForEachAsync(connectors, async connector => { - var serverStatus = await VerifyServerStatus(connector, linkedCancellationTokenSource.Token); + var serverStatus = await VerifyServerStatus(connector, linkedCancellationTokenSource.Token).VhConfigureAwait(); hostEndPointStatus[connector.EndPointInfo.TcpEndPoint] = serverStatus; if (serverStatus) linkedCancellationTokenSource.Cancel(); // no need to continue, we find a server - }, maxDegreeOfParallelism, linkedCancellationTokenSource.Token); + }, maxDegreeOfParallelism, linkedCancellationTokenSource.Token).VhConfigureAwait(); } catch (OperationCanceledException) @@ -73,11 +73,13 @@ private static async Task VerifyServerStatus(ConnectorService connector, C try { var requestResult = await connector.SendRequest( - new ServerStatusRequest - { - RequestId = Guid.NewGuid().ToString(), - Message = "Hi, How are you?" - }, cancellationToken); + new ServerStatusRequest + { + RequestId = Guid.NewGuid().ToString(), + Message = "Hi, How are you?" + }, + cancellationToken) + .VhConfigureAwait(); // this should be already handled by the connector and never happen if (requestResult.Response.ErrorCode != SessionErrorCode.Ok) diff --git a/VpnHood.Client/ServerTokenHelper.cs b/VpnHood.Client/ServerTokenHelper.cs index 12942063b..b5ba3f38a 100644 --- a/VpnHood.Client/ServerTokenHelper.cs +++ b/VpnHood.Client/ServerTokenHelper.cs @@ -17,7 +17,7 @@ private static async Task ResolveHostEndPointsInternal(ServerToken try { VhLogger.Instance.LogInformation("Resolving IP from host name: {HostName}...", VhLogger.FormatHostName(serverToken.HostName)); - var hostEntities = await Dns.GetHostEntryAsync(serverToken.HostName); + var hostEntities = await Dns.GetHostEntryAsync(serverToken.HostName).VhConfigureAwait(); if (!VhUtil.IsNullOrEmpty(hostEntities.AddressList)) { return hostEntities.AddressList @@ -39,7 +39,7 @@ private static async Task ResolveHostEndPointsInternal(ServerToken public static async Task ResolveHostEndPoints(ServerToken serverToken) { - var endPoints = await ResolveHostEndPointsInternal(serverToken); + var endPoints = await ResolveHostEndPointsInternal(serverToken).VhConfigureAwait(); if (VhUtil.IsNullOrEmpty(endPoints)) throw new Exception("Could not resolve any host endpoint from AccessToken!"); @@ -48,13 +48,13 @@ public static async Task ResolveHostEndPoints(ServerToken serverTo if (ipV6EndPoints.Length == 0) return ipV4EndPoints; if (ipV4EndPoints.Length == 0) return ipV6EndPoints; - var publicAddressesIpV6 = await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetworkV6); + var publicAddressesIpV6 = await IPAddressUtil.GetPublicIpAddress(AddressFamily.InterNetworkV6).VhConfigureAwait(); return publicAddressesIpV6 != null ? ipV6EndPoints : ipV4EndPoints; //return IPv6 if user has access to IpV6 } public static async Task ResolveHostEndPoint(ServerToken serverToken) { - var endPoints = await ResolveHostEndPoints(serverToken); + var endPoints = await ResolveHostEndPoints(serverToken).VhConfigureAwait(); if (VhUtil.IsNullOrEmpty(endPoints)) throw new Exception("Could not resolve any host endpoint!"); diff --git a/VpnHood.Client/VpnHood.Client.csproj b/VpnHood.Client/VpnHood.Client.csproj index 7e9ac153a..b8647895b 100644 --- a/VpnHood.Client/VpnHood.Client.csproj +++ b/VpnHood.Client/VpnHood.Client.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Client/VpnHoodClient.cs b/VpnHood.Client/VpnHoodClient.cs index 381c4f868..01a52d216 100644 --- a/VpnHood.Client/VpnHoodClient.cs +++ b/VpnHood.Client/VpnHoodClient.cs @@ -31,7 +31,7 @@ public class VpnHoodClient : IDisposable, IAsyncDisposable private readonly Dictionary _includeIps = new(); private readonly int _maxDatagramChannelCount; private readonly IPacketCapture _packetCapture; - private readonly SendingPackets _sendingPacket = new(); + private readonly SendingPackets _sendingPackets = new(); private readonly ClientHost _clientHost; private readonly SemaphoreSlim _datagramChannelsSemaphore = new(1, 1); private readonly IIpRangeProvider? _ipRangeProvider; @@ -75,8 +75,8 @@ public class VpnHoodClient : IDisposable, IAsyncDisposable public SessionStatus SessionStatus { get; private set; } = new(); public Version Version { get; } public bool IncludeLocalNetwork { get; } - public IpRange[] IncludeIpRanges { get; private set; } = IpNetwork.All.ToIpRanges().ToArray(); - public IpRange[] PacketCaptureIncludeIpRanges { get; private set; } + public IpRangeOrderedList IncludeIpRanges { get; private set; } = new(IpNetwork.All.ToIpRanges()); + public IpRangeOrderedList PacketCaptureIncludeIpRanges { get; private set; } public string UserAgent { get; } public IPEndPoint? HostTcpEndPoint => _connectorService?.EndPointInfo.TcpEndPoint; public IPEndPoint? HostUdpEndPoint { get; private set; } @@ -152,7 +152,7 @@ public VpnHoodClient(IPacketCapture packetCapture, Guid clientId, Token token, C public IPAddress[] DnsServers { get => _dnsServers; - set + private set { _dnsServersIpV4 = value.Where(x => x.AddressFamily == AddressFamily.InterNetwork).ToArray(); _dnsServersIpV6 = value.Where(x => x.AddressFamily == AddressFamily.InterNetworkV6).ToArray(); @@ -199,14 +199,14 @@ internal async Task AddPassthruTcpStream(IClientStream orgTcpClientStream, IPEnd // connect to host var tcpClient = SocketFactory.CreateTcpClient(hostEndPoint.AddressFamily); - await VhUtil.RunTask(tcpClient.ConnectAsync(hostEndPoint.Address, hostEndPoint.Port), cancellationToken: cancellationToken); + await VhUtil.RunTask(tcpClient.ConnectAsync(hostEndPoint.Address, hostEndPoint.Port), cancellationToken: cancellationToken).VhConfigureAwait(); // create and add the channel var bypassChannel = new StreamProxyChannel(channelId, orgTcpClientStream, new TcpClientStream(tcpClient, tcpClient.GetStream(), channelId + ":host")); try { _proxyManager.AddChannel(bypassChannel); } - catch { await bypassChannel.DisposeAsync(); throw; } + catch { await bypassChannel.DisposeAsync().VhConfigureAwait(); throw; } } public async Task Connect(CancellationToken cancellationToken = default) @@ -240,14 +240,14 @@ public async Task Connect(CancellationToken cancellationToken = default) var endPointInfo = new ConnectorEndPointInfo { HostName = Token.ServerToken.HostName, - TcpEndPoint = await ServerTokenHelper.ResolveHostEndPoint(Token.ServerToken), + TcpEndPoint = await ServerTokenHelper.ResolveHostEndPoint(Token.ServerToken).VhConfigureAwait(), CertificateHash = Token.ServerToken.CertificateHash }; _connectorService = new ConnectorService(endPointInfo, SocketFactory, _tcpConnectTimeout); // Establish first connection and create a session using var linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(_cancellationTokenSource.Token, cancellationToken); - await ConnectInternal(linkedCancellationTokenSource.Token); + await ConnectInternal(linkedCancellationTokenSource.Token).VhConfigureAwait(); // Create Tcp Proxy Host _clientHost.Start(); @@ -260,7 +260,7 @@ public async Task Connect(CancellationToken cancellationToken = default) _packetCapture.StartCapture(); // disable IncludeIpRanges if it contains all networks - if (IncludeIpRanges.ToIpNetworks().IsAll()) + if (IncludeIpRanges.IsAll()) IncludeIpRanges = []; State = ClientState.Connected; @@ -287,24 +287,24 @@ private void ConfigPacketFilter(IPAddress hostIpAddress) _packetCapture.AddIpV6Address = true; //lets block ipV6 if not supported // Start with user PacketCaptureIncludeIpRanges - var includeIpRanges = (IEnumerable)PacketCaptureIncludeIpRanges; + var includeIpRanges = PacketCaptureIncludeIpRanges; // exclude server if ProtectSocket is not supported to prevent loop if (!_packetCapture.CanProtectSocket) - includeIpRanges = includeIpRanges.Exclude(new[] { new IpRange(hostIpAddress) }); + includeIpRanges = includeIpRanges.Exclude(hostIpAddress); // exclude local networks if (!IncludeLocalNetwork) includeIpRanges = includeIpRanges.Exclude(IpNetwork.LocalNetworks.ToIpRanges()); // Make sure CatcherAddress is included - includeIpRanges = includeIpRanges.Concat(new[] + includeIpRanges = includeIpRanges.Union(new[] { new IpRange(_clientHost.CatcherAddressIpV4), new IpRange(_clientHost.CatcherAddressIpV6) }); - _packetCapture.IncludeNetworks = includeIpRanges.Sort().ToIpNetworks().ToArray(); //sort and unify + _packetCapture.IncludeNetworks = includeIpRanges.ToIpNetworks().ToArray(); //sort and unify VhLogger.Instance.LogInformation($"PacketCapture Include Networks: {string.Join(", ", _packetCapture.IncludeNetworks.Select(x => x.ToString()))}"); } @@ -328,14 +328,14 @@ private void PacketCapture_OnPacketReceivedFromInbound(object sender, PacketRece try { - lock (_sendingPacket) // this method should not be called in multi-thread, if so we need to allocate the list per call + lock (_sendingPackets) // this method should not be called in multi-thread, if so we need to allocate the list per call { - _sendingPacket.Clear(); // prevent reallocation in this intensive event - var tunnelPackets = _sendingPacket.TunnelPackets; - var tcpHostPackets = _sendingPacket.TcpHostPackets; - var passthruPackets = _sendingPacket.PassthruPackets; - var proxyPackets = _sendingPacket.ProxyPackets; - var droppedPackets = _sendingPacket.DroppedPackets; + _sendingPackets.Clear(); // prevent reallocation in this intensive event + var tunnelPackets = _sendingPackets.TunnelPackets; + var tcpHostPackets = _sendingPackets.TcpHostPackets; + var passthruPackets = _sendingPackets.PassthruPackets; + var proxyPackets = _sendingPackets.ProxyPackets; + var droppedPackets = _sendingPackets.DroppedPackets; foreach (var ipPacket in e.IpPackets) { if (_disposed) return; @@ -420,23 +420,22 @@ private void PacketCapture_OnPacketReceivedFromInbound(object sender, PacketRece if (_autoWaitTime != null) { if (FastDateTime.Now - _autoWaitTime.Value < AutoWaitTimeout) - { tunnelPackets.Clear(); - } else - { - State = ClientState.Connecting; _autoWaitTime = null; - } } // send packets if (tunnelPackets.Count > 0 && ShouldManageDatagramChannels) _ = ManageDatagramChannels(_cancellationTokenSource.Token); - if (tunnelPackets.Count > 0) Tunnel.SendPackets(tunnelPackets).Wait(_cancellationTokenSource.Token); - if (passthruPackets.Count > 0) _packetCapture.SendPacketToOutbound(passthruPackets.ToArray()); + if (tunnelPackets.Count > 0) Tunnel.SendPackets(tunnelPackets, _cancellationTokenSource.Token); + if (passthruPackets.Count > 0) _packetCapture.SendPacketToOutbound(passthruPackets); if (proxyPackets.Count > 0) _proxyManager.SendPackets(proxyPackets).Wait(_cancellationTokenSource.Token); if (tcpHostPackets.Count > 0) _packetCapture.SendPacketToInbound(_clientHost.ProcessOutgoingPacket(tcpHostPackets)); } + + // set state outside the lock as it may raise an event + if (_autoWaitTime == null && State == ClientState.Waiting) + State = ClientState.Connecting; } catch (Exception ex) { @@ -479,7 +478,7 @@ private static bool IsIcmpControlMessage(IPPacket ipPacket) public bool IsInIpRange(IPAddress ipAddress) { // all IPs are included if there is no filter - if (VhUtil.IsNullOrEmpty(IncludeIpRanges)) + if (IncludeIpRanges.Count == 0) return true; // check tcp-loopback @@ -492,7 +491,7 @@ public bool IsInIpRange(IPAddress ipAddress) return isInRange; // check include - isInRange = IpRange.IsInSortedRanges(IncludeIpRanges, ipAddress); + isInRange = IncludeIpRanges.IsInRange(ipAddress); // cache the result // we really don't need to keep that much ips in the cache @@ -562,7 +561,7 @@ private bool UpdateDnsRequest(IPPacket ipPacket, bool outgoing) private async Task ManageDatagramChannels(CancellationToken cancellationToken) { - if (_disposed || !await _datagramChannelsSemaphore.WaitAsync(0, cancellationToken)) + if (_disposed || !await _datagramChannelsSemaphore.WaitAsync(0, cancellationToken).VhConfigureAwait()) return; if (!ShouldManageDatagramChannels) @@ -572,9 +571,9 @@ private async Task ManageDatagramChannels(CancellationToken cancellationToken) { // make sure only one UdpChannel exists for DatagramChannels if UseUdpChannel is on if (UseUdpChannel) - await AddUdpChannel(); + await AddUdpChannel().VhConfigureAwait(); else - await AddTcpDatagramChannel(cancellationToken); + await AddTcpDatagramChannel(cancellationToken).VhConfigureAwait(); } catch (Exception ex) { @@ -609,7 +608,7 @@ private async Task AddUdpChannel() catch { udpClient.Dispose(); - await udpChannel.DisposeAsync(); + await udpChannel.DisposeAsync().VhConfigureAwait(); UseUdpChannel = false; throw; } @@ -638,7 +637,7 @@ private async Task ConnectInternal(CancellationToken cancellationToken, bool all AllowRedirect = allowRedirect }; - await using var requestResult = await SendRequest(request, cancellationToken); + await using var requestResult = await SendRequest(request, cancellationToken).VhConfigureAwait(); var sessionResponse = requestResult.Response; if (sessionResponse.ServerProtocolVersion < 4) throw new SessionException(SessionErrorCode.UnsupportedServer, "This server is outdated and does not support this client!"); @@ -707,40 +706,36 @@ private async Task ConnectInternal(CancellationToken cancellationToken, bool all // PacketCaptureIpRanges if (!VhUtil.IsNullOrEmpty(sessionResponse.PacketCaptureIncludeIpRanges)) - PacketCaptureIncludeIpRanges = PacketCaptureIncludeIpRanges.Intersect(sessionResponse.PacketCaptureIncludeIpRanges).ToArray(); + PacketCaptureIncludeIpRanges = PacketCaptureIncludeIpRanges.Intersect(sessionResponse.PacketCaptureIncludeIpRanges); // IncludeIpRanges - if (!VhUtil.IsNullOrEmpty(sessionResponse.IncludeIpRanges) && !sessionResponse.IncludeIpRanges.ToIpNetworks().IsAll()) - IncludeIpRanges = IncludeIpRanges.Intersect(sessionResponse.IncludeIpRanges).ToArray(); + if (!VhUtil.IsNullOrEmpty(sessionResponse.IncludeIpRanges) && !sessionResponse.IncludeIpRanges.ToOrderedList().IsAll()) + IncludeIpRanges = IncludeIpRanges.Intersect(sessionResponse.IncludeIpRanges); // Get IncludeIpRange for clientIp - var filterIpRanges = _ipRangeProvider != null ? await _ipRangeProvider.GetIncludeIpRanges(sessionResponse.ClientPublicAddress) : null; + var filterIpRanges = _ipRangeProvider != null ? await _ipRangeProvider.GetIncludeIpRanges(sessionResponse.ClientPublicAddress).VhConfigureAwait() : null; if (!VhUtil.IsNullOrEmpty(filterIpRanges)) { - filterIpRanges = filterIpRanges.Concat(DnsServers.Select((x => new IpRange(x)))).ToArray(); - IncludeIpRanges = IncludeIpRanges.Intersect(filterIpRanges).ToArray(); + filterIpRanges = filterIpRanges.Union(DnsServers.Select((x => new IpRange(x)))); + IncludeIpRanges = IncludeIpRanges.Intersect(filterIpRanges); } // set DNS after setting IpFilters + VhLogger.Instance.LogInformation("Configuring Client DNS servers... DnsServers: {DnsServers}", string.Join(", ", DnsServers.Select(x => x.ToString()))); Stat.IsDnsServersAccepted = VhUtil.IsNullOrEmpty(DnsServers) || DnsServers.Any(IsInIpRange); // no servers means accept default - DnsServers = DnsServers.Where(IsInIpRange).ToArray(); if (!Stat.IsDnsServersAccepted) VhLogger.Instance.LogWarning("Client DNS servers have been ignored because the server does not route them."); + DnsServers = DnsServers.Where(IsInIpRange).ToArray(); if (VhUtil.IsNullOrEmpty(DnsServers)) { DnsServers = VhUtil.IsNullOrEmpty(sessionResponse.DnsServers) ? IPAddressUtil.GoogleDnsServers : sessionResponse.DnsServers; - IncludeIpRanges = IncludeIpRanges.Concat(DnsServers.Select(x => new IpRange(x))).Sort().ToArray(); + IncludeIpRanges = IncludeIpRanges.Union(DnsServers.Select(IpRange.FromIpAddress)); } if (VhUtil.IsNullOrEmpty(DnsServers?.Where(IsInIpRange).ToArray())) // make sure there is at least one DNS server throw new Exception("Could not specify any DNS server. The server is not configured properly."); - // Preparing tunnel - Tunnel.MaxDatagramChannelCount = sessionResponse.MaxDatagramChannelCount != 0 - ? Tunnel.MaxDatagramChannelCount = Math.Min(_maxDatagramChannelCount, sessionResponse.MaxDatagramChannelCount) - : _maxDatagramChannelCount; - // report Suppressed if (sessionResponse.SuppressedTo == SessionSuppressType.YourSelf) VhLogger.Instance.LogWarning("You suppressed a session of yourself!"); @@ -750,17 +745,23 @@ private async Task ConnectInternal(CancellationToken cancellationToken, bool all // show ad if required if (sessionResponse.IsAdRequired || sessionResponse.AdRequirement is not AdRequirement.None) - await ShowAd(sessionResponse.AdRequirement is AdRequirement.Flexible, cancellationToken); + await ShowAd(sessionResponse.AdRequirement is AdRequirement.Flexible, cancellationToken).VhConfigureAwait(); + + // Preparing tunnel + VhLogger.Instance.LogInformation("Configuring Datagram Channels..."); + Tunnel.MaxDatagramChannelCount = sessionResponse.MaxDatagramChannelCount != 0 + ? Tunnel.MaxDatagramChannelCount = Math.Min(_maxDatagramChannelCount, sessionResponse.MaxDatagramChannelCount) + : _maxDatagramChannelCount; // manage datagram channels - await ManageDatagramChannels(cancellationToken); + await ManageDatagramChannels(cancellationToken).VhConfigureAwait(); } catch (RedirectHostException ex) when (allowRedirect) { // todo; init new connector ConnectorService.EndPointInfo.TcpEndPoint = ex.RedirectHostEndPoint; - await ConnectInternal(cancellationToken, false); + await ConnectInternal(cancellationToken, false).VhConfigureAwait(); } } @@ -774,7 +775,7 @@ private async Task AddTcpDatagramChannel(CancellationToken cancellationToken) SessionKey = SessionKey }; - var requestResult = await SendRequest(request, cancellationToken); + var requestResult = await SendRequest(request, cancellationToken).VhConfigureAwait(); StreamDatagramChannel? channel = null; try { @@ -789,8 +790,8 @@ private async Task AddTcpDatagramChannel(CancellationToken cancellationToken) } catch { - if (channel != null) await channel.DisposeAsync(); - await requestResult.DisposeAsync(); + if (channel != null) await channel.DisposeAsync().VhConfigureAwait(); + await requestResult.DisposeAsync().VhConfigureAwait(); throw; } } @@ -804,7 +805,7 @@ internal async Task> SendRequest(ClientRequest requ try { // create a connection and send the request - var requestResult = await ConnectorService.SendRequest(request, cancellationToken); + var requestResult = await ConnectorService.SendRequest(request, cancellationToken).VhConfigureAwait(); // set SessionStatus if (requestResult.Response.AccessUsage != null) @@ -824,7 +825,7 @@ internal async Task> SendRequest(ClientRequest requ if (ex.SessionResponse.AccessUsage != null) SessionStatus.AccessUsage = ex.SessionResponse.AccessUsage; - // GeneralError and RedirectHost mean that the request accepted by server but there is an error for that request + // SessionException means that the request accepted by server but there is an error for that request _lastConnectionErrorTime = null; // close session if server has ended the session @@ -856,6 +857,7 @@ internal async Task> SendRequest(ClientRequest requ { _autoWaitTime = now; State = ClientState.Waiting; + VhLogger.Instance.LogWarning("Client is paused because of too many connection errors."); } // set connecting state if it could not establish any connection @@ -878,7 +880,8 @@ public async Task UpdateSessionStatus(CancellationToken cancellationToken = defa SessionId = SessionId, SessionKey = SessionKey }, - linkedCancellationTokenSource.Token); + linkedCancellationTokenSource.Token) + .VhConfigureAwait(); } private async Task SendByeRequest(CancellationToken cancellationToken) @@ -893,7 +896,8 @@ private async Task SendByeRequest(CancellationToken cancellationToken) SessionId = SessionId, SessionKey = SessionKey }, - cancellationToken); + cancellationToken) + .VhConfigureAwait(); } catch (Exception ex) { @@ -912,7 +916,7 @@ private async Task ShowAd(bool flexible, CancellationToken cancellationToken) throw new Exception("AppAdService has not been initialized."); _isWaitingForAd = true; - var adData = await _adProvider.ShowAd(SessionId.ToString(), cancellationToken); + var adData = await _adProvider.ShowAd(SessionId.ToString(), cancellationToken).VhConfigureAwait(); _ = SendAdReward(adData, cancellationToken); } catch (AdLoadException ex) when (flexible) @@ -948,7 +952,8 @@ private async Task SendAdReward(string adData, CancellationToken cancellationTok SessionKey = SessionKey, AdData = adData }, - cancellationToken); + cancellationToken) + .VhConfigureAwait(); } catch (Exception ex) { @@ -1002,7 +1007,7 @@ public ValueTask DisposeAsync() private readonly AsyncLock _disposeLock = new(); public async ValueTask DisposeAsync(bool waitForBye) { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; @@ -1031,24 +1036,24 @@ public async ValueTask DisposeAsync(bool waitForBye) var finalizeTask = Finalize(wasConnected); if (waitForBye) - await finalizeTask; + await finalizeTask.VhConfigureAwait(); } private async Task Finalize(bool wasConnected) { // Anonymous usage tracker - _clientUsageTracker?.DisposeAsync(); + _ = _clientUsageTracker?.DisposeAsync(); VhLogger.Instance.LogTrace("Disposing ClientHost..."); - await _clientHost.DisposeAsync(); + await _clientHost.DisposeAsync().VhConfigureAwait(); // Tunnel VhLogger.Instance.LogTrace("Disposing Tunnel..."); Tunnel.PacketReceived -= Tunnel_OnPacketReceived; - await Tunnel.DisposeAsync(); + await Tunnel.DisposeAsync().VhConfigureAwait(); VhLogger.Instance.LogTrace("Disposing ProxyManager..."); - await _proxyManager.DisposeAsync(); + await _proxyManager.DisposeAsync().VhConfigureAwait(); // dispose NAT VhLogger.Instance.LogTrace("Disposing Nat..."); @@ -1058,12 +1063,12 @@ private async Task Finalize(bool wasConnected) if (wasConnected && SessionId != 0 && SessionStatus.ErrorCode == SessionErrorCode.Ok) { using var cancellationTokenSource = new CancellationTokenSource(TunnelDefaults.TcpGracefulTimeout); - await SendByeRequest(cancellationTokenSource.Token); + await SendByeRequest(cancellationTokenSource.Token).VhConfigureAwait(); } // dispose ConnectorService VhLogger.Instance.LogTrace("Disposing ConnectorService..."); - await ConnectorService.DisposeAsync(); + await ConnectorService.DisposeAsync().VhConfigureAwait(); State = ClientState.Disposed; VhLogger.Instance.LogInformation("Bye Bye!"); @@ -1099,7 +1104,7 @@ public class ClientStat public bool IsUdpChannelSupported => _client.HostUdpEndPoint != null; public bool IsWaitingForAd => _client._isWaitingForAd; public bool IsDnsServersAccepted { get; internal set; } - public ServerLocationInfo? ServerLocationInfo{ get; internal set; } + public ServerLocationInfo? ServerLocationInfo { get; internal set; } internal ClientStat(VpnHoodClient vpnHoodClient) { diff --git a/VpnHood.Common/Client/ApiClientBase.cs b/VpnHood.Common/Client/ApiClientBase.cs index 5489730c1..61b698cb3 100644 --- a/VpnHood.Common/Client/ApiClientBase.cs +++ b/VpnHood.Common/Client/ApiClientBase.cs @@ -6,6 +6,7 @@ using System.Text.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using VpnHood.Common.Utils; // ReSharper disable UnusedMember.Global namespace VpnHood.Common.Client; @@ -51,7 +52,7 @@ protected virtual JsonSerializerOptions CreateSerializerSettings() { if (ReadResponseAsString) { - var responseText = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + var responseText = await response.Content.ReadAsStringAsync().VhConfigureAwait(); try { var typedBody = JsonSerializer.Deserialize(responseText, JsonSerializerSettings); @@ -66,8 +67,8 @@ protected virtual JsonSerializerOptions CreateSerializerSettings() try { - await using var responseStream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); - var typedBody = await JsonSerializer.DeserializeAsync(responseStream, JsonSerializerSettings, cancellationToken).ConfigureAwait(false); + await using var responseStream = await response.Content.ReadAsStreamAsync().VhConfigureAwait(); + var typedBody = await JsonSerializer.DeserializeAsync(responseStream, JsonSerializerSettings, cancellationToken).VhConfigureAwait(); return new HttpResult { ResponseMessage = response, Object = typedBody, Text= string.Empty }; } catch (JsonException exception) @@ -123,14 +124,14 @@ protected string ConvertToString(object? value, CultureInfo cultureInfo) protected async Task HttpSendAsync(HttpMethod httpMethod, string urlPart, Dictionary? parameters = null, object? data = null, CancellationToken cancellationToken = default) { - var res = await HttpSendExAsync(httpMethod, urlPart, parameters, data, cancellationToken); + var res = await HttpSendExAsync(httpMethod, urlPart, parameters, data, cancellationToken).VhConfigureAwait(); return res.Text; } protected async Task HttpSendAsync(HttpMethod httpMethod, string urlPart, Dictionary? parameters = null, object? data = null, CancellationToken cancellationToken = default) { - var res = await HttpSendExAsync(httpMethod, urlPart, parameters, data, cancellationToken); + var res = await HttpSendExAsync(httpMethod, urlPart, parameters, data, cancellationToken).VhConfigureAwait(); return res.Object; } @@ -149,13 +150,14 @@ protected async Task> HttpSendExAsync(HttpMethod httpMethod, st request.Content = content; } - return await HttpSendAsync(urlPart, parameters, request, cancellationToken); + // don't return Task as request will be disposed + return await HttpSendAsync(urlPart, parameters, request, cancellationToken).VhConfigureAwait(); } protected async Task HttpSendAsync(string urlPart, Dictionary? parameters, HttpRequestMessage request, CancellationToken cancellationToken) { - var res = await HttpSendAsync(urlPart, parameters, request, cancellationToken); + var res = await HttpSendAsync(urlPart, parameters, request, cancellationToken).VhConfigureAwait(); return res.Text; } @@ -164,7 +166,7 @@ protected virtual async Task> HttpSendAsync(string urlPart, Dic { try { - var ret = await HttpSendAsyncImpl(urlPart, parameters, request, cancellationToken); + var ret = await HttpSendAsyncImpl(urlPart, parameters, request, cancellationToken).VhConfigureAwait(); // report the log Logger.LogInformation(LoggerEventId, "API Called. Method: {Method}, Uri: {RequestUri} => StatusCode: {StatusCode}.", @@ -205,10 +207,10 @@ private async Task> HttpSendAsyncImpl(string urlPart, Dictionar } var client = HttpClient ?? throw new Exception("HttpClient has not been set."); - await PrepareRequestAsync(client, request, urlBuilder, cancellationToken).ConfigureAwait(false); + await PrepareRequestAsync(client, request, urlBuilder, cancellationToken).VhConfigureAwait(); - using var response = await HttpClientSendAsync(client, request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + using var response = await HttpClientSendAsync(client, request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).VhConfigureAwait(); var headers = response.Headers.ToDictionary(h => h.Key, h => h.Value); // ReSharper disable once ConditionalAccessQualifierIsNonNullableAccordingToAPIContract @@ -218,7 +220,7 @@ private async Task> HttpSendAsyncImpl(string urlPart, Dictionar headers[item.Key] = item.Value; } - await ProcessResponseAsync(client, response, cancellationToken).ConfigureAwait(false); + await ProcessResponseAsync(client, response, cancellationToken).VhConfigureAwait(); var status = (int)response.StatusCode; if (status is >= 200 and < 300) @@ -226,14 +228,14 @@ private async Task> HttpSendAsyncImpl(string urlPart, Dictionar if (typeof(T) == typeof(HttpNoResult)) return new HttpResult { ResponseMessage = response, Object = default!, Text= string.Empty }; - var objectResponse = await ReadObjectResponseAsync(response, headers, cancellationToken).ConfigureAwait(false); + var objectResponse = await ReadObjectResponseAsync(response, headers, cancellationToken).VhConfigureAwait(); if (objectResponse.Object == null) throw new ApiException("Response was null which was not expected.", status, objectResponse.Text, headers, null); return objectResponse!; } - var responseData = response.Content != null ? await response.Content.ReadAsStringAsync().ConfigureAwait(false) : null; + var responseData = response.Content != null ? await response.Content.ReadAsStringAsync().VhConfigureAwait() : null; throw new ApiException("The HTTP status code of the response was not expected (" + status + ").", status, responseData, headers, null); } @@ -284,9 +286,9 @@ protected Task HttpPatchAsync(string urlPart, Dictionary? param return HttpSendAsync(HttpMethod.Patch, urlPart, parameters, data, cancellationToken); } - protected async Task HttpDeleteAsync(string urlPart, + protected Task HttpDeleteAsync(string urlPart, Dictionary? parameters = null, CancellationToken cancellationToken = default) { - await HttpSendAsync(HttpMethod.Delete, urlPart, parameters, null, cancellationToken); + return HttpSendAsync(HttpMethod.Delete, urlPart, parameters, null, cancellationToken); } } \ No newline at end of file diff --git a/VpnHood.Common/Collections/TaskCollection.cs b/VpnHood.Common/Collections/TaskCollection.cs index 0b6f8c82c..cfc076f8b 100644 --- a/VpnHood.Common/Collections/TaskCollection.cs +++ b/VpnHood.Common/Collections/TaskCollection.cs @@ -1,5 +1,6 @@ using System.Collections.Concurrent; using VpnHood.Common.Jobs; +using VpnHood.Common.Utils; namespace VpnHood.Common.Collections; @@ -23,7 +24,7 @@ public void Add(ValueTask valueTask) } public async ValueTask DisposeAsync() { - await Task.WhenAll(_tasks.Keys); + await Task.WhenAll(_tasks.Keys).VhConfigureAwait(); } public Task RunJob() diff --git a/VpnHood.Common/Ga4Tracking/Ga4Tracker.cs b/VpnHood.Common/Ga4Tracking/Ga4Tracker.cs index b1fe67771..05cc3fb8b 100644 --- a/VpnHood.Common/Ga4Tracking/Ga4Tracker.cs +++ b/VpnHood.Common/Ga4Tracking/Ga4Tracker.cs @@ -3,6 +3,7 @@ using System.Runtime.InteropServices; using System.Text.Json; using System.Text.RegularExpressions; +using VpnHood.Common.Utils; // ReSharper disable once CheckNamespace namespace Ga4.Ga4Tracking; @@ -206,23 +207,25 @@ private async Task SendHttpRequest(HttpRequestMessage requestMessage, string nam { if (IsLogEnabled) { - await Console.Out.WriteLineAsync($"* Sending {name}..."); - await Console.Out.WriteLineAsync($"Url: {requestMessage.RequestUri}"); - await Console.Out.WriteLineAsync($"Headers: {JsonSerializer.Serialize(requestMessage.Headers, new JsonSerializerOptions { WriteIndented = true })}"); + await Console.Out.WriteLineAsync($"* Sending {name}...").VhConfigureAwait(); + await Console.Out.WriteLineAsync($"Url: {requestMessage.RequestUri}").VhConfigureAwait(); + await Console.Out.WriteLineAsync( + $"Headers: {JsonSerializer.Serialize(requestMessage.Headers, new JsonSerializerOptions { WriteIndented = true })}").VhConfigureAwait(); } if (jsonData != null) { requestMessage.Content = new StringContent(JsonSerializer.Serialize(jsonData)); requestMessage.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json"); - await Console.Out.WriteLineAsync($"Data: {JsonSerializer.Serialize(jsonData, new JsonSerializerOptions { WriteIndented = true })}"); + await Console.Out.WriteLineAsync( + $"Data: {JsonSerializer.Serialize(jsonData, new JsonSerializerOptions { WriteIndented = true })}").VhConfigureAwait(); } - var res = await HttpClient.SendAsync(requestMessage); + var res = await HttpClient.SendAsync(requestMessage).VhConfigureAwait(); if (IsLogEnabled) { - await Console.Out.WriteLineAsync("Result: "); - await Console.Out.WriteLineAsync(await res.Content.ReadAsStringAsync()); + await Console.Out.WriteLineAsync("Result: ").VhConfigureAwait(); + await Console.Out.WriteLineAsync(await res.Content.ReadAsStringAsync()).VhConfigureAwait(); } } catch diff --git a/VpnHood.Common/IpLocations/Providers/IpApiCoLocationProvider.cs b/VpnHood.Common/IpLocations/Providers/IpApiCoLocationProvider.cs index 550ad983c..b16b10c5e 100644 --- a/VpnHood.Common/IpLocations/Providers/IpApiCoLocationProvider.cs +++ b/VpnHood.Common/IpLocations/Providers/IpApiCoLocationProvider.cs @@ -49,9 +49,9 @@ private static async Task GetLocation(HttpClient httpClient, Uri url // get json from the service provider var requestMessage = new HttpRequestMessage(HttpMethod.Get, url); requestMessage.Headers.Add("User-Agent", userAgent); - var responseMessage = await httpClient.SendAsync(requestMessage); + var responseMessage = await httpClient.SendAsync(requestMessage).VhConfigureAwait(); responseMessage.EnsureSuccessStatusCode(); - var json = await responseMessage.Content.ReadAsStringAsync(); + var json = await responseMessage.Content.ReadAsStringAsync().VhConfigureAwait(); var apiLocation = VhUtil.JsonDeserialize(json); var ipLocation = new IpLocation diff --git a/VpnHood.Common/Jobs/JobRunner.cs b/VpnHood.Common/Jobs/JobRunner.cs index 6524eb51f..70dbf067b 100644 --- a/VpnHood.Common/Jobs/JobRunner.cs +++ b/VpnHood.Common/Jobs/JobRunner.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using VpnHood.Common.Utils; namespace VpnHood.Common.Jobs; @@ -87,7 +88,7 @@ private async Task RunJobs(IEnumerable jobs) // run jobs foreach (var job in jobs) { - await _semaphore.WaitAsync(); + await _semaphore.WaitAsync().VhConfigureAwait(); _ = RunJob(job); } } @@ -104,7 +105,7 @@ private async Task RunJob(IJob job) // run the job try { - await job.RunJob(); + await job.RunJob().VhConfigureAwait(); } catch (ObjectDisposedException) { diff --git a/VpnHood.Common/Jobs/JobSection.cs b/VpnHood.Common/Jobs/JobSection.cs index d0ffcb5fd..32805e6d7 100644 --- a/VpnHood.Common/Jobs/JobSection.cs +++ b/VpnHood.Common/Jobs/JobSection.cs @@ -51,7 +51,7 @@ public async Task Enter(Func action, bool force = false) if (!jobLock.IsEntered) return false; - await action(); + await action().VhConfigureAwait(); return true; } diff --git a/VpnHood.Common/Net/IPAddressUtil.cs b/VpnHood.Common/Net/IPAddressUtil.cs index 28d04673b..97c2d9656 100644 --- a/VpnHood.Common/Net/IPAddressUtil.cs +++ b/VpnHood.Common/Net/IPAddressUtil.cs @@ -3,6 +3,7 @@ using System.Net.Sockets; using System.Numerics; using System.Text.Json; +using VpnHood.Common.Utils; namespace VpnHood.Common.Net; @@ -23,7 +24,7 @@ public static async Task GetPrivateIpAddresses() var ipV4Task = GetPrivateIpAddress(AddressFamily.InterNetwork); var ipV6Task = GetPrivateIpAddress(AddressFamily.InterNetworkV6); - await Task.WhenAll(ipV4Task, ipV6Task); + await Task.WhenAll(ipV4Task, ipV6Task).VhConfigureAwait(); if (ipV4Task.Result != null) ret.Add(ipV4Task.Result); if (ipV6Task.Result != null) ret.Add(ipV6Task.Result); @@ -37,11 +38,11 @@ public static async Task IsIpv6Supported() { // it may throw error if IPv6 is not supported before creating task var ping = new Ping(); - var ping1 = ping.SendPingAsync("2001:4860:4860::8888"); + var pingTask = ping.SendPingAsync("2001:4860:4860::8888"); var ping2 = ping.SendPingAsync("2001:4860:4860::8844"); try { - if ((await ping1).Status == IPStatus.Success) + if ((await pingTask.VhConfigureAwait()).Status == IPStatus.Success) return true; } catch @@ -51,7 +52,7 @@ public static async Task IsIpv6Supported() try { - if ((await ping2).Status == IPStatus.Success) + if ((await ping2.VhConfigureAwait()).Status == IPStatus.Success) return true; } catch @@ -72,8 +73,8 @@ public static async Task GetPublicIpAddresses() var ret = new List(); //note: api.ipify.org may not work in parallel call - var ipV4Task = await GetPublicIpAddress(AddressFamily.InterNetwork, TimeSpan.FromSeconds(10)); - var ipV6Task = await GetPublicIpAddress(AddressFamily.InterNetworkV6, TimeSpan.FromSeconds(4)); + var ipV4Task = await GetPublicIpAddress(AddressFamily.InterNetwork, TimeSpan.FromSeconds(10)).VhConfigureAwait(); + var ipV6Task = await GetPublicIpAddress(AddressFamily.InterNetworkV6, TimeSpan.FromSeconds(4)).VhConfigureAwait(); if (ipV4Task != null) ret.Add(ipV4Task); if (ipV6Task != null) ret.Add(ipV6Task); @@ -117,7 +118,7 @@ public static async Task GetPublicIpAddresses() using var httpClient = new HttpClient(handler); httpClient.Timeout = timeout ?? TimeSpan.FromSeconds(5); - var json = await httpClient.GetStringAsync(url); + var json = await httpClient.GetStringAsync(url).VhConfigureAwait(); var document = JsonDocument.Parse(json); var ipString = document.RootElement.GetProperty("ip").GetString(); var ipAddress = IPAddress.Parse(ipString ?? throw new InvalidOperationException()); @@ -162,6 +163,11 @@ public static int Compare(IPAddress ipAddress1, IPAddress ipAddress2) Verify(ipAddress1); Verify(ipAddress2); + // resolve mapped addresses + if (ipAddress1.IsIPv4MappedToIPv6) ipAddress1 = ipAddress1.MapToIPv4(); + if (ipAddress2.IsIPv4MappedToIPv6) ipAddress2 = ipAddress2.MapToIPv4(); + + // compare if (ipAddress1.AddressFamily == AddressFamily.InterNetwork && ipAddress2.AddressFamily == AddressFamily.InterNetworkV6) return -1; diff --git a/VpnHood.Common/Net/IpGroup.cs b/VpnHood.Common/Net/IpGroup.cs deleted file mode 100644 index c05449b34..000000000 --- a/VpnHood.Common/Net/IpGroup.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace VpnHood.Common.Net; - -public class IpGroup(string ipGroupId, string ipGroupName) -{ - public string IpGroupId { get; set; } = ipGroupId; - public string IpGroupName { get; set; } = ipGroupName; -} \ No newline at end of file diff --git a/VpnHood.Common/Net/IpGroupManager.cs b/VpnHood.Common/Net/IpGroupManager.cs deleted file mode 100644 index fec859bd6..000000000 --- a/VpnHood.Common/Net/IpGroupManager.cs +++ /dev/null @@ -1,132 +0,0 @@ -using System.Net; -using System.Numerics; -using System.Text.Json; -using System.Text.RegularExpressions; -using Microsoft.Extensions.Logging; -using VpnHood.Common.Logging; - -namespace VpnHood.Common.Net; - -public class IpGroupManager -{ - private readonly string _ipGroupsFilePath; - - public IpGroup[] IpGroups { get; private set; } = []; - private readonly Dictionary _ipRangeGroups = new(); - private IpRange[]? _sortedIpRanges; - - public IpGroupManager(string ipGroupsFilePath) - { - _ipGroupsFilePath = ipGroupsFilePath; - try - { - IpGroups = JsonSerializer.Deserialize(File.ReadAllText(ipGroupsFilePath)) - ?? throw new FormatException($"Could deserialize {ipGroupsFilePath}!"); - } - catch - { - // ignored - } - } - - private string IpGroupsFolderPath => Path.Combine(Path.GetDirectoryName(_ipGroupsFilePath)!, "ipgroups"); - - public async Task AddFromIp2Location(Stream ipLocationsStream) - { - // extract IpGroups - Dictionary ipGroupNetworks = new(); - using var streamReader = new StreamReader(ipLocationsStream); - while (!streamReader.EndOfStream) - { - var line = await streamReader.ReadLineAsync(); - var items = line.Replace("\"", "").Split(','); - if (items.Length != 4) - continue; - - var ipGroupId = items[2].ToLower(); - if (ipGroupId == "-") continue; - if (ipGroupId == "um") ipGroupId = "us"; - - if (!ipGroupNetworks.TryGetValue(ipGroupId, out var ipGroupNetwork)) - { - var ipGroupName = ipGroupId switch - { - "us" => "United States", - "gb" => "United Kingdom", - _ => items[3] - }; - ipGroupName = Regex.Replace(ipGroupName, @"\(.*?\)", "").Replace(" ", " "); - - ipGroupNetwork = new IpGroupNetwork(ipGroupId, ipGroupName); - ipGroupNetworks.Add(ipGroupId, ipGroupNetwork); - } - - var ip1 = new IPAddress(BigInteger.Parse(items[0]).ToByteArray(true, true)); - var ip2 = new IPAddress(BigInteger.Parse(items[1]).ToByteArray(true, true)); - var ipRange = new IpRange(ip1, ip2); - ipGroupNetwork.IpRanges.Add(ipRange); - } - - //generating files - VhLogger.Instance.LogTrace($"Generating IpGroups files. IpGroupCount: {ipGroupNetworks.Count}"); - Directory.CreateDirectory(IpGroupsFolderPath); - foreach (var item in ipGroupNetworks) - { - var ipGroup = item.Value; - var filePath = Path.Combine(IpGroupsFolderPath, $"{ipGroup.IpGroupId}.json"); - await using var fileStream = File.Create(filePath); - await JsonSerializer.SerializeAsync(fileStream, ipGroup.IpRanges); - } - - // creating IpGroups - IpGroups = IpGroups.Concat(ipGroupNetworks.Values.Select(x => new IpGroup(x.IpGroupId, x.IpGroupName))) - .ToArray(); - _sortedIpRanges = null; - - // save - await File.WriteAllTextAsync(_ipGroupsFilePath, JsonSerializer.Serialize(IpGroups)); - } - - public async Task GetIpRanges(string ipGroupId) - { - var filePath = Path.Combine(IpGroupsFolderPath, $"{ipGroupId}.json"); - var json = await File.ReadAllTextAsync(filePath); - return JsonSerializer.Deserialize(json) ?? throw new Exception($"Could not deserialize {filePath}!"); - } - - private readonly SemaphoreSlim _sortedIpRangesSemaphore = new(1, 1); - private async Task LoadIpRangeGroups() - { - // load all groups - try - { - await _sortedIpRangesSemaphore.WaitAsync(); - _ipRangeGroups.Clear(); - List ipRanges = []; - foreach (var ipGroup in IpGroups) - foreach (var ipRange in await GetIpRanges(ipGroup.IpGroupId)) - { - ipRanges.Add(ipRange); - _ipRangeGroups.Add(ipRange, ipGroup); - } - _sortedIpRanges = IpRange.Sort(ipRanges, false).ToArray(); - } - finally - { - _sortedIpRangesSemaphore.Release(); - } - } - - public async Task FindIpGroup(IPAddress ipAddress) - { - await LoadIpRangeGroups(); - var findIpRange = IpRange.FindInSortedRanges(_sortedIpRanges!, ipAddress); - return findIpRange != null ? _ipRangeGroups[findIpRange] : null; - } - - private class IpGroupNetwork(string ipGroupId, string ipGroupName) - : IpGroup(ipGroupId, ipGroupName) - { - public List IpRanges { get; } = []; - } -} \ No newline at end of file diff --git a/VpnHood.Common/Net/IpNetwork.cs b/VpnHood.Common/Net/IpNetwork.cs index ed2886a9c..5e38360da 100644 --- a/VpnHood.Common/Net/IpNetwork.cs +++ b/VpnHood.Common/Net/IpNetwork.cs @@ -51,6 +51,7 @@ public IpNetwork(IPAddress prefix, int prefixLength) public static IpNetwork[] LoopbackNetworksV4 { get; } = [Parse("127.0.0.0/8")]; public static IpNetwork[] LoopbackNetworksV6 { get; } = [Parse("::1/128")]; + public static IpNetwork[] LoopbackNetworks { get; } = LoopbackNetworksV4.Concat(LoopbackNetworksV6).ToArray(); public static IpNetwork AllV6 { get; } = Parse("::/0"); public static IpNetwork AllGlobalUnicastV6 { get; } = Parse("2000::/3"); public static IpNetwork[] LocalNetworksV6 { get; } = AllGlobalUnicastV6.Invert().ToArray(); @@ -58,17 +59,7 @@ public IpNetwork(IPAddress prefix, int prefixLength) public static IpNetwork[] All { get; } = [AllV4, AllV6]; public static IpNetwork[] None { get; } = []; - public static bool IsAll(IOrderedEnumerable ipNetworks) - { - return ipNetworks.SequenceEqual(All); - } - - public static IEnumerable FromIpRange(IpRange ipRange) - { - return FromIpRange(ipRange.FirstIpAddress, ipRange.LastIpAddress); - } - - public static IEnumerable FromIpRange(IPAddress firstIpAddress, IPAddress lastIpAddress) + public static IEnumerable FromRange(IPAddress firstIpAddress, IPAddress lastIpAddress) { if (firstIpAddress.AddressFamily != lastIpAddress.AddressFamily) throw new ArgumentException("AddressFamilies don't match!"); @@ -107,7 +98,17 @@ public static IEnumerable FromIpRange(IPAddress firstIpAddress, IPAdd public IOrderedEnumerable Invert() { - return Invert(new[] { this }, AddressFamily == AddressFamily.InterNetwork, AddressFamily == AddressFamily.InterNetworkV6); + return new[] { this } + .ToIpRanges() + .Invert( + includeIPv4: AddressFamily == AddressFamily.InterNetwork, + includeIPv6: AddressFamily == AddressFamily.InterNetworkV6) + .ToIpNetworks(); + } + + public IpRange ToIpRange() + { + return new IpRange(FirstIpAddress, LastIpAddress); } public static IpNetwork Parse(string value) @@ -123,50 +124,6 @@ public static IpNetwork Parse(string value) } } - public static IOrderedEnumerable Sort(IEnumerable ipNetworks) - { - return FromIpRange(ToIpRange(ipNetworks)); - } - - public static IOrderedEnumerable Invert(IEnumerable ipNetworks, bool includeIPv4 = true, bool includeIPv6 = true) - { - return FromIpRange(IpRange.Invert(ToIpRange(ipNetworks), includeIPv4, includeIPv6)); - } - - public static IOrderedEnumerable Intersect(IEnumerable ipNetworks1, IEnumerable ipNetworks2) - { - return FromIpRange(IpRange.Intersect(ToIpRange(ipNetworks1), ToIpRange(ipNetworks2))); - } - - public static IOrderedEnumerable Union(IEnumerable ipNetworks1, IEnumerable ipNetworks2) - { - return FromIpRange(IpRange.Union(ToIpRange(ipNetworks1), ToIpRange(ipNetworks2))); - } - - public static IOrderedEnumerable Exclude(IEnumerable ipNetworks, IEnumerable excludeIpNetworks) - { - return FromIpRange(IpRange.Exclude(ToIpRange(ipNetworks), ToIpRange(excludeIpNetworks))); - } - - public IpRange ToIpRange() - { - return new IpRange(FirstIpAddress, LastIpAddress); - } - - public static IEnumerable ToIpRange(IEnumerable ipNetworks) - { - return ipNetworks.Select(x => x.ToIpRange()).Sort(); - } - - public static IOrderedEnumerable FromIpRange(IEnumerable ipRanges) - { - var ipNetworks = new List(); - foreach (var ipRange in IpRange.Sort(ipRanges)) - ipNetworks.AddRange(FromIpRange(ipRange)); - - return ipNetworks.OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - } - public override string ToString() { return $"{Prefix}/{PrefixLength}"; diff --git a/VpnHood.Common/Net/IpNetworkExtension.cs b/VpnHood.Common/Net/IpNetworkExtension.cs index 67d76fea9..d1b7235f7 100644 --- a/VpnHood.Common/Net/IpNetworkExtension.cs +++ b/VpnHood.Common/Net/IpNetworkExtension.cs @@ -1,88 +1,23 @@ -using System.Net; - -namespace VpnHood.Common.Net; +namespace VpnHood.Common.Net; public static class IpNetworkExtension { public static IOrderedEnumerable Sort(this IEnumerable ipNetworks) { - return IpNetwork.Sort(ipNetworks); - } - - public static IEnumerable ToIpRanges(this IEnumerable ipNetworks) - { - return IpNetwork.ToIpRange(ipNetworks); + return ipNetworks + .ToIpRanges() + .ToIpNetworks(); } - public static IOrderedEnumerable ToIpNetworks(this IEnumerable ipRanges) + public static IpRangeOrderedList ToIpRanges(this IEnumerable ipNetworks) { - return IpNetwork.FromIpRange(ipRanges); - } - - public static IEnumerable ToIpNetworks(this IpRange ipRange) - { - return IpNetwork.FromIpRange(ipRange.FirstIpAddress, ipRange.LastIpAddress); - } - - public static IOrderedEnumerable Invert(this IEnumerable ipNetworks, bool includeIPv4 = true, bool includeIPv6 = true) - { - return IpNetwork.Invert(ipNetworks, includeIPv4, includeIPv6); - } - - public static IOrderedEnumerable Intersect(this IEnumerable ipNetworks1, IEnumerable ipNetworks2) - { - return IpNetwork.Intersect(ipNetworks1, ipNetworks2); - } - - public static IOrderedEnumerable Union(IEnumerable ipNetworks1, IEnumerable ipNetworks2) - { - return IpNetwork.Union(ipNetworks1, ipNetworks2); - } - - public static IOrderedEnumerable Exclude(this IEnumerable ipNetworks, IEnumerable excludeIpNetworks) - { - return IpNetwork.Exclude(ipNetworks, excludeIpNetworks); + return ipNetworks + .Select(x => x.ToIpRange()) + .ToOrderedList(); } public static bool IsAll(this IOrderedEnumerable ipNetworks) { - return IpNetwork.IsAll(ipNetworks); - } - - - public static IOrderedEnumerable Exclude(this IEnumerable ipRanges, IEnumerable excludeIpRanges) - { - return IpRange.Exclude(ipRanges, excludeIpRanges); - } - - public static IOrderedEnumerable Sort(this IEnumerable ipRanges, bool unify = true) - { - return IpRange.Sort(ipRanges, unify); - } - - public static IOrderedEnumerable Invert(this IEnumerable ipRanges, - bool includeIPv4 = true, bool includeIPv6 = true) - { - return IpRange.Invert(ipRanges, includeIPv4, includeIPv6); - } - - public static bool IsInSortedRanges(this IpRange[] sortedIpRanges, IPAddress ipAddress) - { - return IpRange.IsInSortedRanges(sortedIpRanges, ipAddress); - } - - public static IpRange? FindInSortedRanges(this IpRange[] sortedIpRanges, IPAddress ipAddress) - { - return IpRange.FindInSortedRanges(sortedIpRanges, ipAddress); - } - - public static IOrderedEnumerable Intersect(this IEnumerable ipRanges1, IEnumerable ipRanges2) - { - return IpRange.Intersect(ipRanges1, ipRanges2); - } - - public static IOrderedEnumerable Union(this IEnumerable ipRanges1, IEnumerable ipRanges2) - { - return IpRange.Union(ipRanges1, ipRanges2); + return ipNetworks.SequenceEqual(IpNetwork.All); } } \ No newline at end of file diff --git a/VpnHood.Common/Net/IpRange.cs b/VpnHood.Common/Net/IpRange.cs index 638e9657f..c624b2a2c 100644 --- a/VpnHood.Common/Net/IpRange.cs +++ b/VpnHood.Common/Net/IpRange.cs @@ -31,6 +31,7 @@ public IpRange(IPAddress firstIpAddress, IPAddress lastIpAddress) LastIpAddress = lastIpAddress; } + public static IpRange FromIpAddress(IPAddress ipAddress) => new(ipAddress); public bool IsIPv4MappedToIPv6 => FirstIpAddress.IsIPv4MappedToIPv6; public IpRange MapToIPv4() => new (FirstIpAddress.MapToIPv4(), LastIpAddress.MapToIPv4()); public IpRange MapToIPv6() => new (FirstIpAddress.MapToIPv6(), LastIpAddress.MapToIPv6()); @@ -38,86 +39,7 @@ public IpRange(IPAddress firstIpAddress, IPAddress lastIpAddress) public IPAddress FirstIpAddress { get; } public IPAddress LastIpAddress { get; } public BigInteger Total => new BigInteger(LastIpAddress.GetAddressBytes(), true, true) - new BigInteger(FirstIpAddress.GetAddressBytes(), true, true) + 1; - - public static IOrderedEnumerable Sort(IEnumerable ipRanges, bool unify = true) - { - var sortedRanges = ipRanges.OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - return unify ? Unify(sortedRanges) : sortedRanges; - } - - private static IOrderedEnumerable Unify(IEnumerable sortedIpRanges) - { - List res = []; - foreach (var ipRange in sortedIpRanges) - { - if (res.Count > 0 && - ipRange.AddressFamily == res[^1].AddressFamily && - IPAddressUtil.Compare(IPAddressUtil.Decrement(ipRange.FirstIpAddress), res[^1].LastIpAddress) <= 0) - { - if (IPAddressUtil.Compare(ipRange.LastIpAddress, res[^1].LastIpAddress) > 0) - res[^1] = new IpRange(res[^1].FirstIpAddress, ipRange.LastIpAddress); - } - else - { - res.Add(ipRange); - } - } - - return res.OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - } - - public static IOrderedEnumerable Invert(IEnumerable ipRanges, bool includeIPv4 = true, bool includeIPv6 = true) - { - var list = new List(); - - // IP4 - if (includeIPv4) - { - var ipRanges2 = ipRanges.Where(x => x.AddressFamily == AddressFamily.InterNetwork); - if (ipRanges2.Any()) - list.AddRange(InvertInternal(ipRanges2)); - else - list.Add(new IpRange(IPAddressUtil.MinIPv4Value, IPAddressUtil.MaxIPv4Value)); - } - - // IP6 - if (includeIPv6) - { - var ipRanges2 = ipRanges.Where(x => x.AddressFamily == AddressFamily.InterNetworkV6); - if (ipRanges2.Any()) - list.AddRange(InvertInternal(ipRanges2)); - else - list.Add(new IpRange(IPAddressUtil.MinIPv6Value, IPAddressUtil.MaxIPv6Value)); - } - - return Sort(list); - } - - private static IEnumerable InvertInternal(IEnumerable ipRanges) - { - // sort - var ipRangesSorted = Sort(ipRanges).ToArray(); - - // extract - List res = []; - for (var i = 0; i < ipRangesSorted.Length; i++) - { - var ipRange = ipRangesSorted[i]; - var minIpValue = ipRange.AddressFamily == AddressFamily.InterNetworkV6 ? IPAddressUtil.MinIPv6Value : IPAddressUtil.MinIPv4Value; - var maxIpValue = ipRange.AddressFamily == AddressFamily.InterNetworkV6 ? IPAddressUtil.MaxIPv6Value : IPAddressUtil.MaxIPv4Value; - - if (i == 0 && !IPAddressUtil.IsMinValue(ipRange.FirstIpAddress)) - res.Add(new IpRange(minIpValue, IPAddressUtil.Decrement(ipRange.FirstIpAddress))); - - if (i > 0) - res.Add(new IpRange(IPAddressUtil.Increment(ipRangesSorted[i - 1].LastIpAddress), IPAddressUtil.Decrement(ipRange.FirstIpAddress))); - - if (i == ipRangesSorted.Length - 1 && !IPAddressUtil.IsMaxValue(ipRange.LastIpAddress)) - res.Add(new IpRange(IPAddressUtil.Increment(ipRange.LastIpAddress), maxIpValue)); - } - - return res; - } + public IEnumerable ToIpNetworks() => IpNetwork.FromRange(FirstIpAddress, LastIpAddress); public static IpRange Parse(string value) { @@ -155,102 +77,4 @@ public bool IsInRange(IPAddress ipAddress) IPAddressUtil.Compare(ipAddress, FirstIpAddress) >= 0 && IPAddressUtil.Compare(ipAddress, LastIpAddress) <= 0; } - - /// - /// Search in ipRanges using binary search - /// - /// a sorted ipRanges - /// search value - /// - public static bool IsInSortedRanges(IpRange[] sortedIpRanges, IPAddress ipAddress) - { - return FindInSortedRanges(sortedIpRanges, ipAddress) != null; - } - - /// a sorted ipRanges - /// search value - public static IpRange? FindInSortedRanges(IpRange[] sortedIpRanges, IPAddress ipAddress) - { - var res = Array.BinarySearch(sortedIpRanges, new IpRange(ipAddress, ipAddress), new IpRangeSearchComparer()); - return res >= 0 && res < sortedIpRanges.Length ? sortedIpRanges[res] : null; - } - - public static IOrderedEnumerable Union(IEnumerable ipRanges1, IEnumerable ipRanges2) - { - return Sort(ipRanges1.Concat(ipRanges2)); - } - - public static IOrderedEnumerable Exclude(IEnumerable ipRanges, IEnumerable excludeIpRanges) - { - return Intersect(ipRanges, Invert(excludeIpRanges)); - } - - public static IOrderedEnumerable Intersect(IEnumerable ipRanges1, IEnumerable ipRanges2) - { - // ReSharper disable once PossibleMultipleEnumeration - var v4SortedRanges1 = ipRanges1 - .Where(x => x.AddressFamily == AddressFamily.InterNetwork) - .OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - - // ReSharper disable once PossibleMultipleEnumeration - var v4SortedRanges2 = ipRanges2 - .Where(x => x.AddressFamily == AddressFamily.InterNetwork) - .OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - - // ReSharper disable once PossibleMultipleEnumeration - var v6SortedRanges1 = ipRanges1 - .Where(x => x.AddressFamily == AddressFamily.InterNetworkV6) - .OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - - // ReSharper disable once PossibleMultipleEnumeration - var v6SortedRanges2 = ipRanges2 - .Where(x => x.AddressFamily == AddressFamily.InterNetworkV6) - .OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); - - - var ipRangesV4 = IntersectInternal(v4SortedRanges1, v4SortedRanges2); - var ipRangesV6 = IntersectInternal(v6SortedRanges1, v6SortedRanges2); - var ret = ipRangesV4.Concat(ipRangesV6); - return Sort(ret); - } - - private static IEnumerable IntersectInternal(IEnumerable ipRanges1, - IEnumerable ipRanges2) - { - ipRanges1 = Sort(ipRanges1); - ipRanges2 = Sort(ipRanges2); - - var ipRanges = new List(); - foreach (var ipRange1 in ipRanges1) - foreach (var ipRange2 in ipRanges2) - { - if (ipRange1.IsInRange(ipRange2.FirstIpAddress)) - ipRanges.Add(new IpRange(ipRange2.FirstIpAddress, - IPAddressUtil.Min(ipRange1.LastIpAddress, ipRange2.LastIpAddress))); - - else if (ipRange1.IsInRange(ipRange2.LastIpAddress)) - ipRanges.Add(new IpRange(IPAddressUtil.Max(ipRange1.FirstIpAddress, ipRange2.FirstIpAddress), - ipRange2.LastIpAddress)); - - else if (ipRange2.IsInRange(ipRange1.FirstIpAddress)) - ipRanges.Add(new IpRange(ipRange1.FirstIpAddress, - IPAddressUtil.Min(ipRange1.LastIpAddress, ipRange2.LastIpAddress))); - - else if (ipRange2.IsInRange(ipRange1.LastIpAddress)) - ipRanges.Add(new IpRange(IPAddressUtil.Max(ipRange1.FirstIpAddress, ipRange2.FirstIpAddress), - ipRange1.LastIpAddress)); - } - - return ipRanges; - } - - private class IpRangeSearchComparer : IComparer - { - public int Compare(IpRange x, IpRange y) - { - if (IPAddressUtil.Compare(x.FirstIpAddress, y.FirstIpAddress) <= 0 && IPAddressUtil.Compare(x.LastIpAddress, y.LastIpAddress) >= 0) return 0; - if (IPAddressUtil.Compare(x.FirstIpAddress, y.FirstIpAddress) < 0) return -1; - return +1; - } - } } \ No newline at end of file diff --git a/VpnHood.Common/Net/IpRangeExtension.cs b/VpnHood.Common/Net/IpRangeExtension.cs new file mode 100644 index 000000000..e6775ab79 --- /dev/null +++ b/VpnHood.Common/Net/IpRangeExtension.cs @@ -0,0 +1,15 @@ +namespace VpnHood.Common.Net; + +public static class IpRangeExtension +{ + public static IpRangeOrderedList ToOrderedList(this IEnumerable ipRanges) + { + return new IpRangeOrderedList(ipRanges); + } + + public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) + { + // prevent use Linq.Intersect in mistake. it is bug prone. + throw new NotSupportedException($"Use {nameof(IpRangeOrderedList)}.Intersect."); + } +} \ No newline at end of file diff --git a/VpnHood.Common/Net/IpRangeOrderedList.cs b/VpnHood.Common/Net/IpRangeOrderedList.cs new file mode 100644 index 000000000..fdbc24032 --- /dev/null +++ b/VpnHood.Common/Net/IpRangeOrderedList.cs @@ -0,0 +1,285 @@ +using System.Collections; +using System.Net; +using System.Net.Sockets; + + +namespace VpnHood.Common.Net; + +public class IpRangeOrderedList : + IOrderedEnumerable, + IReadOnlyList +{ + private readonly List _orderedList; + public IOrderedEnumerable CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending) => + descending ? _orderedList.OrderByDescending(keySelector, comparer) : _orderedList.OrderBy(keySelector, comparer); + + public IEnumerator GetEnumerator() => _orderedList.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public int Count => _orderedList.Count; + public static IpRangeOrderedList Empty { get; } = new([]); + + public IpRange this[int index] => _orderedList[index]; + + public IpRangeOrderedList() + { + _orderedList = []; + } + + public IpRangeOrderedList(IEnumerable ipRanges) + { + _orderedList = Sort(ipRanges); + } + + private IpRangeOrderedList(List orderedList) + { + _orderedList = orderedList; + } + + public void Serialize(Stream stream) + { + // serialize to binary + using var writer = new BinaryWriter(stream); + writer.Write(_orderedList.Count); + foreach (var range in _orderedList) + { + var firstIpBytes = range.FirstIpAddress.GetAddressBytes(); + var lastIpBytes = range.LastIpAddress.GetAddressBytes(); + + writer.Write((byte)firstIpBytes.Length); + writer.Write(firstIpBytes); + writer.Write((byte)lastIpBytes.Length); + writer.Write(lastIpBytes); + } + } + + public static IpRangeOrderedList Deserialize(Stream stream) + { + using var reader = new BinaryReader(stream); + var length = reader.ReadInt32(); + var ipRanges = new IpRange[length]; + for (var i = 0; i < length; i++) + { + var firstIpLength = reader.ReadByte(); + var firstIpBytes = reader.ReadBytes(firstIpLength); + var lastIpLength = reader.ReadByte(); + var lastIpBytes = reader.ReadBytes(lastIpLength); + ipRanges[i] = new IpRange(new IPAddress(firstIpBytes), new IPAddress(lastIpBytes)); + } + + return new IpRangeOrderedList(ipRanges); + } + + public bool IsAll() + { + // use ToIpRanges for All to improve performance + return IpNetwork.All + .ToIpRanges() + .SequenceEqual(this); + } + + public bool IsInRange(IPAddress ipAddress) + { + if (ipAddress.IsIPv4MappedToIPv6) ipAddress = ipAddress.MapToIPv4(); + var res = _orderedList.BinarySearch(new IpRange(ipAddress, ipAddress), new IpRangeSearchComparer()); + return res >= 0; + } + + public bool IsNone() + { + return _orderedList.Count == 0; + } + + public IOrderedEnumerable ToIpNetworks() + { + var ipNetworkList = new List(); + foreach (var ipRange in _orderedList) + ipNetworkList.AddRange(ipRange.ToIpNetworks()); + + return ipNetworkList.OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); + } + + public IpRangeOrderedList Union(IEnumerable ipRanges) + { + // ReSharper disable PossibleMultipleEnumeration + return ipRanges.Any() ? new IpRangeOrderedList(_orderedList.Concat(ipRanges)) : this; + // ReSharper restore PossibleMultipleEnumeration + } + + public IpRangeOrderedList Exclude(IPAddress ipAddress) + { + return Exclude(new IpRange(ipAddress)); + } + + public IpRangeOrderedList Exclude(IpRange ipRange) + { + return Exclude(new[] { ipRange }); + } + + public IpRangeOrderedList Exclude(IEnumerable ipRanges) + { + return Exclude(ipRanges.ToOrderedList()); + } + + public IpRangeOrderedList Exclude(IpRangeOrderedList ipRanges) + { + return Intersect(ipRanges.Invert()); + } + + public IpRangeOrderedList Intersect(IEnumerable ipRanges) + { + return Intersect(ipRanges.ToOrderedList()); + } + + public IpRangeOrderedList Intersect(IpRangeOrderedList ipRanges) + { + return Intersect(this, ipRanges); + } + + public IpRangeOrderedList Invert(bool includeIPv4 = true, bool includeIPv6 = true) + { + return Invert(this, includeIPv4: includeIPv4, includeIPv6: includeIPv6); + } + + private static List Sort(IEnumerable ipRanges) + { + var sortedRanges = ipRanges.OrderBy(x => x.FirstIpAddress, new IPAddressComparer()); + return Unify(sortedRanges); + } + + private static List Unify(IOrderedEnumerable sortedIpRanges) + { + List res = []; + foreach (var ipRange in sortedIpRanges) + { + if (res.Count > 0 && + ipRange.AddressFamily == res[^1].AddressFamily && + IPAddressUtil.Compare(IPAddressUtil.Decrement(ipRange.FirstIpAddress), res[^1].LastIpAddress) <= 0) + { + if (IPAddressUtil.Compare(ipRange.LastIpAddress, res[^1].LastIpAddress) > 0) + res[^1] = new IpRange(res[^1].FirstIpAddress, ipRange.LastIpAddress); + } + else + { + res.Add(ipRange); + } + } + + return res; + } + + private static IpRangeOrderedList Intersect(IpRangeOrderedList ipRanges1, IpRangeOrderedList ipRanges2) + { + // performance optimization + if (ipRanges1.IsAll()) return ipRanges2; + if (ipRanges2.IsAll()) return ipRanges1; + + var v4SortedRanges1 = ipRanges1 + .Where(x => x.AddressFamily == AddressFamily.InterNetwork); + + var v4SortedRanges2 = ipRanges2 + .Where(x => x.AddressFamily == AddressFamily.InterNetwork); + + var v6SortedRanges1 = ipRanges1 + .Where(x => x.AddressFamily == AddressFamily.InterNetworkV6); + + var v6SortedRanges2 = ipRanges2 + .Where(x => x.AddressFamily == AddressFamily.InterNetworkV6); + + //all range are ordered as the following process does not change the order + var ipRangesV4 = IntersectInternal(v4SortedRanges1, v4SortedRanges2.ToArray()); + var ipRangesV6 = IntersectInternal(v6SortedRanges1, v6SortedRanges2.ToArray()); + var ret = ipRangesV4.Concat(ipRangesV6); + + return new IpRangeOrderedList(ret); + } + + private static IEnumerable IntersectInternal( + IEnumerable orderedIpRanges1, + IpRange[] orderedIpRanges2) + { + var ipRanges = new List(); + foreach (var ipRange1 in orderedIpRanges1) + foreach (var ipRange2 in orderedIpRanges2) + { + if (ipRange1.IsInRange(ipRange2.FirstIpAddress)) + ipRanges.Add(new IpRange(ipRange2.FirstIpAddress, + IPAddressUtil.Min(ipRange1.LastIpAddress, ipRange2.LastIpAddress))); + + else if (ipRange1.IsInRange(ipRange2.LastIpAddress)) + ipRanges.Add(new IpRange(IPAddressUtil.Max(ipRange1.FirstIpAddress, ipRange2.FirstIpAddress), + ipRange2.LastIpAddress)); + + else if (ipRange2.IsInRange(ipRange1.FirstIpAddress)) + ipRanges.Add(new IpRange(ipRange1.FirstIpAddress, + IPAddressUtil.Min(ipRange1.LastIpAddress, ipRange2.LastIpAddress))); + + else if (ipRange2.IsInRange(ipRange1.LastIpAddress)) + ipRanges.Add(new IpRange(IPAddressUtil.Max(ipRange1.FirstIpAddress, ipRange2.FirstIpAddress), + ipRange1.LastIpAddress)); + } + + return ipRanges; + } + + private static IpRangeOrderedList Invert(IpRangeOrderedList ipRanges, + bool includeIPv4 = true, bool includeIPv6 = true) + { + //it is ordered as the following process does not change the order + var newIpRanges = new List(); + + // IP4 + if (includeIPv4) + { + var ipRanges2 = ipRanges.Where(x => x.AddressFamily == AddressFamily.InterNetwork).ToArray(); + if (ipRanges2.Any()) + newIpRanges.AddRange(InvertInternal(ipRanges2)); + else + newIpRanges.Add(new IpRange(IPAddressUtil.MinIPv4Value, IPAddressUtil.MaxIPv4Value)); + } + + // IP6 + if (includeIPv6) + { + var ipRanges2 = ipRanges.Where(x => x.AddressFamily == AddressFamily.InterNetworkV6).ToArray(); + if (ipRanges2.Any()) + newIpRanges.AddRange(InvertInternal(ipRanges2)); + else + newIpRanges.Add(new IpRange(IPAddressUtil.MinIPv6Value, IPAddressUtil.MaxIPv6Value)); + } + + return new IpRangeOrderedList(newIpRanges); + } + + private static IEnumerable InvertInternal(IpRange[] orderedIpRanges) + { + // extract + List res = []; + for (var i = 0; i < orderedIpRanges.Length; i++) + { + var ipRange = orderedIpRanges[i]; + var minIpValue = ipRange.AddressFamily == AddressFamily.InterNetworkV6 ? IPAddressUtil.MinIPv6Value : IPAddressUtil.MinIPv4Value; + var maxIpValue = ipRange.AddressFamily == AddressFamily.InterNetworkV6 ? IPAddressUtil.MaxIPv6Value : IPAddressUtil.MaxIPv4Value; + + if (i == 0 && !IPAddressUtil.IsMinValue(ipRange.FirstIpAddress)) + res.Add(new IpRange(minIpValue, IPAddressUtil.Decrement(ipRange.FirstIpAddress))); + + if (i > 0) + res.Add(new IpRange(IPAddressUtil.Increment(orderedIpRanges[i - 1].LastIpAddress), IPAddressUtil.Decrement(ipRange.FirstIpAddress))); + + if (i == orderedIpRanges.Length - 1 && !IPAddressUtil.IsMaxValue(ipRange.LastIpAddress)) + res.Add(new IpRange(IPAddressUtil.Increment(ipRange.LastIpAddress), maxIpValue)); + } + + return res; + } + + private class IpRangeSearchComparer : IComparer + { + public int Compare(IpRange x, IpRange y) + { + if (IPAddressUtil.Compare(x.FirstIpAddress, y.FirstIpAddress) <= 0 && IPAddressUtil.Compare(x.LastIpAddress, y.LastIpAddress) >= 0) return 0; + if (IPAddressUtil.Compare(x.FirstIpAddress, y.FirstIpAddress) < 0) return -1; + return +1; + } + } +} \ No newline at end of file diff --git a/VpnHood.Common/ServerLocationInfo.cs b/VpnHood.Common/ServerLocationInfo.cs index f0e9a00eb..7dcdddb93 100644 --- a/VpnHood.Common/ServerLocationInfo.cs +++ b/VpnHood.Common/ServerLocationInfo.cs @@ -14,7 +14,7 @@ public class ServerLocationInfo : IComparable public int CompareTo(ServerLocationInfo other) { - var countryComparison = string.Compare(CountryCode, other.CountryCode, StringComparison.OrdinalIgnoreCase); + var countryComparison = string.Compare(CountryName, other.CountryName, StringComparison.OrdinalIgnoreCase); return countryComparison != 0 ? countryComparison : string.Compare(RegionName, other.RegionName, StringComparison.OrdinalIgnoreCase); } diff --git a/VpnHood.Common/Utils/AsyncLock.cs b/VpnHood.Common/Utils/AsyncLock.cs index 0fedba14a..93ad93ce3 100644 --- a/VpnHood.Common/Utils/AsyncLock.cs +++ b/VpnHood.Common/Utils/AsyncLock.cs @@ -43,13 +43,13 @@ public void Dispose() public async Task LockAsync(CancellationToken cancellationToken = default) { - await _semaphoreSlimEx.WaitAsync(cancellationToken); + await _semaphoreSlimEx.WaitAsync(cancellationToken).VhConfigureAwait(); return new SemaphoreLock(_semaphoreSlimEx, true, null); } public async Task LockAsync(TimeSpan timeout, CancellationToken cancellationToken = default) { - var succeeded = await _semaphoreSlimEx.WaitAsync(timeout, cancellationToken); + var succeeded = await _semaphoreSlimEx.WaitAsync(timeout, cancellationToken).VhConfigureAwait(); return new SemaphoreLock(_semaphoreSlimEx, succeeded, null); } @@ -69,7 +69,7 @@ public static async Task LockAsync(string name, TimeSpan timeo try { - var succeeded = await semaphoreSlim.WaitAsync(timeout, cancellationToken); + var succeeded = await semaphoreSlim.WaitAsync(timeout, cancellationToken).VhConfigureAwait(); return new SemaphoreLock(semaphoreSlim, succeeded, name); } catch diff --git a/VpnHood.Common/Utils/ReadCacheStream.cs b/VpnHood.Common/Utils/ReadCacheStream.cs index 089033943..db761b3f4 100644 --- a/VpnHood.Common/Utils/ReadCacheStream.cs +++ b/VpnHood.Common/Utils/ReadCacheStream.cs @@ -34,12 +34,12 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, { // read directly to user buffer if there is no buffer, and it is larger than cache if (_cacheRemain == 0 && count > _cache.Length) - return await base.ReadAsync(buffer, offset, count, cancellationToken); + return await base.ReadAsync(buffer, offset, count, cancellationToken).VhConfigureAwait(); // fill cache if (_cacheRemain == 0 && count <= _cache.Length) { - _cacheRemain = await base.ReadAsync(_cache, 0, _cache.Length, cancellationToken); + _cacheRemain = await base.ReadAsync(_cache, 0, _cache.Length, cancellationToken).VhConfigureAwait(); _cacheOffset = 0; } diff --git a/VpnHood.Common/Utils/VhTaskExtensions.cs b/VpnHood.Common/Utils/VhTaskExtensions.cs new file mode 100644 index 000000000..11b2d23cd --- /dev/null +++ b/VpnHood.Common/Utils/VhTaskExtensions.cs @@ -0,0 +1,25 @@ +namespace VpnHood.Common.Utils; + +public static class VhTaskExtensions +{ + public static async Task VhConfigureAwait(this Task task) + { + await task.ConfigureAwait(false); + } + + public static async Task VhConfigureAwait(this Task task) + { + return await task.ConfigureAwait(false); + } + + public static async ValueTask VhConfigureAwait(this ValueTask task) + { + await task.ConfigureAwait(false); + } + + public static async ValueTask VhConfigureAwait(this ValueTask task) + { + return await task.ConfigureAwait(false); + } + +} \ No newline at end of file diff --git a/VpnHood.Common/Utils/VhTestUtil.cs b/VpnHood.Common/Utils/VhTestUtil.cs index 542c5466f..c7d9a8ff7 100644 --- a/VpnHood.Common/Utils/VhTestUtil.cs +++ b/VpnHood.Common/Utils/VhTestUtil.cs @@ -28,7 +28,7 @@ private static async Task WaitForValue(TValue expectedValue, Fun if (Equals(expectedValue, actualValue)) return actualValue; - await Task.Delay(waitTime); + await Task.Delay(waitTime).VhConfigureAwait(); actualValue = valueFactory(); } @@ -38,14 +38,14 @@ private static async Task WaitForValue(TValue expectedValue, Fun private static async Task WaitForValue(TValue expectedValue, Func> valueFactory, int timeout = 5000) { const int waitTime = 100; - var actualValue = await valueFactory(); + var actualValue = await valueFactory().VhConfigureAwait(); for (var elapsed = 0; elapsed < timeout; elapsed += waitTime) { if (Equals(expectedValue, actualValue)) return actualValue; - await Task.Delay(waitTime); - actualValue = await valueFactory(); + await Task.Delay(waitTime).VhConfigureAwait(); + actualValue = await valueFactory().VhConfigureAwait(); } return actualValue; @@ -62,21 +62,21 @@ private static void AssertEquals(object? expected, object? actual, string? messa public static async Task AssertEqualsWait(TValue expectedValue, Func valueFactory, string? message = null, int timeout = 5000) { - var actualValue = await WaitForValue(expectedValue, valueFactory, timeout); + var actualValue = await WaitForValue(expectedValue, valueFactory, timeout).VhConfigureAwait(); AssertEquals(expectedValue, actualValue, message); } public static async Task AssertEqualsWait(TValue expectedValue, Func> valueFactory, string? message = null, int timeout = 5000) { - var actualValue = await WaitForValue(expectedValue, valueFactory, timeout); + var actualValue = await WaitForValue(expectedValue, valueFactory, timeout).VhConfigureAwait(); AssertEquals(expectedValue, actualValue, message); } public static async Task AssertEqualsWait(TValue expectedValue, Task task, string? message = null, int timeout = 5000) { - var actualValue = await WaitForValue(expectedValue, () => task, timeout); + var actualValue = await WaitForValue(expectedValue, () => task, timeout).VhConfigureAwait(); AssertEquals(expectedValue, actualValue, message); } @@ -96,7 +96,7 @@ public static async Task AssertApiException(int expectedStatusCode, Task task, { try { - await task; + await task.VhConfigureAwait(); throw new AssertException($"Expected {expectedStatusCode} but the actual was OK. {message}"); } catch (ApiException ex) @@ -119,7 +119,7 @@ public static async Task AssertApiException(string expectedExceptionType, Task t { try { - await task; + await task.VhConfigureAwait(); throw new AssertException($"Expected {expectedExceptionType} exception but was OK. {message}"); } catch (ApiException ex) @@ -142,7 +142,7 @@ public static async Task AssertNotExistsException(Task task, string? message = n { try { - await task; + await task.VhConfigureAwait(); throw new AssertException($"Expected kind of {nameof(NotExistsException)} but was OK. {message}"); } catch (ApiException ex) @@ -165,7 +165,7 @@ public static async Task AssertAlreadyExistsException(Task task, string? message { try { - await task; + await task.VhConfigureAwait(); throw new AssertException($"Expected kind of {nameof(AlreadyExistsException)} but was OK. {message}"); } catch (ApiException ex) diff --git a/VpnHood.Common/Utils/VhUtil.cs b/VpnHood.Common/Utils/VhUtil.cs index 214920a9b..79e8a9512 100644 --- a/VpnHood.Common/Utils/VhUtil.cs +++ b/VpnHood.Common/Utils/VhUtil.cs @@ -117,8 +117,8 @@ public static T[] SafeToArray(object lockObject, IEnumerable collection) public static async Task RunTask(Task task, TimeSpan timeout = default, CancellationToken cancellationToken = default) { - await RunTask((Task)task, timeout, cancellationToken); - return await task; + await RunTask((Task)task, timeout, cancellationToken).VhConfigureAwait(); + return await task.VhConfigureAwait(); } public static async Task RunTask(Task task, TimeSpan timeout = default, CancellationToken cancellationToken = default) @@ -127,13 +127,13 @@ public static async Task RunTask(Task task, TimeSpan timeout = default, Cancella timeout = Timeout.InfiniteTimeSpan; var timeoutTask = Task.Delay(timeout, cancellationToken); - await Task.WhenAny(task, timeoutTask); + await Task.WhenAny(task, timeoutTask).VhConfigureAwait(); cancellationToken.ThrowIfCancellationRequested(); if (timeoutTask.IsCompleted) throw new TimeoutException(); - await task; + await task.VhConfigureAwait(); } public static bool IsNullOrEmpty([NotNullWhen(false)] IEnumerable? array) @@ -225,7 +225,7 @@ public static bool JsonEquals(object? obj1, object? obj2) return JsonSerializer.Serialize(obj1) == JsonSerializer.Serialize(obj2); } - public static T JsonClone(object obj, JsonSerializerOptions? options = null) + public static T JsonClone(T obj, JsonSerializerOptions? options = null) { var json = JsonSerializer.Serialize(obj, options); return JsonDeserialize(json, options); @@ -417,12 +417,12 @@ public static async Task ParallelForEachAsync(IEnumerable source, Func x.IsCompleted).ToArray()) tasks.Remove(completedTask); } } - await Task.WhenAll(tasks); + await Task.WhenAll(tasks).VhConfigureAwait(); } public static bool TryDeleteFile(string filePath) diff --git a/VpnHood.Common/VpnHood.Common.csproj b/VpnHood.Common/VpnHood.Common.csproj index 5c600a659..24af32b30 100644 --- a/VpnHood.Common/VpnHood.Common.csproj +++ b/VpnHood.Common/VpnHood.Common.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Server.Access/Configurations/NetFilterOptions.cs b/VpnHood.Server.Access/Configurations/NetFilterOptions.cs index 2dcf5c015..c44d6cf98 100644 --- a/VpnHood.Server.Access/Configurations/NetFilterOptions.cs +++ b/VpnHood.Server.Access/Configurations/NetFilterOptions.cs @@ -31,7 +31,7 @@ public void ApplyDefaults() BlockIpV6 = BlockIpV6Value; } - public IEnumerable GetFinalIncludeIpRanges() + public IpRangeOrderedList GetFinalIncludeIpRanges() { var includeIpRanges = IpNetwork.All.ToIpRanges(); if (!VhUtil.IsNullOrEmpty(IncludeIpRanges)) @@ -58,11 +58,11 @@ public IEnumerable GetFinalPacketCaptureIncludeIpRanges() return packetCaptureIncludeIpRanges; } - public IEnumerable GetBlockedIpRanges() + public IpRangeOrderedList GetBlockedIpRanges() { var includeIpRanges = GetFinalIncludeIpRanges().Intersect(GetFinalPacketCaptureIncludeIpRanges()); if (BlockIpV6Value) - includeIpRanges = includeIpRanges.Exclude(new[] { IpNetwork.AllV6.ToIpRange() }); + includeIpRanges = includeIpRanges.Exclude(IpNetwork.AllV6.ToIpRange()); return includeIpRanges.Invert(); } diff --git a/VpnHood.Server.Access/Managers/File/FileAccessManager.cs b/VpnHood.Server.Access/Managers/File/FileAccessManager.cs index 20c030d65..30da18471 100644 --- a/VpnHood.Server.Access/Managers/File/FileAccessManager.cs +++ b/VpnHood.Server.Access/Managers/File/FileAccessManager.cs @@ -128,9 +128,9 @@ private byte[] LoadServerSecret() if (string.IsNullOrEmpty(serverLocation) && ServerConfig.UseExternalLocationService) { var ipLocationProvider = new IpLocationProviderFactory().CreateDefault("VpnHood-Server"); - var ipLocation = await ipLocationProvider.GetLocation(new HttpClient()); + var ipLocation = await ipLocationProvider.GetLocation(new HttpClient()).VhConfigureAwait(); serverLocation = IpLocationProviderFactory.GetPath(ipLocation.CountryCode, ipLocation.RegionName, ipLocation.CityName); - await System.IO.File.WriteAllTextAsync(serverCountryFile, serverLocation); + await System.IO.File.WriteAllTextAsync(serverCountryFile, serverLocation).VhConfigureAwait(); } VhLogger.Instance.LogInformation("ServerLocation: {ServerLocation}", serverLocation ?? "Unknown"); @@ -180,7 +180,7 @@ public virtual Task Server_Configure(ServerInfo serverInfo) public virtual async Task Session_Create(SessionRequestEx sessionRequestEx) { - var accessItem = await AccessItem_Read(sessionRequestEx.TokenId); + var accessItem = await AccessItem_Read(sessionRequestEx.TokenId).VhConfigureAwait(); if (accessItem == null) return new SessionResponseEx { @@ -214,7 +214,7 @@ public virtual async Task Session_Get(ulong sessionId, IPEndP }; // read accessItem - var accessItem = await AccessItem_Read(tokenId); + var accessItem = await AccessItem_Read(tokenId).VhConfigureAwait(); if (accessItem == null) return new SessionResponseEx { @@ -256,7 +256,7 @@ private async Task Session_AddUsage(ulong sessionId, Traffic tr }; // read accessItem - var accessItem = await AccessItem_Read(tokenId); + var accessItem = await AccessItem_Read(tokenId).VhConfigureAwait(); if (accessItem == null) return new SessionResponse { @@ -265,7 +265,7 @@ private async Task Session_AddUsage(ulong sessionId, Traffic tr }; accessItem.AccessUsage.Traffic += traffic; - await WriteAccessItemUsage(accessItem); + await WriteAccessItemUsage(accessItem).VhConfigureAwait(); if (closeSession) SessionController.CloseSession(sessionId); @@ -323,7 +323,7 @@ public async Task AccessItem_LoadAll() foreach (var file in files) { - var accessItem = await AccessItem_Read(Path.GetFileNameWithoutExtension(file)); + var accessItem = await AccessItem_Read(Path.GetFileNameWithoutExtension(file)).VhConfigureAwait(); if (accessItem != null) accessItems.Add(accessItem); } @@ -384,7 +384,7 @@ public AccessItem AccessItem_Create( public async Task AccessItem_Delete(string tokenId) { // remove index - _ = await AccessItem_Read(tokenId) + _ = await AccessItem_Read(tokenId).VhConfigureAwait() ?? throw new KeyNotFoundException("Could not find tokenId"); // delete files @@ -398,14 +398,14 @@ public async Task AccessItem_Delete(string tokenId) { // read access item var fileName = GetAccessItemFileName(tokenId); - using var fileLock = await AsyncLock.LockAsync(fileName); + using var fileLock = await AsyncLock.LockAsync(fileName).VhConfigureAwait(); if (!System.IO.File.Exists(fileName)) return null; - var json = await System.IO.File.ReadAllTextAsync(fileName); + var json = await System.IO.File.ReadAllTextAsync(fileName).VhConfigureAwait(); var accessItem = VhUtil.JsonDeserialize(json); accessItem.Token.ServerToken = _serverToken; // update server token - await ReadAccessItemUsage(accessItem); + await ReadAccessItemUsage(accessItem).VhConfigureAwait(); return accessItem; } @@ -424,10 +424,10 @@ private async Task ReadAccessItemUsage(AccessItem accessItem) try { var fileName = GetUsageFileName(accessItem.Token.TokenId); - using var fileLock = await AsyncLock.LockAsync(fileName); + using var fileLock = await AsyncLock.LockAsync(fileName).VhConfigureAwait(); if (System.IO.File.Exists(fileName)) { - var json = await System.IO.File.ReadAllTextAsync(fileName); + var json = await System.IO.File.ReadAllTextAsync(fileName).VhConfigureAwait(); var accessItemUsage = JsonSerializer.Deserialize(json) ?? new AccessItemUsage(); accessItem.AccessUsage.Traffic = new Traffic { Sent = accessItemUsage.SentTraffic, Received = accessItemUsage.ReceivedTraffic }; } @@ -451,8 +451,8 @@ private async Task WriteAccessItemUsage(AccessItem accessItem) // write accessItem var fileName = GetUsageFileName(accessItem.Token.TokenId); - using var fileLock = await AsyncLock.LockAsync(fileName); - await System.IO.File.WriteAllTextAsync(fileName, json); + using var fileLock = await AsyncLock.LockAsync(fileName).VhConfigureAwait(); + await System.IO.File.WriteAllTextAsync(fileName, json).VhConfigureAwait(); } public class AccessItem diff --git a/VpnHood.Server.Access/Managers/Http/HttpAccessManager.cs b/VpnHood.Server.Access/Managers/Http/HttpAccessManager.cs index 71896e69d..0b414c2a8 100644 --- a/VpnHood.Server.Access/Managers/Http/HttpAccessManager.cs +++ b/VpnHood.Server.Access/Managers/Http/HttpAccessManager.cs @@ -47,7 +47,7 @@ protected override async Task> HttpSendAsync(string urlPart, Di { try { - return await base.HttpSendAsync(urlPart, parameters, request, cancellationToken); + return await base.HttpSendAsync(urlPart, parameters, request, cancellationToken).VhConfigureAwait(); } catch (Exception ex) when (VhUtil.IsConnectionRefusedException(ex)) { diff --git a/VpnHood.Server.Access/VpnHood.Server.Access.csproj b/VpnHood.Server.Access/VpnHood.Server.Access.csproj index d5b42a3b1..4bb5a44d6 100644 --- a/VpnHood.Server.Access/VpnHood.Server.Access.csproj +++ b/VpnHood.Server.Access/VpnHood.Server.Access.csproj @@ -20,7 +20,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Server.App.Net/FileAccessManagerCommand.cs b/VpnHood.Server.App.Net/FileAccessManagerCommand.cs index f73e1b1ec..5210b0800 100644 --- a/VpnHood.Server.App.Net/FileAccessManagerCommand.cs +++ b/VpnHood.Server.App.Net/FileAccessManagerCommand.cs @@ -1,6 +1,7 @@ using System.Text.Json; using McMaster.Extensions.CommandLineUtils; using VpnHood.Common; +using VpnHood.Common.Utils; using VpnHood.Server.Access.Managers.File; namespace VpnHood.Server.App; @@ -20,14 +21,14 @@ private void PrintToken(CommandLineApplication cmdApp) var tokenIdArg = cmdApp.Argument("tokenId", "tokenId to print"); cmdApp.OnExecuteAsync(async _ => { - await PrintToken(tokenIdArg.Value!); + await PrintToken(tokenIdArg.Value!).VhConfigureAwait(); return 0; }); } private async Task PrintToken(string tokenId) { - var accessItem = await fileAccessManager.AccessItem_Read(tokenId); + var accessItem = await fileAccessManager.AccessItem_Read(tokenId).VhConfigureAwait(); if (accessItem == null) throw new KeyNotFoundException($"Token does not exist! tokenId: {tokenId}"); var hostName = accessItem.Token.ServerToken.HostName + (accessItem.Token.ServerToken.IsValidHostName ? "" : " (Fake)"); var endPoints = accessItem.Token.ServerToken.HostEndPoints?.Select(x => x.ToString()) ?? Array.Empty(); @@ -71,7 +72,7 @@ private void GenerateToken(CommandLineApplication cmdApp) ); Console.WriteLine("The following token has been generated: "); - await PrintToken(accessItem.Token.TokenId); + await PrintToken(accessItem.Token.TokenId).VhConfigureAwait(); Console.WriteLine($"Store Token Count: {accessManager.AccessItem_Count()}"); return 0; }); diff --git a/VpnHood.Server.App.Net/Program.cs b/VpnHood.Server.App.Net/Program.cs index 59be2bf0c..6c24cff3f 100644 --- a/VpnHood.Server.App.Net/Program.cs +++ b/VpnHood.Server.App.Net/Program.cs @@ -1,4 +1,6 @@ -namespace VpnHood.Server.App; +using VpnHood.Common.Utils; + +namespace VpnHood.Server.App; internal class Program { @@ -7,7 +9,7 @@ private static async Task Main(string[] args) try { using var serverApp = new ServerApp(); - await serverApp.Start(args); + await serverApp.Start(args).VhConfigureAwait(); } catch (Exception ex) { diff --git a/VpnHood.Server.App.Net/ServerApp.cs b/VpnHood.Server.App.Net/ServerApp.cs index fec5fdb58..d4db42fb7 100644 --- a/VpnHood.Server.App.Net/ServerApp.cs +++ b/VpnHood.Server.App.Net/ServerApp.cs @@ -246,9 +246,9 @@ private void StartServer(CommandLineApplication cmdApp) _commandListener.Start(); // start server - await _vpnHoodServer.Start(); + await _vpnHoodServer.Start().VhConfigureAwait(); while (_vpnHoodServer.State != ServerState.Disposed) - await Task.Delay(1000, cancellationToken); + await Task.Delay(1000, cancellationToken).VhConfigureAwait(); return 0; }); } @@ -289,6 +289,6 @@ public async Task Start(string[] args) new FileAccessManagerCommand(FileAccessManager) .AddCommands(cmdApp); - await cmdApp.ExecuteAsync(args); + await cmdApp.ExecuteAsync(args).VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Server.App.Net/VpnHood.Server.App.Net.csproj b/VpnHood.Server.App.Net/VpnHood.Server.App.Net.csproj index 3765a32e0..b75d1b30e 100644 --- a/VpnHood.Server.App.Net/VpnHood.Server.App.Net.csproj +++ b/VpnHood.Server.App.Net/VpnHood.Server.App.Net.csproj @@ -22,7 +22,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) VpnHoodServer diff --git a/VpnHood.Server/Http01ChallengeService.cs b/VpnHood.Server/Http01ChallengeService.cs index 0fa6e73f7..be074b212 100644 --- a/VpnHood.Server/Http01ChallengeService.cs +++ b/VpnHood.Server/Http01ChallengeService.cs @@ -2,6 +2,7 @@ using System.Net.Sockets; using Microsoft.Extensions.Logging; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; using VpnHood.Tunneling; using VpnHood.Tunneling.Utils; @@ -44,10 +45,10 @@ private async Task AcceptTcpClient(TcpListener tcpListener, CancellationToken ca { while (IsStarted && !cancellationToken.IsCancellationRequested) { - using var client = await tcpListener.AcceptTcpClientAsync(); + using var client = await tcpListener.AcceptTcpClientAsync().VhConfigureAwait(); try { - await HandleRequest(client, token, keyAuthorization, cancellationToken); + await HandleRequest(client, token, keyAuthorization, cancellationToken).VhConfigureAwait(); } catch (Exception ex) { @@ -59,7 +60,7 @@ private async Task AcceptTcpClient(TcpListener tcpListener, CancellationToken ca private static async Task HandleRequest(TcpClient client, string token, string keyAuthorization, CancellationToken cancellationToken) { await using var stream = client.GetStream(); - var headers = await HttpUtil.ParseHeadersAsync(stream, cancellationToken) + var headers = await HttpUtil.ParseHeadersAsync(stream, cancellationToken).VhConfigureAwait() ?? throw new Exception("Connection has been closed before receiving any request."); if (!headers.Any()) return; @@ -74,8 +75,8 @@ private static async Task HandleRequest(TcpClient client, string token, string k ? HttpResponseBuilder.Http01(keyAuthorization) : HttpResponseBuilder.NotFound(); - await stream.WriteAsync(response, 0, response.Length, cancellationToken); - await stream.FlushAsync(cancellationToken); + await stream.WriteAsync(response, 0, response.Length, cancellationToken).VhConfigureAwait(); + await stream.FlushAsync(cancellationToken).VhConfigureAwait(); } // use dispose diff --git a/VpnHood.Server/INetFilter.cs b/VpnHood.Server/INetFilter.cs index 5f0ce5da8..3d9ed3c07 100644 --- a/VpnHood.Server/INetFilter.cs +++ b/VpnHood.Server/INetFilter.cs @@ -6,7 +6,7 @@ namespace VpnHood.Server; public interface INetFilter { - public IpRange[] BlockedIpRanges { get; set; } + public IpRangeOrderedList BlockedIpRanges { get; set; } public IPPacket? ProcessRequest(IPPacket ipPacket); public IPEndPoint? ProcessRequest(ProtocolType protocol, IPEndPoint requestEndPoint); public IPPacket ProcessReply(IPPacket ipPacket); diff --git a/VpnHood.Server/NetFilter.cs b/VpnHood.Server/NetFilter.cs index 3cb426465..801d11edc 100644 --- a/VpnHood.Server/NetFilter.cs +++ b/VpnHood.Server/NetFilter.cs @@ -6,23 +6,23 @@ namespace VpnHood.Server; public class NetFilter : INetFilter { - private readonly IpRange[] _loopbackIpRange = IpNetwork.ToIpRange(IpNetwork.LoopbackNetworksV4.Concat(IpNetwork.LoopbackNetworksV6)).ToArray(); - private IpRange[] _sortedBlockedIpRanges = []; + private readonly IpRangeOrderedList _loopbackIpRange = IpNetwork.LoopbackNetworks.ToIpRanges(); + private IpRangeOrderedList _blockedIpRanges = new([]); public NetFilter() { BlockedIpRanges = _loopbackIpRange; } - public IpRange[] BlockedIpRanges + public IpRangeOrderedList BlockedIpRanges { - get => _sortedBlockedIpRanges; - set => _sortedBlockedIpRanges = value.Concat(_loopbackIpRange).Sort().ToArray(); + get => _blockedIpRanges; + set => _blockedIpRanges = _loopbackIpRange.Union(value); } - public virtual bool IsIpAddressBlocked(IPAddress ipAddress) + private bool IsIpAddressBlocked(IPAddress ipAddress) { - return IpRange.IsInSortedRanges(BlockedIpRanges, ipAddress); + return BlockedIpRanges.IsInRange(ipAddress); } // ReSharper disable once ReturnTypeCanBeNotNullable diff --git a/VpnHood.Server/ServerHost.cs b/VpnHood.Server/ServerHost.cs index ad99a53ef..1be45621a 100644 --- a/VpnHood.Server/ServerHost.cs +++ b/VpnHood.Server/ServerHost.cs @@ -55,14 +55,14 @@ public async Task Configure(IPEndPoint[] tcpEndPoints, IPEndPoint[] udpEndPoints throw new ObjectDisposedException(GetType().Name); // wait for last configure to finish - using var lockResult = await _configureLock.LockAsync(_cancellationTokenSource.Token); + using var lockResult = await _configureLock.LockAsync(_cancellationTokenSource.Token).VhConfigureAwait(); // reconfigure DnsServers = dnsServers; Certificates = certificates.Select(x => new CertificateHostName(x)).ToArray(); // Configure - await Task.WhenAll(ConfigureUdpListeners(udpEndPoints), ConfigureTcpListeners(tcpEndPoints)); + await Task.WhenAll(ConfigureUdpListeners(udpEndPoints), ConfigureTcpListeners(tcpEndPoints)).VhConfigureAwait(); _tcpListenerTasks.RemoveAll(x => x.IsCompleted); } @@ -163,7 +163,7 @@ private async Task ListenTask(TcpListener tcpListener, CancellationToken cancell { try { - var tcpClient = await tcpListener.AcceptTcpClientAsync(); + var tcpClient = await tcpListener.AcceptTcpClientAsync().VhConfigureAwait(); if (_disposed) throw new ObjectDisposedException("ServerHost has been stopped."); @@ -221,7 +221,8 @@ await sslStream.AuthenticateAsServerAsync( CertificateRevocationCheckMode = X509RevocationMode.NoCheck, ServerCertificateSelectionCallback = ServerCertificateSelectionCallback }, - cancellationToken); + cancellationToken) + .VhConfigureAwait(); return sslStream; } catch (Exception ex) @@ -257,7 +258,7 @@ private async Task CreateClientStream(TcpClient tcpClient, Stream try { var headers = - await HttpUtil.ParseHeadersAsync(sslStream, cancellationToken) + await HttpUtil.ParseHeadersAsync(sslStream, cancellationToken).VhConfigureAwait() ?? throw new Exception("Connection has been closed before receiving any request."); // int.TryParse(headers.GetValueOrDefault("X-Version", "0"), out var xVersion); @@ -272,12 +273,12 @@ await HttpUtil.ParseHeadersAsync(sslStream, cancellationToken) if (authorization != "ApiKey") throw new UnauthorizedAccessException(); - await sslStream.WriteAsync(HttpResponseBuilder.Unauthorized(), cancellationToken); + await sslStream.WriteAsync(HttpResponseBuilder.Unauthorized(), cancellationToken).VhConfigureAwait(); return new TcpClientStream(tcpClient, sslStream, streamId); } // use binary stream only for authenticated clients - await sslStream.WriteAsync(HttpResponseBuilder.Ok(), cancellationToken); + await sslStream.WriteAsync(HttpResponseBuilder.Ok(), cancellationToken).VhConfigureAwait(); switch (binaryStreamType) { @@ -297,7 +298,7 @@ await HttpUtil.ParseHeadersAsync(sslStream, cancellationToken) //always return BadRequest if (!VhUtil.IsTcpClientHealthy(tcpClient)) throw; var response = ex is UnauthorizedAccessException ? HttpResponseBuilder.Unauthorized() : HttpResponseBuilder.BadRequest(); - await sslStream.WriteAsync(response, cancellationToken); + await sslStream.WriteAsync(response, cancellationToken).VhConfigureAwait(); throw; } } @@ -313,13 +314,13 @@ private async Task ProcessTcpClient(TcpClient tcpClient, CancellationToken cance try { // establish SSL - var sslStream = await AuthenticateAsServerAsync(tcpClient.GetStream(), cancellationToken); + var sslStream = await AuthenticateAsServerAsync(tcpClient.GetStream(), cancellationToken).VhConfigureAwait(); // create client stream - clientStream = await CreateClientStream(tcpClient, sslStream, cancellationToken); + clientStream = await CreateClientStream(tcpClient, sslStream, cancellationToken).VhConfigureAwait(); lock (_clientStreams) _clientStreams.Add(clientStream); - await ProcessClientStream(clientStream, cancellationToken); + await ProcessClientStream(clientStream, cancellationToken).VhConfigureAwait(); } catch (TlsAuthenticateException ex) when (ex.InnerException is OperationCanceledException) { @@ -337,7 +338,7 @@ private async Task ProcessTcpClient(TcpClient tcpClient, CancellationToken cance else VhLogger.LogError(GeneralEventId.Request, ex, "ServerHost could not process this request. ClientStreamId: {ClientStreamId}", clientStream?.ClientStreamId); - if (clientStream != null) await clientStream.DisposeAsync(false); + if (clientStream != null) await clientStream.DisposeAsync(false).VhConfigureAwait(); tcpClient.Dispose(); } } @@ -363,7 +364,7 @@ private async Task ReuseClientStream(IClientStream clientStream) "ServerHost.ReuseClientStream: A shared ClientStream is pending for reuse. ClientStreamId: {ClientStreamId}", clientStream.ClientStreamId); - await ProcessClientStream(clientStream, cancellationToken); + await ProcessClientStream(clientStream, cancellationToken).VhConfigureAwait(); VhLogger.Instance.LogTrace(GeneralEventId.TcpLife, "ServerHost.ReuseClientStream: A shared ClientStream has been reused. ClientStreamId: {ClientStreamId}", @@ -375,7 +376,7 @@ private async Task ReuseClientStream(IClientStream clientStream) "ServerHost.ReuseClientStream: Could not reuse a ClientStream. ClientStreamId: {ClientStreamId}", clientStream.ClientStreamId); - await clientStream.DisposeAsync(false); + await clientStream.DisposeAsync(false).VhConfigureAwait(); } } @@ -384,14 +385,13 @@ private async Task ProcessClientStream(IClientStream clientStream, CancellationT using var scope = VhLogger.Instance.BeginScope($"RemoteEp: {VhLogger.Format(clientStream.IpEndPointPair.RemoteEndPoint)}"); try { - await ProcessRequest(clientStream, cancellationToken); + await ProcessRequest(clientStream, cancellationToken).VhConfigureAwait(); } catch (SessionException ex) { // reply the error to caller if it is SessionException // Should not reply anything when user is unknown - await StreamUtil.WriteJsonAsync(clientStream.Stream, ex.SessionResponse, - cancellationToken); + await StreamUtil.WriteJsonAsync(clientStream.Stream, ex.SessionResponse, cancellationToken).VhConfigureAwait(); if (ex is ISelfLog loggable) loggable.Log(); @@ -399,7 +399,7 @@ await StreamUtil.WriteJsonAsync(clientStream.Stream, ex.SessionResponse, VhLogger.Instance.LogInformation(ex.SessionResponse.ErrorCode == SessionErrorCode.GeneralError ? GeneralEventId.Tcp : GeneralEventId.Session, ex, "Could not process the request. SessionErrorCode: {SessionErrorCode}", ex.SessionResponse.ErrorCode); - await clientStream.DisposeAsync(); + await clientStream.DisposeAsync().VhConfigureAwait(); } catch (Exception ex) when (VhLogger.IsSocketCloseException(ex)) { @@ -407,12 +407,12 @@ await StreamUtil.WriteJsonAsync(clientStream.Stream, ex.SessionResponse, "Connection has been closed. ClientStreamId: {ClientStreamId}.", clientStream.ClientStreamId); - await clientStream.DisposeAsync(); + await clientStream.DisposeAsync().VhConfigureAwait(); } catch (Exception ex) { // return 401 for ANY non SessionException to keep server's anonymity - await clientStream.Stream.WriteAsync(HttpResponseBuilder.Unauthorized(), cancellationToken); + await clientStream.Stream.WriteAsync(HttpResponseBuilder.Unauthorized(), cancellationToken).VhConfigureAwait(); if (ex is ISelfLog loggable) loggable.Log(); @@ -420,7 +420,7 @@ await StreamUtil.WriteJsonAsync(clientStream.Stream, ex.SessionResponse, VhLogger.Instance.LogInformation(GeneralEventId.Tcp, ex, "Could not process the request and return 401. ClientStreamId: {ClientStreamId}", clientStream.ClientStreamId); - await clientStream.DisposeAsync(false); + await clientStream.DisposeAsync(false).VhConfigureAwait(); } finally { @@ -434,7 +434,7 @@ private async Task ProcessRequest(IClientStream clientStream, CancellationToken var buffer = new byte[1]; // read request version - var rest = await clientStream.Stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken); + var rest = await clientStream.Stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).VhConfigureAwait(); if (rest == 0) throw new Exception("ClientStream has been closed before receiving any request."); @@ -443,7 +443,7 @@ private async Task ProcessRequest(IClientStream clientStream, CancellationToken throw new NotSupportedException($"The request version is not supported. Version: {version}"); // read request code - var res = await clientStream.Stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken); + var res = await clientStream.Stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).VhConfigureAwait(); if (res == 0) throw new Exception("ClientStream has been closed before receiving any request."); @@ -451,35 +451,35 @@ private async Task ProcessRequest(IClientStream clientStream, CancellationToken switch (requestCode) { case RequestCode.ServerStatus: - await ProcessServerStatus(clientStream, cancellationToken); + await ProcessServerStatus(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.SessionStatus: - await ProcessSessionStatus(clientStream, cancellationToken); + await ProcessSessionStatus(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.Hello: - await ProcessHello(clientStream, cancellationToken); + await ProcessHello(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.TcpDatagramChannel: - await ProcessTcpDatagramChannel(clientStream, cancellationToken); + await ProcessTcpDatagramChannel(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.StreamProxyChannel: - await ProcessStreamProxyChannel(clientStream, cancellationToken); + await ProcessStreamProxyChannel(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.UdpPacket: - await ProcessUdpPacketRequest(clientStream, cancellationToken); + await ProcessUdpPacketRequest(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.AdReward: - await ProcessAdRewardRequest(clientStream, cancellationToken); + await ProcessAdRewardRequest(clientStream, cancellationToken).VhConfigureAwait(); break; case RequestCode.Bye: - await ProcessBye(clientStream, cancellationToken); + await ProcessBye(clientStream, cancellationToken).VhConfigureAwait(); break; default: @@ -494,7 +494,7 @@ private static async Task ReadRequest(IClientStream clientStream, Cancella "Processing a request. RequestType: {RequestType}.", VhLogger.FormatType()); - var request = await StreamUtil.ReadJsonAsync(clientStream.Stream, cancellationToken); + var request = await StreamUtil.ReadJsonAsync(clientStream.Stream, cancellationToken).VhConfigureAwait(); request.RequestId = request.RequestId.Replace(":client", ":server"); clientStream.ClientStreamId = request.RequestId; @@ -513,12 +513,12 @@ private async Task ProcessHello(IClientStream clientStream, CancellationToken ca VhLogger.Instance.LogTrace(GeneralEventId.Session, "Processing hello request... ClientIp: {ClientIp}", VhLogger.Format(ipEndPointPair.RemoteEndPoint.Address)); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); // creating a session VhLogger.Instance.LogTrace(GeneralEventId.Session, "Creating a session... TokenId: {TokenId}, ClientId: {ClientId}, ClientVersion: {ClientVersion}, UserAgent: {UserAgent}", VhLogger.FormatId(request.TokenId), VhLogger.FormatId(request.ClientInfo.ClientId), request.ClientInfo.ClientVersion, request.ClientInfo.UserAgent); - var sessionResponse = await _sessionManager.CreateSession(request, ipEndPointPair); + var sessionResponse = await _sessionManager.CreateSession(request, ipEndPointPair).VhConfigureAwait(); var session = _sessionManager.GetSessionById(sessionResponse.SessionId) ?? throw new InvalidOperationException("Session is lost!"); // check client version; unfortunately it must be after CreateSession to preserve server anonymity @@ -584,94 +584,95 @@ private async Task ProcessHello(IClientStream clientStream, CancellationToken ca AdRequirement = sessionResponse.AdRequirement, ServerLocation = sessionResponse.ServerLocation }; - await StreamUtil.WriteJsonAsync(clientStream.Stream, helloResponse, cancellationToken); - await clientStream.DisposeAsync(); + await StreamUtil.WriteJsonAsync(clientStream.Stream, helloResponse, cancellationToken).VhConfigureAwait(); + await clientStream.DisposeAsync().VhConfigureAwait(); } private async Task ProcessAdRewardRequest(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogTrace(GeneralEventId.Session, "Reading the RewardAd request..."); - var request = await ReadRequest(clientStream, cancellationToken); - var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair); - await session.ProcessAdRewardRequest(request, clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); + var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair).VhConfigureAwait(); + await session.ProcessAdRewardRequest(request, clientStream, cancellationToken).VhConfigureAwait(); } private static async Task ProcessServerStatus(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogTrace(GeneralEventId.Session, "Reading the ServerStatus request..."); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); // Before calling CloseSession. Session must be validated by GetSession - await StreamUtil.WriteJsonAsync(clientStream.Stream, + await StreamUtil.WriteJsonAsync(clientStream.Stream, new ServerStatusResponse { ErrorCode = SessionErrorCode.Ok, Message = request.Message == "Hi, How are you?" ? "I am OK. How are you?" : "OK. Who are you?" - },cancellationToken); + }, cancellationToken) + .VhConfigureAwait(); - await clientStream.DisposeAsync(); + await clientStream.DisposeAsync().VhConfigureAwait(); } private async Task ProcessSessionStatus(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogTrace(GeneralEventId.Session, "Reading the SessionStatus request..."); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); // finding session - var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair); + var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair).VhConfigureAwait(); // processing request - await session.ProcessSessionStatusRequest(request, clientStream, cancellationToken); + await session.ProcessSessionStatusRequest(request, clientStream, cancellationToken).VhConfigureAwait(); } private async Task ProcessBye(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogTrace(GeneralEventId.Session, "Reading the Bye request..."); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); // finding session using var scope = VhLogger.Instance.BeginScope($"SessionId: {VhLogger.FormatSessionId(request.SessionId)}"); - var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair); + var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair).VhConfigureAwait(); // Before calling CloseSession. Session must be validated by GetSession - await StreamUtil.WriteJsonAsync(clientStream.Stream, new SessionResponse { ErrorCode = SessionErrorCode.Ok }, cancellationToken); - await clientStream.DisposeAsync(false); + await StreamUtil.WriteJsonAsync(clientStream.Stream, new SessionResponse { ErrorCode = SessionErrorCode.Ok }, cancellationToken).VhConfigureAwait(); + await clientStream.DisposeAsync(false).VhConfigureAwait(); // must be last - await _sessionManager.CloseSession(session.SessionId); + await _sessionManager.CloseSession(session.SessionId).VhConfigureAwait(); } private async Task ProcessUdpPacketRequest(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogTrace(GeneralEventId.Session, "Reading a UdpPacket request..."); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); using var scope = VhLogger.Instance.BeginScope($"SessionId: {VhLogger.FormatSessionId(request.SessionId)}"); - var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair); - await session.ProcessUdpPacketRequest(request, clientStream, cancellationToken); + var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair).VhConfigureAwait(); + await session.ProcessUdpPacketRequest(request, clientStream, cancellationToken).VhConfigureAwait(); } private async Task ProcessTcpDatagramChannel(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogTrace(GeneralEventId.StreamProxyChannel, "Reading the TcpDatagramChannelRequest..."); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); // finding session using var scope = VhLogger.Instance.BeginScope($"SessionId: {VhLogger.FormatSessionId(request.SessionId)}"); - var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair); - await session.ProcessTcpDatagramChannelRequest(request, clientStream, cancellationToken); + var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair).VhConfigureAwait(); + await session.ProcessTcpDatagramChannelRequest(request, clientStream, cancellationToken).VhConfigureAwait(); } private async Task ProcessStreamProxyChannel(IClientStream clientStream, CancellationToken cancellationToken) { VhLogger.Instance.LogInformation(GeneralEventId.StreamProxyChannel, "Reading the StreamProxyChannelRequest..."); - var request = await ReadRequest(clientStream, cancellationToken); + var request = await ReadRequest(clientStream, cancellationToken).VhConfigureAwait(); // find session using var scope = VhLogger.Instance.BeginScope($"SessionId: {VhLogger.FormatSessionId(request.SessionId)}"); - var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair); - await session.ProcessTcpProxyRequest(request, clientStream, cancellationToken); + var session = await _sessionManager.GetSession(request, clientStream.IpEndPointPair).VhConfigureAwait(); + await session.ProcessTcpProxyRequest(request, clientStream, cancellationToken).VhConfigureAwait(); } public Task RunJob() @@ -704,13 +705,13 @@ private async Task Stop() Task[] disposeTasks; lock (_clientStreams) disposeTasks = _clientStreams.Select(x => x.DisposeAsync(false).AsTask()).ToArray(); - await Task.WhenAll(disposeTasks); + await Task.WhenAll(disposeTasks).VhConfigureAwait(); // wait for finalizing all listener tasks VhLogger.Instance.LogTrace("Disposing current processing requests..."); try { - await VhUtil.RunTask(Task.WhenAll(_tcpListenerTasks), TimeSpan.FromSeconds(15)); + await VhUtil.RunTask(Task.WhenAll(_tcpListenerTasks), TimeSpan.FromSeconds(15)).VhConfigureAwait(); } catch (Exception ex) { @@ -723,10 +724,10 @@ private async Task Stop() private readonly AsyncLock _disposeLock = new(); public async ValueTask DisposeAsync() { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; - await Stop(); + await Stop().VhConfigureAwait(); } // cache HostName for performance diff --git a/VpnHood.Server/Session.cs b/VpnHood.Server/Session.cs index aed73f202..99a9df2ae 100644 --- a/VpnHood.Server/Session.cs +++ b/VpnHood.Server/Session.cs @@ -166,7 +166,7 @@ public Task Sync() private async Task Sync(bool force, bool closeSession, string? adData = null) { - using var syncLock = await _syncLock.LockAsync(); + using var syncLock = await _syncLock.LockAsync().VhConfigureAwait(); if (SessionResponse.ErrorCode != SessionErrorCode.Ok) return; @@ -191,18 +191,18 @@ private async Task Sync(bool force, bool closeSession, string? adData = null) try { SessionResponse = closeSession - ? await _accessManager.Session_Close(SessionId, traffic) - : await _accessManager.Session_AddUsage(SessionId, traffic, adData); + ? await _accessManager.Session_Close(SessionId, traffic).VhConfigureAwait() + : await _accessManager.Session_AddUsage(SessionId, traffic, adData).VhConfigureAwait(); // dispose for any error if (SessionResponse.ErrorCode != SessionErrorCode.Ok) - await DisposeAsync(false, false); + await DisposeAsync(false, false).VhConfigureAwait(); } catch (ApiException ex) when (ex.StatusCode == (int)HttpStatusCode.NotFound) { SessionResponse.ErrorCode = SessionErrorCode.AccessError; SessionResponse.ErrorMessage = "Session Not Found."; - await DisposeAsync(false, false); + await DisposeAsync(false, false).VhConfigureAwait(); } catch (Exception ex) { @@ -253,7 +253,7 @@ public void LogTrack(string protocol, IPEndPoint? localEndPoint, IPEndPoint? des public async Task ProcessTcpDatagramChannelRequest(TcpDatagramChannelRequest request, IClientStream clientStream, CancellationToken cancellationToken) { // send OK reply - await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken); + await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken).VhConfigureAwait(); // Disable UdpChannel UseUdpChannel = false; @@ -269,7 +269,7 @@ public async Task ProcessTcpDatagramChannelRequest(TcpDatagramChannelRequest req } catch { - await channel.DisposeAsync(); + await channel.DisposeAsync().VhConfigureAwait(); throw; } } @@ -284,15 +284,15 @@ public Task ProcessUdpPacketRequest(UdpPacketRequest request, IClientStream clie public async Task ProcessSessionStatusRequest(SessionStatusRequest request, IClientStream clientStream, CancellationToken cancellationToken) { - await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken); - await clientStream.DisposeAsync(); + await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken).VhConfigureAwait(); + await clientStream.DisposeAsync().VhConfigureAwait(); } public async Task ProcessAdRewardRequest(AdRewardRequest request, IClientStream clientStream, CancellationToken cancellationToken) { - await Sync(force: true, closeSession: false, adData: request.AdData); - await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken); - await clientStream.DisposeAsync(); + await Sync(force: true, closeSession: false, adData: request.AdData).VhConfigureAwait(); + await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken).VhConfigureAwait(); + await clientStream.DisposeAsync().VhConfigureAwait(); } public async Task ProcessTcpProxyRequest(StreamProxyChannelRequest request, IClientStream clientStream, CancellationToken cancellationToken) @@ -325,7 +325,8 @@ public async Task ProcessTcpProxyRequest(StreamProxyChannelRequest request, ICli isRequestedEpException = true; await VhUtil.RunTask( tcpClientHost.ConnectAsync(request.DestinationEndPoint.Address, request.DestinationEndPoint.Port), - _tcpConnectTimeout, cancellationToken); + _tcpConnectTimeout, cancellationToken) + .VhConfigureAwait(); isRequestedEpException = false; //tracking @@ -333,7 +334,7 @@ await VhUtil.RunTask( true, true, null); // send response - await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken); + await StreamUtil.WriteJsonAsync(clientStream.Stream, SessionResponse, cancellationToken).VhConfigureAwait(); // add the connection VhLogger.Instance.LogTrace(GeneralEventId.StreamProxyChannel, @@ -347,8 +348,8 @@ await VhUtil.RunTask( catch (Exception ex) { tcpClientHost?.Dispose(); - if (tcpClientStreamHost != null) await tcpClientStreamHost.DisposeAsync(); - if (streamProxyChannel != null) await streamProxyChannel.DisposeAsync(); + if (tcpClientStreamHost != null) await tcpClientStreamHost.DisposeAsync().VhConfigureAwait(); + if (streamProxyChannel != null) await streamProxyChannel.DisposeAsync().VhConfigureAwait(); if (isRequestedEpException) throw new ServerSessionException(clientStream.IpEndPointPair.RemoteEndPoint, @@ -424,7 +425,7 @@ private async ValueTask DisposeAsync(bool sync, bool byUser) // Sync must before dispose, Some dispose may take time if (sync) - await Sync(true, byUser); + await Sync(true, byUser).VhConfigureAwait(); Tunnel.PacketReceived -= Tunnel_OnPacketReceived; _ = Tunnel.DisposeAsync(); @@ -455,7 +456,7 @@ public override Task OnPacketReceived(IPPacket ipPacket) PacketUtil.LogPacket(ipPacket, "Delegating packet to client via proxy."); ipPacket = session._netFilter.ProcessReply(ipPacket); - return session.Tunnel.SendPacket(ipPacket); + return session.Tunnel.SendPacketAsync(ipPacket, CancellationToken.None); } public override void OnNewEndPoint(ProtocolType protocolType, IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, diff --git a/VpnHood.Server/SessionManager.cs b/VpnHood.Server/SessionManager.cs index 295a41307..7f18c0f9c 100644 --- a/VpnHood.Server/SessionManager.cs +++ b/VpnHood.Server/SessionManager.cs @@ -72,7 +72,7 @@ public async Task SyncSessions() { try { - await syncTask.Task; + await syncTask.Task.VhConfigureAwait(); } catch (Exception ex) { @@ -100,7 +100,7 @@ private async Task CreateSessionInternal( session.SessionResponse.ErrorMessage = "Could not add session to collection."; session.SessionResponse.ErrorCode = SessionErrorCode.SessionError; - await session.DisposeAsync(); + await session.DisposeAsync().VhConfigureAwait(); throw new ServerSessionException(ipEndPointPair.RemoteEndPoint, session, session.SessionResponse, requestId); @@ -123,7 +123,7 @@ public async Task CreateSession(HelloRequest helloRequest, IP TokenId = helloRequest.TokenId, ServerLocation = helloRequest.ServerLocation, AllowRedirect = helloRequest.AllowRedirect - }); + }).VhConfigureAwait(); // Access Error should not pass to the client in create session if (sessionResponseEx.ErrorCode is SessionErrorCode.AccessError) @@ -134,7 +134,7 @@ public async Task CreateSession(HelloRequest helloRequest, IP throw new ServerSessionException(ipEndPointPair.RemoteEndPoint, sessionResponseEx, helloRequest); // create the session and add it to list - var session = await CreateSessionInternal(sessionResponseEx, ipEndPointPair, helloRequest.RequestId); + var session = await CreateSessionInternal(sessionResponseEx, ipEndPointPair, helloRequest.RequestId).VhConfigureAwait(); // Anonymous Report to GA _ = GaTrackNewSession(helloRequest.ClientInfo); @@ -166,7 +166,7 @@ private Task GaTrackNewSession(ClientInfo clientInfo) private async Task RecoverSession(RequestBase sessionRequest, IPEndPointPair ipEndPointPair) { - using var recoverLock = await AsyncLock.LockAsync($"Recover_session_{sessionRequest.SessionId}"); + using var recoverLock = await AsyncLock.LockAsync($"Recover_session_{sessionRequest.SessionId}").VhConfigureAwait(); var session = GetSessionById(sessionRequest.SessionId); if (session != null) return session; @@ -179,7 +179,8 @@ private async Task RecoverSession(RequestBase sessionRequest, IPEndPoin try { var sessionResponse = await _accessManager.Session_Get(sessionRequest.SessionId, - ipEndPointPair.LocalEndPoint, ipEndPointPair.RemoteEndPoint.Address); + ipEndPointPair.LocalEndPoint, ipEndPointPair.RemoteEndPoint.Address) + .VhConfigureAwait(); // Check session key for recovery if (!sessionRequest.SessionKey.SequenceEqual(sessionResponse.SessionKey)) @@ -191,7 +192,7 @@ private async Task RecoverSession(RequestBase sessionRequest, IPEndPoin throw new ServerSessionException(ipEndPointPair.RemoteEndPoint, sessionResponse, sessionRequest); // create the session even if it contains error to prevent many calls - session = await CreateSessionInternal(sessionResponse, ipEndPointPair, "recovery"); + session = await CreateSessionInternal(sessionResponse, ipEndPointPair, "recovery").VhConfigureAwait(); VhLogger.Instance.LogInformation(GeneralEventId.Session, "Session has been recovered. SessionId: {SessionId}", VhLogger.FormatSessionId(sessionRequest.SessionId)); @@ -212,8 +213,8 @@ private async Task RecoverSession(RequestBase sessionRequest, IPEndPoin SessionKey = sessionRequest.SessionKey, CreatedTime = DateTime.UtcNow, ErrorMessage = ex.Message - }, ipEndPointPair, "dead-recovery"); - await session.DisposeAsync(); + }, ipEndPointPair, "dead-recovery").VhConfigureAwait(); + await session.DisposeAsync().VhConfigureAwait(); throw; } } @@ -230,7 +231,7 @@ internal async Task GetSession(RequestBase requestBase, IPEndPointPair // try to restore session if not found else { - session = await RecoverSession(requestBase, ipEndPointPair); + session = await RecoverSession(requestBase, ipEndPointPair).VhConfigureAwait(); } if (session.SessionResponse.ErrorCode != SessionErrorCode.Ok) @@ -256,10 +257,10 @@ internal async Task GetSession(RequestBase requestBase, IPEndPointPair public async Task RunJob() { // anonymous heart_beat reporter - await _heartbeatSection.Enter(SendHeartbeat); + await _heartbeatSection.Enter(SendHeartbeat).VhConfigureAwait(); // clean disposed sessions - await Cleanup(); + await Cleanup().VhConfigureAwait(); } private Task SendHeartbeat() @@ -284,7 +285,7 @@ private async Task CloseExpiredSessions() .Where(x => !x.IsDisposed && x.SessionResponse.AccessUsage?.ExpirationTime < utcNow); foreach (var session in timeoutSessions) - await session.Sync(); + await session.Sync().VhConfigureAwait(); } private async Task RemoveTimeoutSession() @@ -298,15 +299,15 @@ private async Task RemoveTimeoutSession() foreach (var session in timeoutSessions) { Sessions.Remove(session.Key, out _); - await session.Value.DisposeAsync(); + await session.Value.DisposeAsync().VhConfigureAwait(); } } private async Task Cleanup() { - await CloseExpiredSessions(); - await RemoveTimeoutSession(); + await CloseExpiredSessions().VhConfigureAwait(); + await RemoveTimeoutSession().VhConfigureAwait(); } public Session? GetSessionById(ulong sessionId) @@ -323,18 +324,18 @@ public async Task CloseSession(ulong sessionId) { // find in session if (Sessions.TryGetValue(sessionId, out var session)) - await session.Close(); + await session.Close().VhConfigureAwait(); } private bool _disposed; private readonly AsyncLock _disposeLock = new(); public async ValueTask DisposeAsync() { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; - await Task.WhenAll(Sessions.Values.Select(x => x.DisposeAsync().AsTask())); + await Task.WhenAll(Sessions.Values.Select(x => x.DisposeAsync().AsTask())).VhConfigureAwait(); } } diff --git a/VpnHood.Server/VpnHood.Server.csproj b/VpnHood.Server/VpnHood.Server.csproj index 5e2fe036e..f1be2621a 100644 --- a/VpnHood.Server/VpnHood.Server.csproj +++ b/VpnHood.Server/VpnHood.Server.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm")) diff --git a/VpnHood.Server/VpnHoodServer.cs b/VpnHood.Server/VpnHoodServer.cs index ac460ed5e..db21f6770 100644 --- a/VpnHood.Server/VpnHoodServer.cs +++ b/VpnHood.Server/VpnHoodServer.cs @@ -74,14 +74,14 @@ public async Task RunJob() if (State == ServerState.Waiting && _configureTask.IsCompleted) { _configureTask = Configure(); // configure does not throw any error - await _configureTask; + await _configureTask.VhConfigureAwait(); return; } if (State == ServerState.Ready && _sendStatusTask.IsCompleted) { _sendStatusTask = SendStatusToAccessManager(true); - await _sendStatusTask; + await _sendStatusTask.VhConfigureAwait(); } } @@ -110,7 +110,7 @@ public async Task Start() // Configure State = ServerState.Waiting; - await RunJob(); + await RunJob().VhConfigureAwait(); } private async Task Configure() @@ -129,8 +129,8 @@ private async Task Configure() { EnvironmentVersion = Environment.Version, Version = ServerVersion, - PrivateIpAddresses = await IPAddressUtil.GetPrivateIpAddresses(), - PublicIpAddresses = _publicIpDiscovery ? await IPAddressUtil.GetPublicIpAddresses() : [], + PrivateIpAddresses = await IPAddressUtil.GetPrivateIpAddresses().VhConfigureAwait(), + PublicIpAddresses = _publicIpDiscovery ? await IPAddressUtil.GetPublicIpAddresses().VhConfigureAwait() : [], Status = GetStatus(), MachineName = Environment.MachineName, OsInfo = providerSystemInfo.OsInfo, @@ -143,13 +143,13 @@ private async Task Configure() var publicIpV4 = serverInfo.PublicIpAddresses.SingleOrDefault(x => x.AddressFamily == AddressFamily.InterNetwork); var publicIpV6 = serverInfo.PublicIpAddresses.SingleOrDefault(x => x.AddressFamily == AddressFamily.InterNetworkV6); - var isIpV6Supported = publicIpV6 != null || await IPAddressUtil.IsIpv6Supported(); + var isIpV6Supported = publicIpV6 != null || await IPAddressUtil.IsIpv6Supported().VhConfigureAwait(); VhLogger.Instance.LogInformation("Public IPv4: {IPv4}, Public IPv6: {IpV6}, IsV6Supported: {IsV6Supported}", VhLogger.Format(publicIpV4), VhLogger.Format(publicIpV6), isIpV6Supported); // get configuration from access server VhLogger.Instance.LogTrace("Sending config request to the Access Server..."); - var serverConfig = await ReadConfig(serverInfo); + var serverConfig = await ReadConfig(serverInfo).VhConfigureAwait(); VhLogger.IsAnonymousMode = serverConfig.LogAnonymizerValue; SessionManager.TrackingOptions = serverConfig.TrackingOptions; SessionManager.SessionOptions = serverConfig.SessionOptions; @@ -165,8 +165,10 @@ private async Task Configure() privateAddresses: allServerIps, isIpV6Supported, dnsServers: serverConfig.DnsServersValue); // Reconfigure server host - await _serverHost.Configure(serverConfig.TcpEndPointsValue, serverConfig.UdpEndPointsValue, - serverConfig.DnsServersValue, serverConfig.Certificates.Select(x => new X509Certificate2(x.RawData)).ToArray()); + await _serverHost.Configure( + serverConfig.TcpEndPointsValue, serverConfig.UdpEndPointsValue, + serverConfig.DnsServersValue, serverConfig.Certificates.Select(x => new X509Certificate2(x.RawData)) + .ToArray()).VhConfigureAwait(); // Reconfigure dns challenge StartDnsChallenge(serverConfig.TcpEndPointsValue.Select(x => x.Address), serverConfig.DnsChallenge); @@ -179,7 +181,7 @@ await _serverHost.Configure(serverConfig.TcpEndPointsValue, serverConfig.UdpEndP VhLogger.Instance.LogInformation("Server is ready!"); // set status after successful configuration - await SendStatusToAccessManager(false); + await SendStatusToAccessManager(false).VhConfigureAwait(); } catch (Exception ex) { @@ -190,7 +192,7 @@ await _serverHost.Configure(serverConfig.TcpEndPointsValue, serverConfig.UdpEndP _ = SessionManager.GaTracker?.TrackErrorByTag("configure", ex.Message); VhLogger.Instance.LogError(ex, "Could not configure server! Retrying after {TotalSeconds} seconds.", JobSection.Interval.TotalSeconds); - await SendStatusToAccessManager(false); + await SendStatusToAccessManager(false).VhConfigureAwait(); } } @@ -215,17 +217,17 @@ private void StartDnsChallenge(IEnumerable ipAddresses, DnsChallenge? private static void ConfigNetFilter(INetFilter netFilter, ServerHost serverHost, NetFilterOptions netFilterOptions, IEnumerable privateAddresses, bool isIpV6Supported, IEnumerable dnsServers) { - var dnsServerIpRanges = dnsServers.Select(x => new IpRange(x)).ToArray(); + var dnsServerIpRanges = dnsServers.Select(x => new IpRange(x)).ToOrderedList(); // assign to workers serverHost.NetFilterIncludeIpRanges = netFilterOptions.GetFinalIncludeIpRanges().Union(dnsServerIpRanges).ToArray(); serverHost.NetFilterPacketCaptureIncludeIpRanges = netFilterOptions.GetFinalPacketCaptureIncludeIpRanges().Union(dnsServerIpRanges).ToArray(); serverHost.IsIpV6Supported = isIpV6Supported && !netFilterOptions.BlockIpV6Value; - netFilter.BlockedIpRanges = netFilterOptions.GetBlockedIpRanges().Exclude(dnsServerIpRanges).ToArray(); + netFilter.BlockedIpRanges = netFilterOptions.GetBlockedIpRanges().Exclude(dnsServerIpRanges); // exclude listening ip if (!netFilterOptions.IncludeLocalNetworkValue) - netFilter.BlockedIpRanges = netFilter.BlockedIpRanges.Union(privateAddresses.Select(x => new IpRange(x))).ToArray(); + netFilter.BlockedIpRanges = netFilter.BlockedIpRanges.Union(privateAddresses.Select(x => new IpRange(x))); } private static int GetBestTcpBufferSize(long? totalMemory, int? configValue) @@ -244,7 +246,7 @@ private static int GetBestTcpBufferSize(long? totalMemory, int? configValue) private async Task ReadConfig(ServerInfo serverInfo) { - var serverConfig = await ReadConfigImpl(serverInfo); + var serverConfig = await ReadConfigImpl(serverInfo).VhConfigureAwait(); serverConfig.SessionOptions.TcpBufferSize = GetBestTcpBufferSize(serverInfo.TotalMemory, serverConfig.SessionOptions.TcpBufferSize); serverConfig.ApplyDefaults(); VhLogger.Instance.LogInformation("RemoteConfig: {RemoteConfig}", GetServerConfigReport(serverConfig)); @@ -278,8 +280,8 @@ private async Task ReadConfigImpl(ServerInfo serverInfo) { try { - var serverConfig = await AccessManager.Server_Configure(serverInfo); - try { await File.WriteAllTextAsync(_lastConfigFilePath, JsonSerializer.Serialize(serverConfig)); } + var serverConfig = await AccessManager.Server_Configure(serverInfo).VhConfigureAwait(); + try { await File.WriteAllTextAsync(_lastConfigFilePath, JsonSerializer.Serialize(serverConfig)).VhConfigureAwait(); } catch { /* Ignore */ } return serverConfig; } @@ -290,7 +292,8 @@ private async Task ReadConfigImpl(ServerInfo serverInfo) { if (File.Exists(_lastConfigFilePath)) { - var ret = VhUtil.JsonDeserialize(await File.ReadAllTextAsync(_lastConfigFilePath)); + var configJson = await File.ReadAllTextAsync(_lastConfigFilePath).VhConfigureAwait(); + var ret = VhUtil.JsonDeserialize(configJson); VhLogger.Instance.LogWarning("Last configuration has been loaded to report Maintenance mode."); return ret; } @@ -333,13 +336,13 @@ private async Task SendStatusToAccessManager(bool allowConfigure) { var status = GetStatus(); VhLogger.Instance.LogTrace("Sending status to Access... ConfigCode: {ConfigCode}", status.ConfigCode); - var res = await AccessManager.Server_UpdateStatus(status); + var res = await AccessManager.Server_UpdateStatus(status).VhConfigureAwait(); // reconfigure if (allowConfigure && (res.ConfigCode != _lastConfigCode)) { VhLogger.Instance.LogInformation("Reconfiguration was requested."); - await Configure(); + await Configure().VhConfigureAwait(); } } catch (Exception ex) @@ -372,7 +375,8 @@ private Task GaTrackStart() public void Dispose() { - Task.Run(async () => await DisposeAsync(), CancellationToken.None) + DisposeAsync() + .VhConfigureAwait() .GetAwaiter() .GetResult(); } @@ -380,7 +384,7 @@ public void Dispose() private readonly AsyncLock _disposeLock = new(); public async ValueTask DisposeAsync() { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; @@ -388,10 +392,10 @@ public async ValueTask DisposeAsync() VhLogger.Instance.LogInformation("Shutting down..."); // wait for configuration - try { await _configureTask; } catch {/* no error */ } - try { await _sendStatusTask; } catch {/* no error*/ } - await _serverHost.DisposeAsync(); // before disposing session manager to prevent recovering sessions - await SessionManager.DisposeAsync(); + try { await _configureTask.VhConfigureAwait(); } catch {/* no error */ } + try { await _sendStatusTask.VhConfigureAwait(); } catch {/* no error*/ } + await _serverHost.DisposeAsync().VhConfigureAwait(); // before disposing session manager to prevent recovering sessions + await SessionManager.DisposeAsync().VhConfigureAwait(); _http01ChallengeService?.Dispose(); if (_autoDisposeAccessManager) diff --git a/VpnHood.Tunneling/Channels/StreamDatagramChannel.cs b/VpnHood.Tunneling/Channels/StreamDatagramChannel.cs index 00fab48b7..515880d8c 100644 --- a/VpnHood.Tunneling/Channels/StreamDatagramChannel.cs +++ b/VpnHood.Tunneling/Channels/StreamDatagramChannel.cs @@ -60,8 +60,8 @@ public async Task StartInternal() Connected = true; try { - await ReadTask(_cancellationTokenSource.Token); - await SendClose(); + await ReadTask(_cancellationTokenSource.Token).VhConfigureAwait(); + await SendClose().VhConfigureAwait(); } finally { @@ -82,7 +82,7 @@ public async Task SendPacket(IPPacket[] ipPackets, bool disconnect) try { - await _sendSemaphore.WaitAsync(_cancellationTokenSource.Token); + await _sendSemaphore.WaitAsync(_cancellationTokenSource.Token).VhConfigureAwait(); // check channel connectivity _cancellationTokenSource.Token.ThrowIfCancellationRequested(); @@ -105,7 +105,7 @@ public async Task SendPacket(IPPacket[] ipPackets, bool disconnect) bufferIndex += ipPacket.TotalLength; } - await _clientStream.Stream.WriteAsync(buffer, 0, bufferIndex, _cancellationTokenSource.Token); + await _clientStream.Stream.WriteAsync(buffer, 0, bufferIndex, _cancellationTokenSource.Token).VhConfigureAwait(); LastActivityTime = FastDateTime.Now; Traffic.Sent += bufferIndex; } @@ -125,7 +125,7 @@ private async Task ReadTask(CancellationToken cancellationToken) await using var streamPacketReader = new StreamPacketReader(stream); while (!cancellationToken.IsCancellationRequested && !_isCloseReceived) { - var ipPackets = await streamPacketReader.ReadAsync(cancellationToken); + var ipPackets = await streamPacketReader.ReadAsync(cancellationToken).VhConfigureAwait(); if (ipPackets == null || _disposed) break; @@ -228,13 +228,13 @@ public ValueTask DisposeAsync() private readonly AsyncLock _disposeLock = new(); public async ValueTask DisposeAsync(bool graceful) { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; if (graceful) - await SendClose(); // this won't throw any error + await SendClose().VhConfigureAwait(); // this won't throw any error - await _clientStream.DisposeAsync(graceful); + await _clientStream.DisposeAsync(graceful).VhConfigureAwait(); _disposed = true; } } \ No newline at end of file diff --git a/VpnHood.Tunneling/Channels/StreamProxyChannel.cs b/VpnHood.Tunneling/Channels/StreamProxyChannel.cs index 75ab174f8..3c2297bc9 100644 --- a/VpnHood.Tunneling/Channels/StreamProxyChannel.cs +++ b/VpnHood.Tunneling/Channels/StreamProxyChannel.cs @@ -73,7 +73,7 @@ private async Task StartInternal() _hostTcpClientStream.Stream, _tunnelTcpClientStream.Stream, true, _orgStreamBufferSize, CancellationToken.None, CancellationToken.None); // host => tunnel - await Task.WhenAny(tunnelCopyTask, hostCopyTask); + await Task.WhenAny(tunnelCopyTask, hostCopyTask).VhConfigureAwait(); } finally { @@ -110,7 +110,7 @@ private async Task CopyToAsync(Stream source, Stream destination, bool isDestina try { await CopyToInternalAsync(source, destination, isDestinationTunnel, bufferSize, - sourceCancellationToken, destinationCancellationToken); + sourceCancellationToken, destinationCancellationToken).VhConfigureAwait(); _isFinished = true; } catch (Exception ex) @@ -151,14 +151,14 @@ private async Task CopyToInternalAsync(Stream source, Stream destination, bool i while (!sourceCancellationToken.IsCancellationRequested && !destinationCancellationToken.IsCancellationRequested) { // read from source - var bytesRead = await source.ReadAsync(readBuffer, preserveCount, readBuffer.Length - preserveCount, sourceCancellationToken); + var bytesRead = await source.ReadAsync(readBuffer, preserveCount, readBuffer.Length - preserveCount, sourceCancellationToken).VhConfigureAwait(); // check end of the stream if (bytesRead == 0) break; // write to destination - await destination.WriteAsync(readBuffer, preserveCount, bytesRead, destinationCancellationToken); + await destination.WriteAsync(readBuffer, preserveCount, bytesRead, destinationCancellationToken).VhConfigureAwait(); // calculate transferred bytes if (isSendingToTunnel) @@ -180,13 +180,14 @@ public ValueTask DisposeAsync() private readonly AsyncLock _disposeLock = new(); public async ValueTask DisposeAsync(bool graceful) { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; Connected = false; await Task.WhenAll( _hostTcpClientStream.DisposeAsync(graceful).AsTask(), - _tunnelTcpClientStream.DisposeAsync(graceful).AsTask()); + _tunnelTcpClientStream.DisposeAsync(graceful).AsTask()) + .VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Tunneling/Channels/Streams/BinaryStreamStandard.cs b/VpnHood.Tunneling/Channels/Streams/BinaryStreamStandard.cs index e62ee287d..1ac144c71 100644 --- a/VpnHood.Tunneling/Channels/Streams/BinaryStreamStandard.cs +++ b/VpnHood.Tunneling/Channels/Streams/BinaryStreamStandard.cs @@ -40,7 +40,7 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, // Create CancellationToken using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(_readCts.Token, cancellationToken); _readTask = ReadInternalAsync(buffer, offset, count, tokenSource.Token); - return await _readTask; // await needed to dispose tokenSource + return await _readTask.VhConfigureAwait(); // await needed to dispose tokenSource } private async Task ReadInternalAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -54,7 +54,7 @@ private async Task ReadInternalAsync(byte[] buffer, int offset, int count, // If there are no more in the chunks read the next chunk if (_remainingChunkBytes == 0) { - _remainingChunkBytes = await ReadChunkHeaderAsync(cancellationToken); + _remainingChunkBytes = await ReadChunkHeaderAsync(cancellationToken).VhConfigureAwait(); _finished = _remainingChunkBytes == 0; // check last chunk @@ -63,7 +63,7 @@ private async Task ReadInternalAsync(byte[] buffer, int offset, int count, } var bytesToRead = Math.Min(_remainingChunkBytes, count); - var bytesRead = await SourceStream.ReadAsync(buffer, offset, bytesToRead, cancellationToken); + var bytesRead = await SourceStream.ReadAsync(buffer, offset, bytesToRead, cancellationToken).VhConfigureAwait(); if (bytesRead == 0 && count != 0) // count zero is used for checking the connection throw new Exception("BinaryStream has been closed unexpectedly."); @@ -91,7 +91,7 @@ private void CloseByError(Exception ex) private async Task ReadChunkHeaderAsync(CancellationToken cancellationToken) { // read chunk header by cryptor - if (!await StreamUtil.ReadWaitForFillAsync(SourceStream, _readChunkHeaderBuffer, 0, _readChunkHeaderBuffer.Length, cancellationToken)) + if (!await StreamUtil.ReadWaitForFillAsync(SourceStream, _readChunkHeaderBuffer, 0, _readChunkHeaderBuffer.Length, cancellationToken).VhConfigureAwait()) { if (!_finished && ReadChunkCount > 0) VhLogger.Instance.LogWarning(GeneralEventId.TcpLife, "BinaryStream has been closed unexpectedly."); @@ -123,7 +123,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc ? SourceStream.WriteAsync(buffer, offset, count, tokenSource.Token) : WriteInternalAsync(buffer, offset, count, tokenSource.Token); - await _writeTask; + await _writeTask.VhConfigureAwait(); } catch (Exception ex) { @@ -140,7 +140,7 @@ private async Task WriteInternalAsync(byte[] buffer, int offset, int count, Canc { // create the chunk header BitConverter.GetBytes(count).CopyTo(buffer, offset - ChunkHeaderLength); - await SourceStream.WriteAsync(buffer, offset - ChunkHeaderLength, ChunkHeaderLength + count, cancellationToken); + await SourceStream.WriteAsync(buffer, offset - ChunkHeaderLength, ChunkHeaderLength + count, cancellationToken).VhConfigureAwait(); } else { @@ -155,11 +155,11 @@ private async Task WriteInternalAsync(byte[] buffer, int offset, int count, Canc Buffer.BlockCopy(buffer, offset, _writeBuffer, ChunkHeaderLength, count); // Copy write buffer to output - await SourceStream.WriteAsync(_writeBuffer, 0, size, cancellationToken); + await SourceStream.WriteAsync(_writeBuffer, 0, size, cancellationToken).VhConfigureAwait(); } // make sure chunk is sent - await SourceStream.FlushAsync(cancellationToken); + await SourceStream.FlushAsync(cancellationToken).VhConfigureAwait(); WroteChunkCount++; } catch (Exception ex) @@ -184,14 +184,14 @@ public override async Task CreateReuse() // Dispose the stream but keep the original stream open _leaveOpen = true; - await DisposeAsync(); + await DisposeAsync().VhConfigureAwait(); // reuse if the stream has been closed gracefully if (_finished && !_hasError) return new BinaryStreamStandard(SourceStream, StreamId, ReusedCount + 1); // dispose and throw the ungraceful BinaryStream - await base.DisposeAsync(); + await base.DisposeAsync().VhConfigureAwait(); throw new InvalidOperationException($"Could not reuse a BinaryStream that has not been closed gracefully. StreamId: {StreamId}"); } @@ -204,7 +204,7 @@ private async Task CloseStream(CancellationToken cancellationToken) // wait for finishing current write try { - await _writeTask; + await _writeTask.VhConfigureAwait(); } catch (Exception ex) { @@ -219,9 +219,9 @@ private async Task CloseStream(CancellationToken cancellationToken) try { if (PreserveWriteBuffer) - await WriteInternalAsync(new byte[ChunkHeaderLength], ChunkHeaderLength, 0, cancellationToken); + await WriteInternalAsync(new byte[ChunkHeaderLength], ChunkHeaderLength, 0, cancellationToken).VhConfigureAwait(); else - await WriteInternalAsync([], 0, 0, cancellationToken); + await WriteInternalAsync([], 0, 0, cancellationToken).VhConfigureAwait(); } catch (Exception ex) @@ -235,7 +235,7 @@ private async Task CloseStream(CancellationToken cancellationToken) // make sure current caller read has been finished gracefully or wait for cancellation time try { - await _readTask; + await _readTask.VhConfigureAwait(); } catch (Exception ex) { @@ -258,7 +258,7 @@ private async Task CloseStream(CancellationToken cancellationToken) var trashedLength = 0; while (true) { - var read = await ReadInternalAsync(buffer, 0, buffer.Length, cancellationToken); + var read = await ReadInternalAsync(buffer, 0, buffer.Length, cancellationToken).VhConfigureAwait(); if (read == 0) break; @@ -283,7 +283,7 @@ private async Task CloseStream(CancellationToken cancellationToken) private readonly AsyncLock _disposeLock = new(); public override async ValueTask DisposeAsync() { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; @@ -292,10 +292,10 @@ public override async ValueTask DisposeAsync() { // create a new cancellation token for CloseStream using var cancellationTokenSource = new CancellationTokenSource(TunnelDefaults.TcpGracefulTimeout); - await CloseStream(cancellationTokenSource.Token); + await CloseStream(cancellationTokenSource.Token).VhConfigureAwait(); } if (!_leaveOpen) - await base.DisposeAsync(); + await base.DisposeAsync().VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Tunneling/Channels/Streams/HttpStream.cs b/VpnHood.Tunneling/Channels/Streams/HttpStream.cs index 44a2d9974..ce7c84d08 100644 --- a/VpnHood.Tunneling/Channels/Streams/HttpStream.cs +++ b/VpnHood.Tunneling/Channels/Streams/HttpStream.cs @@ -68,7 +68,7 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, // Create CancellationToken using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(_readCts.Token, cancellationToken); _readTask = ReadInternalAsync(buffer, offset, count, tokenSource.Token); - return await _readTask; // await needed to dispose tokenSource + return await _readTask.VhConfigureAwait(); // await needed to dispose tokenSource } private async Task ReadInternalAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -78,7 +78,7 @@ private async Task ReadInternalAsync(byte[] buffer, int offset, int count, // ignore header if (!_isHttpHeaderRead) { - using var headerBuffer = await HttpUtil.ReadHeadersAsync(SourceStream, cancellationToken); + using var headerBuffer = await HttpUtil.ReadHeadersAsync(SourceStream, cancellationToken).VhConfigureAwait(); if (headerBuffer.Length == 0) { _isFinished = true; @@ -95,7 +95,7 @@ private async Task ReadInternalAsync(byte[] buffer, int offset, int count, // If there are no more in the chunks read the next chunk if (_remainingChunkBytes == 0) - _remainingChunkBytes = await ReadChunkHeaderAsync(cancellationToken); + _remainingChunkBytes = await ReadChunkHeaderAsync(cancellationToken).VhConfigureAwait(); // check last chunk _isFinished = _remainingChunkBytes == 0; @@ -103,7 +103,7 @@ private async Task ReadInternalAsync(byte[] buffer, int offset, int count, return 0; var bytesToRead = Math.Min(_remainingChunkBytes, count); - var bytesRead = await SourceStream.ReadAsync(buffer, offset, bytesToRead, cancellationToken); + var bytesRead = await SourceStream.ReadAsync(buffer, offset, bytesToRead, cancellationToken).VhConfigureAwait(); if (bytesRead == 0 && count != 0) // count zero is used for checking the connection throw new Exception("HttpStream has been closed unexpectedly."); @@ -130,14 +130,14 @@ private async Task ReadChunkHeaderAsync(CancellationToken cancellationToken { // read the end of last chunk if it is not first chunk if (ReadChunkCount != 0) - await ReadNextLine(cancellationToken); + await ReadNextLine(cancellationToken).VhConfigureAwait(); // read chunk var bufferOffset = 0; while (!cancellationToken.IsCancellationRequested) { if (bufferOffset == _chunkHeaderBuffer.Length) throw new InvalidDataException("Chunk header exceeds the maximum size."); - var bytesRead = await SourceStream.ReadAsync(_chunkHeaderBuffer, bufferOffset, 1, cancellationToken); + var bytesRead = await SourceStream.ReadAsync(_chunkHeaderBuffer, bufferOffset, 1, cancellationToken).VhConfigureAwait(); if (bytesRead == 0) throw new InvalidDataException("Could not read HTTP Chunk header."); @@ -153,7 +153,7 @@ private async Task ReadChunkHeaderAsync(CancellationToken cancellationToken throw new InvalidDataException("Invalid HTTP chunk size."); if (chunkSize == 0) - await ReadNextLine(cancellationToken); //read the end of stream + await ReadNextLine(cancellationToken).VhConfigureAwait(); //read the end of stream else ReadChunkCount++; @@ -179,7 +179,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc ? SourceStream.WriteAsync(buffer, offset, count, tokenSource.Token) : WriteInternalAsync(buffer, offset, count, tokenSource.Token); - await _writeTask; + await _writeTask.VhConfigureAwait(); } catch (Exception ex) { @@ -196,20 +196,20 @@ private async Task WriteInternalAsync(byte[] buffer, int offset, int count, Canc // write header for first time if (!_isHttpHeaderSent) { - await SourceStream.WriteAsync(Encoding.UTF8.GetBytes(CreateHttpHeader()), cancellationToken); + await SourceStream.WriteAsync(Encoding.UTF8.GetBytes(CreateHttpHeader()), cancellationToken).VhConfigureAwait(); _isHttpHeaderSent = true; } // Write the chunk header var headerBytes = Encoding.ASCII.GetBytes(count.ToString("X") + "\r\n"); - await SourceStream.WriteAsync(headerBytes, cancellationToken); + await SourceStream.WriteAsync(headerBytes, cancellationToken).VhConfigureAwait(); // Write the chunk data - await SourceStream.WriteAsync(buffer, offset, count, cancellationToken); + await SourceStream.WriteAsync(buffer, offset, count, cancellationToken).VhConfigureAwait(); // Write the chunk footer - await SourceStream.WriteAsync(_newLineBytes, cancellationToken); - await FlushAsync(cancellationToken); + await SourceStream.WriteAsync(_newLineBytes, cancellationToken).VhConfigureAwait(); + await FlushAsync(cancellationToken).VhConfigureAwait(); WroteChunkCount++; } @@ -222,7 +222,7 @@ private async Task WriteInternalAsync(byte[] buffer, int offset, int count, Canc private async Task ReadNextLine(CancellationToken cancellationToken) { - var bytesRead = await SourceStream.ReadAsync(_nextLineBuffer, 0, 2, cancellationToken); + var bytesRead = await SourceStream.ReadAsync(_nextLineBuffer, 0, 2, cancellationToken).VhConfigureAwait(); if (bytesRead < 2 || _nextLineBuffer[0] != '\r' || _nextLineBuffer[1] != '\n') throw new InvalidDataException("Could not find expected line feed in HTTP chunk header."); } @@ -242,14 +242,14 @@ public override async Task CreateReuse() // Dispose the stream but keep the original stream open _keepOpen = true; - await DisposeAsync(); + await DisposeAsync().VhConfigureAwait(); // reuse if the stream has been closed gracefully if (_isFinished && !_hasError) return new HttpStream(SourceStream, StreamId, _host); // dispose and throw the ungraceful HttpStream - await base.DisposeAsync(); + await base.DisposeAsync().VhConfigureAwait(); throw new InvalidOperationException($"Could not reuse a HttpStream that has not been closed gracefully. StreamId: {StreamId}"); } @@ -259,14 +259,14 @@ private async Task CloseStream(CancellationToken cancellationToken) _readCts.CancelAfter(TunnelDefaults.TcpGracefulTimeout); _writeCts.CancelAfter(TunnelDefaults.TcpGracefulTimeout); - try { await _writeTask; } + try { await _writeTask.VhConfigureAwait(); } catch { /* Ignore */ } _writeCts.Dispose(); // finish writing current HttpStream gracefully try { - await WriteInternalAsync([], 0, 0, cancellationToken); + await WriteInternalAsync([], 0, 0, cancellationToken).VhConfigureAwait(); } catch (Exception ex) { @@ -276,7 +276,7 @@ private async Task CloseStream(CancellationToken cancellationToken) } // make sure current caller read has been finished gracefully or wait for cancellation time - try { await _readTask; } + try { await _readTask.VhConfigureAwait(); } catch { /* Ignore */ } try @@ -290,7 +290,7 @@ private async Task CloseStream(CancellationToken cancellationToken) if (!_isFinished) { var buffer = new byte[10]; - var read = await ReadInternalAsync(buffer, 0, buffer.Length, cancellationToken); + var read = await ReadInternalAsync(buffer, 0, buffer.Length, cancellationToken).VhConfigureAwait(); if (read != 0) throw new InvalidDataException("HttpStream read unexpected data on end."); } @@ -307,7 +307,7 @@ private async Task CloseStream(CancellationToken cancellationToken) private readonly AsyncLock _disposeLock = new(); public override async ValueTask DisposeAsync() { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; @@ -316,10 +316,10 @@ public override async ValueTask DisposeAsync() { // create a new cancellation token for CloseStream using var cancellationTokenSource = new CancellationTokenSource(TunnelDefaults.TcpGracefulTimeout); - await CloseStream(cancellationTokenSource.Token); + await CloseStream(cancellationTokenSource.Token).VhConfigureAwait(); } if (!_keepOpen) - await base.DisposeAsync(); + await base.DisposeAsync().VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Tunneling/Channels/UdpChannel.cs b/VpnHood.Tunneling/Channels/UdpChannel.cs index fbe16cb61..d8c131726 100644 --- a/VpnHood.Tunneling/Channels/UdpChannel.cs +++ b/VpnHood.Tunneling/Channels/UdpChannel.cs @@ -47,7 +47,7 @@ public async Task SendPacket(IPPacket[] ipPackets) { // this is shared buffer and client, so we need to sync // Using multiple UdpClient will not increase performance - await _semaphore.WaitAsync(); + await _semaphore.WaitAsync().VhConfigureAwait(); var bufferIndex = UdpChannelTransmitter.HeaderLength; @@ -66,7 +66,7 @@ public async Task SendPacket(IPPacket[] ipPackets) if (_lastRemoteEp == null) throw new InvalidOperationException("RemoveEndPoint has not been initialized yet in UdpChannel."); if (_udpChannelTransmitter == null) throw new InvalidOperationException("UdpChannelTransmitter has not been initialized yet in UdpChannel."); var ret = await _udpChannelTransmitter.SendAsync(_lastRemoteEp, sessionId, - sessionCryptoPosition, _buffer, bufferIndex, protocolVersion); + sessionCryptoPosition, _buffer, bufferIndex, protocolVersion).VhConfigureAwait(); Traffic.Sent += ret; LastActivityTime = FastDateTime.Now; @@ -74,7 +74,7 @@ public async Task SendPacket(IPPacket[] ipPackets) catch (Exception ex) { if (IsInvalidState(ex)) - await DisposeAsync(); + await DisposeAsync().VhConfigureAwait(); } finally { diff --git a/VpnHood.Tunneling/Channels/UdpChannelTransmitter.cs b/VpnHood.Tunneling/Channels/UdpChannelTransmitter.cs index a951b4459..ffdb8c7a3 100644 --- a/VpnHood.Tunneling/Channels/UdpChannelTransmitter.cs +++ b/VpnHood.Tunneling/Channels/UdpChannelTransmitter.cs @@ -3,6 +3,7 @@ using System.Security.Cryptography; using Microsoft.Extensions.Logging; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; namespace VpnHood.Tunneling.Channels; @@ -34,7 +35,7 @@ public async Task SendAsync(IPEndPoint? ipEndPoint, ulong sessionId, long s { try { - await _semaphore.WaitAsync(); + await _semaphore.WaitAsync().VhConfigureAwait(); // add random packet iv _randomGenerator.GetBytes(_sendIv); @@ -55,8 +56,8 @@ public async Task SendAsync(IPEndPoint? ipEndPoint, ulong sessionId, long s buffer[_sendIv.Length + i] ^= _sendHeadKeyBuffer[i]; //simple XOR with generated unique key var ret = ipEndPoint != null - ? await _udpClient.SendAsync(buffer, bufferLength, ipEndPoint) - : await _udpClient.SendAsync(buffer, bufferLength); + ? await _udpClient.SendAsync(buffer, bufferLength, ipEndPoint).VhConfigureAwait() + : await _udpClient.SendAsync(buffer, bufferLength).VhConfigureAwait(); if (ret != bufferLength) throw new Exception($"UdpClient: Send {ret} bytes instead {buffer.Length} bytes."); @@ -90,7 +91,7 @@ private async Task ReadTask() try { remoteEndPoint = null; - var udpResult = await _udpClient.ReceiveAsync(); + var udpResult = await _udpClient.ReceiveAsync().VhConfigureAwait(); remoteEndPoint = udpResult.RemoteEndPoint; var buffer = udpResult.Buffer; if (buffer.Length < HeaderLength) diff --git a/VpnHood.Tunneling/ClientStreams/TcpClientStream.cs b/VpnHood.Tunneling/ClientStreams/TcpClientStream.cs index 0e833d554..3f49b2814 100644 --- a/VpnHood.Tunneling/ClientStreams/TcpClientStream.cs +++ b/VpnHood.Tunneling/ClientStreams/TcpClientStream.cs @@ -67,7 +67,7 @@ public ValueTask DisposeAsync() public bool Disposed { get; private set; } public async ValueTask DisposeAsync(bool graceful) { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (Disposed) return; Disposed = true; @@ -77,7 +77,7 @@ public async ValueTask DisposeAsync(bool graceful) Stream? newStream = null; try { - newStream = await chunkStream.CreateReuse(); + newStream = await chunkStream.CreateReuse().VhConfigureAwait(); _ = _reuseCallback.Invoke(new TcpClientStream(TcpClient, newStream, ClientStreamId, _reuseCallback, false)); VhLogger.Instance.LogTrace(GeneralEventId.TcpLife, @@ -88,15 +88,15 @@ public async ValueTask DisposeAsync(bool graceful) VhLogger.LogError(GeneralEventId.TcpLife, ex, "Could not reuse the TcpClientStream. ClientStreamId: {ClientStreamId}", ClientStreamId); - if (newStream != null) await newStream.DisposeAsync(); - await Stream.DisposeAsync(); + if (newStream != null) await newStream.DisposeAsync().VhConfigureAwait(); + await Stream.DisposeAsync().VhConfigureAwait(); TcpClient.Dispose(); } } else { // close streams - await Stream.DisposeAsync(); // first close stream 2 + await Stream.DisposeAsync().VhConfigureAwait(); // first close stream 2 TcpClient.Dispose(); VhLogger.Instance.LogTrace(GeneralEventId.TcpLife, diff --git a/VpnHood.Tunneling/PingProxy.cs b/VpnHood.Tunneling/PingProxy.cs index ae6ca5e66..36253a71b 100644 --- a/VpnHood.Tunneling/PingProxy.cs +++ b/VpnHood.Tunneling/PingProxy.cs @@ -23,13 +23,13 @@ public async Task Send(IPPacket ipPacket) try { - await _finishSemaphore.WaitAsync(); + await _finishSemaphore.WaitAsync().VhConfigureAwait(); IsBusy = true; LastUsedTime = FastDateTime.Now; return ipPacket.Version == IPVersion.IPv4 - ? await SendIpV4(ipPacket.Extract()) - : await SendIpV6(ipPacket.Extract()); + ? await SendIpV4(ipPacket.Extract()).VhConfigureAwait() + : await SendIpV6(ipPacket.Extract()).VhConfigureAwait(); } finally { @@ -50,7 +50,7 @@ private async Task SendIpV4(IPv4Packet ipPacket) var noFragment = (ipPacket.FragmentFlags & 0x2) != 0; var pingOptions = new PingOptions(ipPacket.TimeToLive - 1, noFragment); - var pingReply = await _ping.SendPingAsync(ipPacket.DestinationAddress, (int)IcmpTimeout.TotalMilliseconds, icmpPacket.Data, pingOptions); + var pingReply = await _ping.SendPingAsync(ipPacket.DestinationAddress, (int)IcmpTimeout.TotalMilliseconds, icmpPacket.Data, pingOptions).VhConfigureAwait(); if (pingReply.Status != IPStatus.Success) throw new Exception($"Ping Reply has been failed! Status: {pingReply.Status}"); @@ -79,7 +79,7 @@ private async Task SendIpV6(IPv6Packet ipPacket) var pingOptions = new PingOptions(ipPacket.TimeToLive - 1, true); var pingData = icmpPacket.Bytes[8..]; - var pingReply = await _ping.SendPingAsync(ipPacket.DestinationAddress, (int)IcmpTimeout.TotalMilliseconds, pingData, pingOptions); + var pingReply = await _ping.SendPingAsync(ipPacket.DestinationAddress, (int)IcmpTimeout.TotalMilliseconds, pingData, pingOptions).VhConfigureAwait(); // IcmpV6 packet generation is not fully implemented by packetNet // So create all packet in buffer diff --git a/VpnHood.Tunneling/PingProxyPool.cs b/VpnHood.Tunneling/PingProxyPool.cs index baff8c713..af70de689 100644 --- a/VpnHood.Tunneling/PingProxyPool.cs +++ b/VpnHood.Tunneling/PingProxyPool.cs @@ -3,6 +3,7 @@ using VpnHood.Common.Collections; using VpnHood.Common.Jobs; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; namespace VpnHood.Tunneling; @@ -92,8 +93,8 @@ public async Task SendPacket(IPPacket ipPacket) new IPEndPoint(ipPacket.SourceAddress, 0), new IPEndPoint(ipPacket.DestinationAddress, 0), isNewLocalEndPoint, isNewRemoteEndPoint); - var result = await sendTask; - await _packetProxyReceiver.OnPacketReceived(result); + var result = await sendTask.VhConfigureAwait(); + await _packetProxyReceiver.OnPacketReceived(result).VhConfigureAwait(); } public Task RunJob() diff --git a/VpnHood.Tunneling/ProxyManager.cs b/VpnHood.Tunneling/ProxyManager.cs index ce1b71830..f0ae94274 100644 --- a/VpnHood.Tunneling/ProxyManager.cs +++ b/VpnHood.Tunneling/ProxyManager.cs @@ -3,6 +3,7 @@ using Microsoft.Extensions.Logging; using PacketDotNet; using VpnHood.Common.Logging; +using VpnHood.Common.Utils; using VpnHood.Tunneling.Channels; using VpnHood.Tunneling.Factory; using ProtocolType = PacketDotNet.ProtocolType; @@ -42,7 +43,7 @@ protected ProxyManager(ISocketFactory socketFactory, ProxyManagerOptions options public async Task SendPackets(IEnumerable ipPackets) { foreach (var ipPacket in ipPackets) - await SendPacket(ipPacket); + await SendPacket(ipPacket).VhConfigureAwait(); } public async Task SendPacket(IPPacket ipPacket) @@ -63,14 +64,14 @@ public async Task SendPacket(IPPacket ipPacket) switch (ipPacket.Protocol) { case ProtocolType.Udp: - await _udpProxyPool.SendPacket(ipPacket); + await _udpProxyPool.SendPacket(ipPacket).VhConfigureAwait(); break; case ProtocolType.Icmp or ProtocolType.IcmpV6: if (!IsPingSupported) throw new NotSupportedException("Ping is not supported by this proxy."); - await _pingProxyPool.SendPacket(ipPacket); + await _pingProxyPool.SendPacket(ipPacket).VhConfigureAwait(); break; default: @@ -108,6 +109,6 @@ public async ValueTask DisposeAsync() lock (_channels) disposeTasks.AddRange(_channels.Select(channel => channel.DisposeAsync(false).AsTask())); - await Task.WhenAll(disposeTasks); + await Task.WhenAll(disposeTasks).VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Tunneling/StreamCryptor.cs b/VpnHood.Tunneling/StreamCryptor.cs index 54c50d08c..c87d5e239 100644 --- a/VpnHood.Tunneling/StreamCryptor.cs +++ b/VpnHood.Tunneling/StreamCryptor.cs @@ -76,7 +76,7 @@ public void Encrypt(byte[] buffer, int offset, int count) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var readCount = await _stream.ReadAsync(buffer, offset, count, cancellationToken); + var readCount = await _stream.ReadAsync(buffer, offset, count, cancellationToken).VhConfigureAwait(); Decrypt(buffer, offset, readCount); return readCount; } @@ -100,8 +100,8 @@ public override async ValueTask DisposeAsync() _bufferCryptor.Dispose(); if (!_leaveOpen) - await _stream.DisposeAsync(); + await _stream.DisposeAsync().VhConfigureAwait(); - await base.DisposeAsync(); + await base.DisposeAsync().VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Tunneling/StreamPacketReader.cs b/VpnHood.Tunneling/StreamPacketReader.cs index 7c2f417f9..e6fb818c2 100644 --- a/VpnHood.Tunneling/StreamPacketReader.cs +++ b/VpnHood.Tunneling/StreamPacketReader.cs @@ -25,7 +25,7 @@ public class StreamPacketReader(Stream stream) : IAsyncDisposable if (_packetBufferCount < minPacketSize) { var toRead = minPacketSize - _packetBufferCount; - var read = await _stream.ReadAsync(_packetBuffer, _packetBufferCount, toRead, cancellationToken); + var read = await _stream.ReadAsync(_packetBuffer, _packetBufferCount, toRead, cancellationToken).VhConfigureAwait(); _packetBufferCount += read; // is eof? @@ -57,7 +57,7 @@ public class StreamPacketReader(Stream stream) : IAsyncDisposable } var toRead = packetLength - _packetBufferCount; - var read = await _stream.ReadAsync(_packetBuffer, _packetBufferCount, toRead, cancellationToken); + var read = await _stream.ReadAsync(_packetBuffer, _packetBufferCount, toRead, cancellationToken).VhConfigureAwait(); _packetBufferCount += read; if (read == 0) throw new Exception("Stream has been unexpectedly closed before reading the rest of packet."); diff --git a/VpnHood.Tunneling/StreamUtil.cs b/VpnHood.Tunneling/StreamUtil.cs index 91489172e..c683528f2 100644 --- a/VpnHood.Tunneling/StreamUtil.cs +++ b/VpnHood.Tunneling/StreamUtil.cs @@ -1,5 +1,6 @@ using System.Text; using System.Text.Json; +using VpnHood.Common.Utils; namespace VpnHood.Tunneling; @@ -15,7 +16,7 @@ public static class StreamUtil CancellationToken cancellationToken) { var buffer = new byte[count]; - if (!await ReadWaitForFillAsync(stream, buffer, 0, buffer.Length, cancellationToken)) + if (!await ReadWaitForFillAsync(stream, buffer, 0, buffer.Length, cancellationToken).VhConfigureAwait()) return null; return buffer; } @@ -41,7 +42,7 @@ public static async Task ReadWaitForFillAsync(Stream stream, byte[] buffer while (totalRead != count) { var read = await stream.ReadAsync(buffer, startIndex + totalRead, count - totalRead, - cancellationToken); + cancellationToken).VhConfigureAwait(); totalRead += read; if (read == 0) return false; @@ -76,7 +77,7 @@ public static T ReadJson(Stream stream, int maxLength = 0xFFFF) public static async Task ReadJsonAsync(Stream stream, CancellationToken cancellationToken, int maxLength = 0xFFFF) { - var message = await ReadMessage(stream, cancellationToken, maxLength); + var message = await ReadMessage(stream, cancellationToken, maxLength).VhConfigureAwait(); var ret = JsonSerializer.Deserialize(message) ?? throw new Exception("Could not read Message!"); return ret; } @@ -85,7 +86,7 @@ public static async Task ReadMessage(Stream stream, CancellationToken ca int maxLength = 0xFFFF) { // read length - var buffer = await ReadWaitForFillAsync(stream, 4, cancellationToken) + var buffer = await ReadWaitForFillAsync(stream, 4, cancellationToken).VhConfigureAwait() ?? throw new Exception("Could not read message."); // check unauthorized exception @@ -102,7 +103,7 @@ public static async Task ReadMessage(Stream stream, CancellationToken ca $"json length is too big! It should be less than {maxLength} bytes but it was {messageSize} bytes"); // read json body... - buffer = await ReadWaitForFillAsync(stream, messageSize, cancellationToken); + buffer = await ReadWaitForFillAsync(stream, messageSize, cancellationToken).VhConfigureAwait(); if (buffer == null) throw new Exception("Could not read Message Length!"); diff --git a/VpnHood.Tunneling/Tunnel.cs b/VpnHood.Tunneling/Tunnel.cs index 8f564c1b7..39b1f9cde 100644 --- a/VpnHood.Tunneling/Tunnel.cs +++ b/VpnHood.Tunneling/Tunnel.cs @@ -237,15 +237,49 @@ private void Channel_OnPacketReceived(object sender, ChannelPacketReceivedEventA } } - public Task SendPacket(IPPacket ipPacket) + public Task SendPacketAsync(IPPacket ipPacket, CancellationToken cancellationToken) { - return SendPackets(new[] { ipPacket }); + return SendPacketsAsync([ipPacket], cancellationToken); } - public async Task SendPackets(IEnumerable ipPackets) + public async Task SendPacketsAsync(IList ipPackets, CancellationToken cancellationToken) + { + if (_disposed) throw new ObjectDisposedException(nameof(Tunnel)); + await WaitForQueueAsync(cancellationToken); + EnqueuePackets(ipPackets); + } + + public void SendPackets(IList ipPackets, CancellationToken cancellationToke) { - var dateTime = FastDateTime.Now; if (_disposed) throw new ObjectDisposedException(nameof(Tunnel)); + WaitForQueue(cancellationToke); + EnqueuePackets(ipPackets); + } + + private async Task WaitForQueueAsync(CancellationToken cancellationToken) + { + var dateTime = FastDateTime.Now; + + // waiting for a space in the packetQueue; the Inconsistently is not important. synchronization may lead to deadlock + // ReSharper disable once InconsistentlySynchronizedField + while (_packetQueue.Count > MaxQueueLength) + { + var releaseCount = DatagramChannelCount - _packetSenderSemaphore.CurrentCount; + if (releaseCount > 0) + _packetSenderSemaphore.Release(releaseCount); // there is some packet + + await _packetSentEvent.WaitAsync(1000, cancellationToken).VhConfigureAwait(); //Wait 1 seconds to prevent deadlock. + if (_disposed) return; + + // check timeout + if (FastDateTime.Now - dateTime > _datagramPacketTimeout) + throw new TimeoutException("Could not send datagram packets."); + } + } + + private void WaitForQueue(CancellationToken cancellationToken) + { + var dateTime = FastDateTime.Now; // waiting for a space in the packetQueue; the Inconsistently is not important. synchronization may lead to deadlock // ReSharper disable once InconsistentlySynchronizedField @@ -254,14 +288,18 @@ public async Task SendPackets(IEnumerable ipPackets) var releaseCount = DatagramChannelCount - _packetSenderSemaphore.CurrentCount; if (releaseCount > 0) _packetSenderSemaphore.Release(releaseCount); // there is some packet - await _packetSentEvent.WaitAsync(1000); //Wait 1000 to prevent deadlock. + + _packetSentEvent.Wait(1000, cancellationToken); //Wait 1 seconds to prevent deadlock. if (_disposed) return; // check timeout if (FastDateTime.Now - dateTime > _datagramPacketTimeout) throw new TimeoutException("Could not send datagram packets."); } + } + private void EnqueuePackets(IList ipPackets) + { // add all packets to the queue lock (_packetQueue) { @@ -347,12 +385,12 @@ private async Task SendPacketTask(IDatagramChannel channel) try { - await channel.SendPacket(packets.ToArray()); + await channel.SendPacket(packets.ToArray()).VhConfigureAwait(); } catch { if (!_disposed) - _ = SendPackets(packets); //resend packets + _ = SendPacketsAsync(packets, CancellationToken.None); //resend packets if (!channel.Connected && !_disposed) throw; // this channel has error @@ -361,7 +399,7 @@ private async Task SendPacketTask(IDatagramChannel channel) // wait for next new packets else { - await _packetSenderSemaphore.WaitAsync(); + await _packetSenderSemaphore.WaitAsync().VhConfigureAwait(); } } // while } @@ -407,7 +445,7 @@ public Task RunJob() private bool _disposed; public async ValueTask DisposeAsync() { - using var lockResult = await _disposeLock.LockAsync(); + using var lockResult = await _disposeLock.LockAsync().VhConfigureAwait(); if (_disposed) return; _disposed = true; @@ -420,7 +458,7 @@ public async ValueTask DisposeAsync() } // Stop speed monitor - await _speedMonitorTimer.DisposeAsync(); + await _speedMonitorTimer.DisposeAsync().VhConfigureAwait(); Speed.Sent = 0; Speed.Received = 0; @@ -429,6 +467,6 @@ public async ValueTask DisposeAsync() _packetSentEvent.Release(); // dispose all channels - await Task.WhenAll(disposeTasks); + await Task.WhenAll(disposeTasks).VhConfigureAwait(); } } \ No newline at end of file diff --git a/VpnHood.Tunneling/UdpProxy.cs b/VpnHood.Tunneling/UdpProxy.cs index 91e2e33cf..607285bd5 100644 --- a/VpnHood.Tunneling/UdpProxy.cs +++ b/VpnHood.Tunneling/UdpProxy.cs @@ -46,7 +46,7 @@ public async Task SendPacket(IPEndPoint ipEndPoint, byte[] datagram, bool? noFra try { - await _sendSemaphore.WaitAsync(); + await _sendSemaphore.WaitAsync().VhConfigureAwait(); if (VhLogger.IsDiagnoseMode) VhLogger.Instance.Log(LogLevel.Information, GeneralEventId.Udp, @@ -56,7 +56,7 @@ public async Task SendPacket(IPEndPoint ipEndPoint, byte[] datagram, bool? noFra if (noFragment != null && ipEndPoint.AddressFamily == AddressFamily.InterNetwork) _udpClient.DontFragment = noFragment.Value; // Never call this for IPv6, it will throw exception for any value - var sentBytes = await _udpClient.SendAsync(datagram, datagram.Length, ipEndPoint); + var sentBytes = await _udpClient.SendAsync(datagram, datagram.Length, ipEndPoint).VhConfigureAwait(); if (sentBytes != datagram.Length) VhLogger.Instance.LogWarning( $"Couldn't send all udp bytes. Requested: {datagram.Length}, Sent: {sentBytes}"); @@ -80,14 +80,14 @@ public async Task Listen() { while (!Disposed) { - var udpResult = await _udpClient.ReceiveAsync(); + var udpResult = await _udpClient.ReceiveAsync().VhConfigureAwait(); LastUsedTime = FastDateTime.Now; // create packet for audience var ipPacket = PacketUtil.CreateUdpPacket(udpResult.RemoteEndPoint, SourceEndPoint, udpResult.Buffer); // send packet to audience - await _packetReceiver.OnPacketReceived(ipPacket); + await _packetReceiver.OnPacketReceived(ipPacket).VhConfigureAwait(); } } diff --git a/VpnHood.Tunneling/UdpProxyEx.cs b/VpnHood.Tunneling/UdpProxyEx.cs index 28aa693f6..55fce935e 100644 --- a/VpnHood.Tunneling/UdpProxyEx.cs +++ b/VpnHood.Tunneling/UdpProxyEx.cs @@ -51,7 +51,7 @@ public async Task SendPacket(IPEndPoint ipEndPoint, byte[] datagram, bool? noFra try { - await _sendSemaphore.WaitAsync(); + await _sendSemaphore.WaitAsync().VhConfigureAwait(); if (VhLogger.IsDiagnoseMode) VhLogger.Instance.Log(LogLevel.Information, GeneralEventId.Udp, @@ -61,7 +61,7 @@ public async Task SendPacket(IPEndPoint ipEndPoint, byte[] datagram, bool? noFra if (noFragment != null && ipEndPoint.AddressFamily == AddressFamily.InterNetwork) _udpClient.DontFragment = noFragment.Value; // Never call this for IPv6, it will throw exception for any value - var sentBytes = await _udpClient.SendAsync(datagram, datagram.Length, ipEndPoint); + var sentBytes = await _udpClient.SendAsync(datagram, datagram.Length, ipEndPoint).VhConfigureAwait(); if (sentBytes != datagram.Length) VhLogger.Instance.LogWarning( $"Couldn't send all udp bytes. Requested: {datagram.Length}, Sent: {sentBytes}"); @@ -85,7 +85,7 @@ public async Task Listen() { while (!Disposed) { - var udpResult = await _udpClient.ReceiveAsync(); + var udpResult = await _udpClient.ReceiveAsync().VhConfigureAwait(); LastUsedTime = FastDateTime.Now; // find the audience @@ -106,7 +106,7 @@ public async Task Listen() PacketUtil.UpdateIpPacket(ipPacket); // send packet to audience - await _packetReceiver.OnPacketReceived(ipPacket); + await _packetReceiver.OnPacketReceived(ipPacket).VhConfigureAwait(); } } diff --git a/VpnHood.Tunneling/Utils/HttpUtil.cs b/VpnHood.Tunneling/Utils/HttpUtil.cs index 17a33f84a..60767b71c 100644 --- a/VpnHood.Tunneling/Utils/HttpUtil.cs +++ b/VpnHood.Tunneling/Utils/HttpUtil.cs @@ -1,5 +1,6 @@ using System.Security.Cryptography; using System.Text; +using VpnHood.Common.Utils; namespace VpnHood.Tunneling.Utils; @@ -18,7 +19,7 @@ public static async Task ReadHeadersAsync(Stream stream, var lfCounter = 0; while (lfCounter < 4) { - var bytesRead = await stream.ReadAsync(readBuffer, 0, 1, cancellationToken); + var bytesRead = await stream.ReadAsync(readBuffer, 0, 1, cancellationToken).VhConfigureAwait(); if (bytesRead == 0) return memStream.Length == 0 ? memStream // connection has been closed gracefully before sending anything @@ -29,7 +30,7 @@ public static async Task ReadHeadersAsync(Stream stream, else lfCounter = 0; - await memStream.WriteAsync(readBuffer, 0, 1, cancellationToken); + await memStream.WriteAsync(readBuffer, 0, 1, cancellationToken).VhConfigureAwait(); if (memStream.Length > maxLength) throw new Exception("HTTP header is too big."); @@ -40,7 +41,7 @@ public static async Task ReadHeadersAsync(Stream stream, } catch { - await memStream.DisposeAsync(); + await memStream.DisposeAsync().VhConfigureAwait(); throw; } } diff --git a/VpnHood.Tunneling/VpnHood.Tunneling.csproj b/VpnHood.Tunneling/VpnHood.Tunneling.csproj index 37e66f91a..21ae14e71 100644 --- a/VpnHood.Tunneling/VpnHood.Tunneling.csproj +++ b/VpnHood.Tunneling/VpnHood.Tunneling.csproj @@ -19,7 +19,7 @@ VpnHood.png https://github.com/vpnhood/vpnhood https://github.com/vpnhood/vpnhood - 4.5.520 + 4.5.522 $([System.DateTime]::Now.ToString("yyyy.M.d.HHmm"))