diff --git a/graylog2-server/src/main/java/org/graylog2/lookup/adapters/CSVFileDataAdapter.java b/graylog2-server/src/main/java/org/graylog2/lookup/adapters/CSVFileDataAdapter.java index fbdf096d9378..5e2a9d53d6fc 100644 --- a/graylog2-server/src/main/java/org/graylog2/lookup/adapters/CSVFileDataAdapter.java +++ b/graylog2-server/src/main/java/org/graylog2/lookup/adapters/CSVFileDataAdapter.java @@ -43,7 +43,7 @@ import org.graylog2.plugin.lookup.LookupDataAdapterConfiguration; import org.graylog2.plugin.lookup.LookupResult; import org.graylog2.plugin.utilities.FileInfo; -import org.graylog2.utilities.CIDRLookupTrie; +import org.graylog2.utilities.CIDRPatriciaTrie; import org.graylog2.utilities.IpSubnet; import org.graylog2.utilities.ReservedIpChecker; import org.joda.time.Duration; @@ -86,7 +86,7 @@ public class CSVFileDataAdapter extends LookupDataAdapter { private final Config config; private final AllowedAuxiliaryPathChecker pathChecker; private final AtomicReference> lookupRef = new AtomicReference<>(ImmutableMap.of()); - private final AtomicReference cidrLookupRef = new AtomicReference<>(new CIDRLookupTrie()); + private final AtomicReference cidrLookupRef = new AtomicReference<>(new CIDRPatriciaTrie()); private final String name; private FileInfo fileInfo = FileInfo.empty(); @@ -168,7 +168,7 @@ private void setLookupRefFromCSV() throws IOException { final InputStream inputStream = Files.newInputStream(Paths.get(config.path())); final InputStreamReader fileReader = new InputStreamReader(inputStream, StandardCharsets.UTF_8); final ImmutableMap.Builder newLookupBuilder = ImmutableMap.builder(); - final CIDRLookupTrie cidrLookupTrie = new CIDRLookupTrie(); + final CIDRPatriciaTrie cidrLookupTrie = new CIDRPatriciaTrie(); try (final CSVReader csvReader = new CSVReader(fileReader, config.separatorAsChar(), config.quotecharAsChar())) { int line = 0; diff --git a/graylog2-server/src/main/java/org/graylog2/utilities/CIDRPatriciaTrie.java b/graylog2-server/src/main/java/org/graylog2/utilities/CIDRPatriciaTrie.java new file mode 100644 index 000000000000..1ac51bcb8c37 --- /dev/null +++ b/graylog2-server/src/main/java/org/graylog2/utilities/CIDRPatriciaTrie.java @@ -0,0 +1,235 @@ +/* + * Copyright (C) 2020 Graylog, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + */ +package org.graylog2.utilities; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.collections4.trie.PatriciaTrie; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; + +import java.util.Locale; +import java.util.Map; + +/** + * PatriciaTrie used to for efficient lookups in CIDR data adapters. + * NOTE: This class is NOT thread-safe. Use {@link #cleanCopy()} to clone the trie, make modifications, and then atomically + * replace the in-use copy if needed. + */ +public class CIDRPatriciaTrie { + private record Node( + // The lookup value of the range + String rangeName, + // Whether the range is an IPv4 or IPv6 CIDR range + boolean rangeIsIPv6, + // Time in millis after which the node is considered expired + long expireAfter) { + } + + private PatriciaTrie trie = new PatriciaTrie<>(); + private int shortestV4Prefix = -1; + private int shortestV6Prefix = -1; + + @VisibleForTesting + boolean isEmpty() { + return trie.isEmpty(); + } + + /** + * Returns a deep copy of this CIDRPatriciaTrie with any expired nodes removed. + * + * @return deep copy of this trie + */ + public CIDRPatriciaTrie cleanCopy() { + final long now = DateTime.now(DateTimeZone.UTC).getMillis(); + final PatriciaTrie cleanTrie = new PatriciaTrie<>(); + int shortestV6 = -1; + int shortestV4 = -1; + for (Map.Entry entry : trie.entrySet()) { + final Node data = entry.getValue(); + if (data.expireAfter == 0L || data.expireAfter > now) { + final int prefixLength = entry.getKey().length(); + if (data.rangeIsIPv6 && (shortestV6 == -1 || prefixLength < shortestV6)) { + shortestV6 = prefixLength; + } else if (!data.rangeIsIPv6 && (shortestV4 == -1 || prefixLength < shortestV4)) { + shortestV4 = prefixLength; + } + cleanTrie.put(entry.getKey(), new Node(data.rangeName, data.rangeIsIPv6, data.expireAfter)); + } + } + final CIDRPatriciaTrie copy = new CIDRPatriciaTrie(); + copy.trie = cleanTrie; + copy.shortestV4Prefix = shortestV4; + copy.shortestV6Prefix = shortestV6; + return copy; + } + + public void insertCIDR(String cidr, String rangeName) { + insertCIDR(cidr, rangeName, 0L); + } + + /** + * Insert a CIDR range into the trie with a time-to-live + * + * @param cidr properly formatted CIDR address (must include '/rangePrefix' even if it is a single address + * @param rangeName the name of the CIDR range + * @param expireAfter epoch time in millis after which the CIDR should be expired + */ + public void insertCIDR(String cidr, String rangeName, long expireAfter) { + final String[] parts = cidr.split("/"); + final String ip = parts[0]; + final int prefixLength; + try { + prefixLength = Integer.parseInt(parts[1]); + } catch (ArrayIndexOutOfBoundsException e) { + throw new IllegalArgumentException("Unable to parse invalid CIDR range: " + cidr); + } + + // Get binary representation of the IP + final String binaryIP = toBinaryString(ip, prefixLength); + final boolean isIPV6 = ip.contains(":"); + final Node node = new Node(rangeName, isIPV6, expireAfter); + trie.put(binaryIP, node); + if (isIPV6 && (shortestV6Prefix == -1 || prefixLength < shortestV6Prefix)) { + this.shortestV6Prefix = prefixLength; + } else if (!isIPV6 && (shortestV4Prefix == -1 || prefixLength < shortestV4Prefix)) { + this.shortestV4Prefix = prefixLength; + } + } + + public String longestPrefixRangeLookup(String ip) { + return longestPrefixRangeLookupWithTtl(ip, 0L); + } + + /** + * Returns the rangeName of the range with the longest prefix that contains the IP address or null if one does not + * exist. + * + * @param ip IP address to check against the collection of ranges + * @param lookupTimeMillis time lookup was performed in epoch time milliseconds or 0 if node expiry is not a concern + * @return the name of the range with the longest prefix that contains the IP if it exists, null otherwise + */ + public String longestPrefixRangeLookupWithTtl(String ip, long lookupTimeMillis) { + if (isEmpty()) { + return null; + } + final String binaryIP = toBinaryString(ip, -1); + final boolean lookupIsIPv6 = ip.contains(":"); + + final int shortestPrefixForType = lookupIsIPv6 ? shortestV6Prefix : shortestV4Prefix; + for (int i = binaryIP.length(); i >= shortestPrefixForType; i--) { + final String lookupPrefix = binaryIP.substring(0, i); + final Map prefixTrie = trie.prefixMap(lookupPrefix); + for (Map.Entry entry : prefixTrie.entrySet()) { + final Node rangeData = entry.getValue(); + final String binaryCidr = entry.getKey(); + if (lookupIsIPv6 == rangeData.rangeIsIPv6 && + (rangeData.expireAfter == 0L || rangeData.expireAfter > lookupTimeMillis) && + binaryIP.startsWith(binaryCidr)) { + return rangeData.rangeName; + } + } + } + + return null; + } + + /** + * Remove a CIDR range from the trie and cleanup any empty nodes after removal. + * + * @param cidr range to remove + */ + public void removeCIDR(String cidr) { + final String[] parts = cidr.split("/"); + final String ip = parts[0]; + final int prefixLength = Integer.parseInt(parts[1]); + final String binaryIP = toBinaryString(ip, prefixLength); + final Node removedNode = trie.remove(binaryIP); + if (removedNode != null) { + final boolean isIPV6 = ip.contains(":"); + if ((isIPV6 && prefixLength == shortestV6Prefix) || (!isIPV6 && prefixLength == shortestV4Prefix)) { + recalculateShortestPrefix(isIPV6); + } + } + } + + public void recalculateShortestPrefix(boolean isIPV6) { + if (trie.isEmpty()) { + this.shortestV4Prefix = -1; + this.shortestV6Prefix = -1; + } else { + int shortest = -1; + long now = DateTime.now(DateTimeZone.UTC).getMillis(); + for (Map.Entry entry : trie.entrySet()) { + final Node rangeData = entry.getValue(); + if (rangeData.rangeIsIPv6 == isIPV6 // same type of address + && (shortest == -1 || entry.getKey().length() < shortest) // shorter prefix than we've seen + && (rangeData.expireAfter == 0L || rangeData.expireAfter > now)) { // not expired + shortest = entry.getKey().length(); + } + } + if (isIPV6) { + this.shortestV6Prefix = shortest; + } else { + this.shortestV4Prefix = shortest; + } + } + } + + // Convert an IP address to a binary string (supports both IPv4 and IPv6) + static String toBinaryString(String ip, int prefixLength) { + final boolean isIPv6 = ip.contains(":"); + try { + StringBuilder binary = new StringBuilder(); + if (!isIPv6) { + // IPv4 + String[] octets = ip.split("\\."); + for (String octet : octets) { + String binaryOctet = String.format(Locale.ROOT, "%8s", Integer.toBinaryString(Integer.parseInt(octet))).replace(' ', '0'); + binary.append(binaryOctet); + } + return prefixLength > 0 ? binary.substring(0, prefixLength) : binary.toString(); + } else { + // IPv6 + String[] hextets = ip.split(":"); + for (String hextet : hextets) { + if (!hextet.isEmpty()) { + String binaryHextet = String.format(Locale.ROOT, "%16s", Integer.toBinaryString(Integer.parseInt(hextet, 16))).replace(' ', '0'); + binary.append(binaryHextet); + } else { + // Handle "::" shorthand for consecutive zero groups + int missingGroups = 8 - hextets.length + 1; + binary.append("0000000000000000".repeat(Math.max(0, missingGroups))); + } + } + // If the prefix length is larger than the resulting binary string, append 0 until the length matches. This + // will avoid index out of range exceptions when inserting the range into the trie. + if (binary.length() < prefixLength) { + binary.append("0".repeat(prefixLength - binary.length())); + } + // When getting binary string of individual IPv6 addresses, ensure binary string is complete 128 digits. + if (prefixLength == -1) { + binary.append("0".repeat(Math.max(0, 128 - binary.length()))); + } else if (binary.length() > prefixLength) { + return binary.substring(0, prefixLength); + } + return binary.toString(); + } + } catch (Exception e) { + throw new IllegalArgumentException("Invalid IP address format: " + ip); + } + } +} diff --git a/graylog2-server/src/test/java/org/graylog2/utilities/CIDRPatriciaTrieTest.java b/graylog2-server/src/test/java/org/graylog2/utilities/CIDRPatriciaTrieTest.java new file mode 100644 index 000000000000..eadef91f896a --- /dev/null +++ b/graylog2-server/src/test/java/org/graylog2/utilities/CIDRPatriciaTrieTest.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2020 Graylog, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + */ +package org.graylog2.utilities; + +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +public class CIDRPatriciaTrieTest { + + @Test + public void testLookups() { + final CIDRPatriciaTrie trie = buildTrie(); + + assertThat(trie).satisfies(t -> { + assertThat(t.longestPrefixRangeLookup("192.168.1.100")).isEqualTo("IPv4 Range 1"); + assertThat(t.longestPrefixRangeLookup("10.0.5.1")).isEqualTo("IPv4 Range 2"); + assertThat(t.longestPrefixRangeLookup("35.139.253.123")).isEqualTo("IPv4 Range 3"); + assertThat(t.longestPrefixRangeLookup("192.168.102.8")).isEqualTo("HR Subnet 1"); + assertThat(t.longestPrefixRangeLookup("192.168.102.22")).isEqualTo("HR Subnet 2"); + assertThat(t.longestPrefixRangeLookup("192.168.102.40")).isEqualTo("HR Subnet 3"); + assertThat(t.longestPrefixRangeLookup("172.16.5.4")).isNull(); + assertThat(t.longestPrefixRangeLookup("2001:db8:abcd::1")).isEqualTo("IPv6 Range 1"); + assertThat(t.longestPrefixRangeLookup("2404:6800:4001:abcd::1")).isEqualTo("IPv6 Range 2"); + assertThat(t.longestPrefixRangeLookup("8dbf:88a6:2000:4ddc:f708:cf8d:f2a5:a420")).isEqualTo("IPv6 Range 3"); + assertThat(t.longestPrefixRangeLookup("77f:8b7a:3e82:6fb3:ba15:9b68:7fe0:a695")).isEqualTo("IPv6 Range 4"); + assertThat(t.longestPrefixRangeLookup("2001:db7::")).isEqualTo("Single IPv6"); + assertThat(t.longestPrefixRangeLookup("2607:f8b0:4001:c01::")).isNull(); + }); + } + + @Test + public void testRemoval() { + final CIDRPatriciaTrie trie = buildTrie(); + + assertThat(trie.longestPrefixRangeLookup("8dbf:88a6:2000:4ddc:f708:cf8d:f2a5:a420")).isEqualTo("IPv6 Range 3"); + trie.removeCIDR("8dbf:8000::/19"); + // CIDR is no longer in lookup + assertThat(trie.longestPrefixRangeLookup("8dbf:88a6:2000:4ddc:f708:cf8d:f2a5:a420")).isNull(); + // Confirm other lookups still work + assertThat(trie.longestPrefixRangeLookup("77f:8b7a:3e82:6fb3:ba15:9b68:7fe0:a695")).isEqualTo("IPv6 Range 4"); + assertThat(trie.longestPrefixRangeLookup("192.168.1.100")).isEqualTo("IPv4 Range 1"); + assertThat(trie.longestPrefixRangeLookup("10.0.5.1")).isEqualTo("IPv4 Range 2"); + assertThat(trie.longestPrefixRangeLookup("35.139.253.123")).isEqualTo("IPv4 Range 3"); + assertThat(trie.longestPrefixRangeLookup("2001:db8:abcd::1")).isEqualTo("IPv6 Range 1"); + assertThat(trie.longestPrefixRangeLookup("2404:6800:4001:abcd::1")).isEqualTo("IPv6 Range 2"); + assertThat(trie.longestPrefixRangeLookup("77f:8b7a:3e82:6fb3:ba15:9b68:7fe0:a695")).isEqualTo("IPv6 Range 4"); + } + + @Test + public void testBadEntry() { + final CIDRPatriciaTrie trie = new CIDRPatriciaTrie(); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> trie.insertCIDR("127.a.3.21/12", "Bad Range 1")); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> trie.insertCIDR("not.an.ip.address/12", "Bad Range 2")); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> trie.insertCIDR("127.0.0.0", "Bad Range 3")); + } + + @Test + public void testCleanCopy() throws InterruptedException { + final CIDRPatriciaTrie trie = new CIDRPatriciaTrie(); + final long expireAt = DateTime.now(DateTimeZone.UTC).getMillis() + 500L; + trie.insertCIDR("192.168.1.0/24", "IPv4 Range 1", expireAt); + trie.insertCIDR("10.0.0.0/8", "IPv4 Range 2", expireAt); + trie.insertCIDR("35.138.0.0/15", "IPv4 Range 3", expireAt); + Thread.sleep(1000L); + final CIDRPatriciaTrie copy = trie.cleanCopy(); + assertThat(copy.isEmpty()).isTrue(); + } + + @Test + public void testToBinaryIP() { + String cidrBinary = CIDRPatriciaTrie.toBinaryString("2002:0000:0000:1234:0000:0000:0000:0000", 64); + assertThat(cidrBinary).isEqualTo("0010000000000010000000000000000000000000000000000001001000110100"); + cidrBinary = CIDRPatriciaTrie.toBinaryString("192.168.103.16", -1); + assertThat(cidrBinary).isEqualTo("11000000101010000110011100010000"); + cidrBinary = CIDRPatriciaTrie.toBinaryString("2001:db7::", -1); + assertThat(cidrBinary).isEqualTo("00100000000000010000110110110111000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"); + cidrBinary = CIDRPatriciaTrie.toBinaryString("192.168.102.0", 24); + assertThat(cidrBinary).isEqualTo("110000001010100001100110"); + + } + + private static CIDRPatriciaTrie buildTrie() { + final CIDRPatriciaTrie trie = new CIDRPatriciaTrie(); + trie.insertCIDR("192.168.1.0/24", "IPv4 Range 1"); + trie.insertCIDR("10.0.0.0/8", "IPv4 Range 2"); + trie.insertCIDR("35.138.0.0/15", "IPv4 Range 3"); + trie.insertCIDR("192.168.102.0/24", "HR"); + trie.insertCIDR("192.168.102.0/28", "HR Subnet 1"); + trie.insertCIDR("192.168.102.16/28", "HR Subnet 2"); + trie.insertCIDR("192.168.102.32/28", "HR Subnet 3"); + trie.insertCIDR("2001:db8::/32", "IPv6 Range 1"); + trie.insertCIDR("2404:6800:4001::/48", "IPv6 Range 2"); + trie.insertCIDR("8dbf:8000::/19", "IPv6 Range 3"); + trie.insertCIDR("77f::/16", "IPv6 Range 4"); + trie.insertCIDR("17c5:b180::/35", "IPv6 Range 5"); + trie.insertCIDR("2001:db7::/128","Single IPv6"); + return trie; + } +}