Skip to content

Commit

Permalink
Refactor to use PatriciaTrie instead of simple binary tree
Browse files Browse the repository at this point in the history
  • Loading branch information
kingzacko1 committed Feb 6, 2025
1 parent 3670bc2 commit f9433fd
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import org.graylog2.plugin.lookup.LookupDataAdapterConfiguration;
import org.graylog2.plugin.lookup.LookupResult;
import org.graylog2.plugin.utilities.FileInfo;
import org.graylog2.utilities.CIDRLookupTrie;
import org.graylog2.utilities.CIDRPatriciaTrie;
import org.graylog2.utilities.IpSubnet;
import org.graylog2.utilities.ReservedIpChecker;
import org.joda.time.Duration;
Expand Down Expand Up @@ -86,7 +86,7 @@ public class CSVFileDataAdapter extends LookupDataAdapter {
private final Config config;
private final AllowedAuxiliaryPathChecker pathChecker;
private final AtomicReference<Map<String, String>> lookupRef = new AtomicReference<>(ImmutableMap.of());
private final AtomicReference<CIDRLookupTrie> cidrLookupRef = new AtomicReference<>(new CIDRLookupTrie());
private final AtomicReference<CIDRPatriciaTrie> cidrLookupRef = new AtomicReference<>(new CIDRPatriciaTrie());
private final String name;

private FileInfo fileInfo = FileInfo.empty();
Expand Down Expand Up @@ -168,7 +168,7 @@ private void setLookupRefFromCSV() throws IOException {
final InputStream inputStream = Files.newInputStream(Paths.get(config.path()));
final InputStreamReader fileReader = new InputStreamReader(inputStream, StandardCharsets.UTF_8);
final ImmutableMap.Builder<String, String> newLookupBuilder = ImmutableMap.builder();
final CIDRLookupTrie cidrLookupTrie = new CIDRLookupTrie();
final CIDRPatriciaTrie cidrLookupTrie = new CIDRPatriciaTrie();

try (final CSVReader csvReader = new CSVReader(fileReader, config.separatorAsChar(), config.quotecharAsChar())) {
int line = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*
* Copyright (C) 2020 Graylog, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*/
package org.graylog2.utilities;

import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.collections4.trie.PatriciaTrie;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;

import java.util.Locale;
import java.util.Map;

/**
* PatriciaTrie used to for efficient lookups in CIDR data adapters.
* NOTE: This class is NOT thread-safe. Use {@link #cleanCopy()} to clone the trie, make modifications, and then atomically
* replace the in-use copy if needed.
*/
public class CIDRPatriciaTrie {
private record Node(
// The lookup value of the range
String rangeName,
// Whether the range is an IPv4 or IPv6 CIDR range
boolean rangeIsIPv6,
// Time in millis after which the node is considered expired
long expireAfter) {
}

private PatriciaTrie<Node> trie = new PatriciaTrie<>();
private int shortestV4Prefix = -1;
private int shortestV6Prefix = -1;

@VisibleForTesting
boolean isEmpty() {
return trie.isEmpty();
}

/**
* Returns a deep copy of this CIDRPatriciaTrie with any expired nodes removed.
*
* @return deep copy of this trie
*/
public CIDRPatriciaTrie cleanCopy() {
final long now = DateTime.now(DateTimeZone.UTC).getMillis();
final PatriciaTrie<Node> cleanTrie = new PatriciaTrie<>();
int shortestV6 = -1;
int shortestV4 = -1;
for (Map.Entry<String, Node> entry : trie.entrySet()) {
final Node data = entry.getValue();
if (data.expireAfter == 0L || data.expireAfter > now) {
final int prefixLength = entry.getKey().length();
if (data.rangeIsIPv6 && (shortestV6 == -1 || prefixLength < shortestV6)) {
shortestV6 = prefixLength;
} else if (!data.rangeIsIPv6 && (shortestV4 == -1 || prefixLength < shortestV4)) {
shortestV4 = prefixLength;
}
cleanTrie.put(entry.getKey(), new Node(data.rangeName, data.rangeIsIPv6, data.expireAfter));
}
}
final CIDRPatriciaTrie copy = new CIDRPatriciaTrie();
copy.trie = cleanTrie;
copy.shortestV4Prefix = shortestV4;
copy.shortestV6Prefix = shortestV6;
return copy;
}

public void insertCIDR(String cidr, String rangeName) {
insertCIDR(cidr, rangeName, 0L);
}

/**
* Insert a CIDR range into the trie with a time-to-live
*
* @param cidr properly formatted CIDR address (must include '/rangePrefix' even if it is a single address
* @param rangeName the name of the CIDR range
* @param expireAfter epoch time in millis after which the CIDR should be expired
*/
public void insertCIDR(String cidr, String rangeName, long expireAfter) {
final String[] parts = cidr.split("/");
final String ip = parts[0];
final int prefixLength;
try {
prefixLength = Integer.parseInt(parts[1]);
} catch (ArrayIndexOutOfBoundsException e) {
throw new IllegalArgumentException("Unable to parse invalid CIDR range: " + cidr);
}

// Get binary representation of the IP
final String binaryIP = toBinaryString(ip, prefixLength);
final boolean isIPV6 = ip.contains(":");
final Node node = new Node(rangeName, isIPV6, expireAfter);
trie.put(binaryIP, node);
if (isIPV6 && (shortestV6Prefix == -1 || prefixLength < shortestV6Prefix)) {
this.shortestV6Prefix = prefixLength;
} else if (!isIPV6 && (shortestV4Prefix == -1 || prefixLength < shortestV4Prefix)) {
this.shortestV4Prefix = prefixLength;
}
}

public String longestPrefixRangeLookup(String ip) {
return longestPrefixRangeLookupWithTtl(ip, 0L);
}

/**
* Returns the rangeName of the range with the longest prefix that contains the IP address or null if one does not
* exist.
*
* @param ip IP address to check against the collection of ranges
* @param lookupTimeMillis time lookup was performed in epoch time milliseconds or 0 if node expiry is not a concern
* @return the name of the range with the longest prefix that contains the IP if it exists, null otherwise
*/
public String longestPrefixRangeLookupWithTtl(String ip, long lookupTimeMillis) {
if (isEmpty()) {
return null;
}
final String binaryIP = toBinaryString(ip, -1);
final boolean lookupIsIPv6 = ip.contains(":");

final int shortestPrefixForType = lookupIsIPv6 ? shortestV6Prefix : shortestV4Prefix;
for (int i = binaryIP.length(); i >= shortestPrefixForType; i--) {
final String lookupPrefix = binaryIP.substring(0, i);
final Map<String, Node> prefixTrie = trie.prefixMap(lookupPrefix);
for (Map.Entry<String, Node> entry : prefixTrie.entrySet()) {
final Node rangeData = entry.getValue();
final String binaryCidr = entry.getKey();
if (lookupIsIPv6 == rangeData.rangeIsIPv6 &&
(rangeData.expireAfter == 0L || rangeData.expireAfter > lookupTimeMillis) &&
binaryIP.startsWith(binaryCidr)) {
return rangeData.rangeName;
}
}
}

return null;
}

/**
* Remove a CIDR range from the trie and cleanup any empty nodes after removal.
*
* @param cidr range to remove
*/
public void removeCIDR(String cidr) {
final String[] parts = cidr.split("/");
final String ip = parts[0];
final int prefixLength = Integer.parseInt(parts[1]);
final String binaryIP = toBinaryString(ip, prefixLength);
final Node removedNode = trie.remove(binaryIP);
if (removedNode != null) {
final boolean isIPV6 = ip.contains(":");
if ((isIPV6 && prefixLength == shortestV6Prefix) || (!isIPV6 && prefixLength == shortestV4Prefix)) {
recalculateShortestPrefix(isIPV6);
}
}
}

public void recalculateShortestPrefix(boolean isIPV6) {
if (trie.isEmpty()) {
this.shortestV4Prefix = -1;
this.shortestV6Prefix = -1;
} else {
int shortest = -1;
long now = DateTime.now(DateTimeZone.UTC).getMillis();
for (Map.Entry<String, Node> entry : trie.entrySet()) {
final Node rangeData = entry.getValue();
if (rangeData.rangeIsIPv6 == isIPV6 // same type of address
&& (shortest == -1 || entry.getKey().length() < shortest) // shorter prefix than we've seen
&& (rangeData.expireAfter == 0L || rangeData.expireAfter > now)) { // not expired
shortest = entry.getKey().length();
}
}
if (isIPV6) {
this.shortestV6Prefix = shortest;
} else {
this.shortestV4Prefix = shortest;
}
}
}

// Convert an IP address to a binary string (supports both IPv4 and IPv6)
static String toBinaryString(String ip, int prefixLength) {
final boolean isIPv6 = ip.contains(":");
try {
StringBuilder binary = new StringBuilder();
if (!isIPv6) {
// IPv4
String[] octets = ip.split("\\.");
for (String octet : octets) {
String binaryOctet = String.format(Locale.ROOT, "%8s", Integer.toBinaryString(Integer.parseInt(octet))).replace(' ', '0');
binary.append(binaryOctet);
}
return prefixLength > 0 ? binary.substring(0, prefixLength) : binary.toString();
} else {
// IPv6
String[] hextets = ip.split(":");
for (String hextet : hextets) {
if (!hextet.isEmpty()) {
String binaryHextet = String.format(Locale.ROOT, "%16s", Integer.toBinaryString(Integer.parseInt(hextet, 16))).replace(' ', '0');
binary.append(binaryHextet);
} else {
// Handle "::" shorthand for consecutive zero groups
int missingGroups = 8 - hextets.length + 1;
binary.append("0000000000000000".repeat(Math.max(0, missingGroups)));
}
}
// If the prefix length is larger than the resulting binary string, append 0 until the length matches. This
// will avoid index out of range exceptions when inserting the range into the trie.
if (binary.length() < prefixLength) {
binary.append("0".repeat(prefixLength - binary.length()));
}
// When getting binary string of individual IPv6 addresses, ensure binary string is complete 128 digits.
if (prefixLength == -1) {
binary.append("0".repeat(Math.max(0, 128 - binary.length())));
} else if (binary.length() > prefixLength) {
return binary.substring(0, prefixLength);
}
return binary.toString();
}
} catch (Exception e) {
throw new IllegalArgumentException("Invalid IP address format: " + ip);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (C) 2020 Graylog, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*/
package org.graylog2.utilities;

import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class CIDRPatriciaTrieTest {

@Test
public void testLookups() {
final CIDRPatriciaTrie trie = buildTrie();

assertThat(trie).satisfies(t -> {
assertThat(t.longestPrefixRangeLookup("192.168.1.100")).isEqualTo("IPv4 Range 1");
assertThat(t.longestPrefixRangeLookup("10.0.5.1")).isEqualTo("IPv4 Range 2");
assertThat(t.longestPrefixRangeLookup("35.139.253.123")).isEqualTo("IPv4 Range 3");
assertThat(t.longestPrefixRangeLookup("192.168.102.8")).isEqualTo("HR Subnet 1");
assertThat(t.longestPrefixRangeLookup("192.168.102.22")).isEqualTo("HR Subnet 2");
assertThat(t.longestPrefixRangeLookup("192.168.102.40")).isEqualTo("HR Subnet 3");
assertThat(t.longestPrefixRangeLookup("172.16.5.4")).isNull();
assertThat(t.longestPrefixRangeLookup("2001:db8:abcd::1")).isEqualTo("IPv6 Range 1");
assertThat(t.longestPrefixRangeLookup("2404:6800:4001:abcd::1")).isEqualTo("IPv6 Range 2");
assertThat(t.longestPrefixRangeLookup("8dbf:88a6:2000:4ddc:f708:cf8d:f2a5:a420")).isEqualTo("IPv6 Range 3");
assertThat(t.longestPrefixRangeLookup("77f:8b7a:3e82:6fb3:ba15:9b68:7fe0:a695")).isEqualTo("IPv6 Range 4");
assertThat(t.longestPrefixRangeLookup("2001:db7::")).isEqualTo("Single IPv6");
assertThat(t.longestPrefixRangeLookup("2607:f8b0:4001:c01::")).isNull();
});
}

@Test
public void testRemoval() {
final CIDRPatriciaTrie trie = buildTrie();

assertThat(trie.longestPrefixRangeLookup("8dbf:88a6:2000:4ddc:f708:cf8d:f2a5:a420")).isEqualTo("IPv6 Range 3");
trie.removeCIDR("8dbf:8000::/19");
// CIDR is no longer in lookup
assertThat(trie.longestPrefixRangeLookup("8dbf:88a6:2000:4ddc:f708:cf8d:f2a5:a420")).isNull();
// Confirm other lookups still work
assertThat(trie.longestPrefixRangeLookup("77f:8b7a:3e82:6fb3:ba15:9b68:7fe0:a695")).isEqualTo("IPv6 Range 4");
assertThat(trie.longestPrefixRangeLookup("192.168.1.100")).isEqualTo("IPv4 Range 1");
assertThat(trie.longestPrefixRangeLookup("10.0.5.1")).isEqualTo("IPv4 Range 2");
assertThat(trie.longestPrefixRangeLookup("35.139.253.123")).isEqualTo("IPv4 Range 3");
assertThat(trie.longestPrefixRangeLookup("2001:db8:abcd::1")).isEqualTo("IPv6 Range 1");
assertThat(trie.longestPrefixRangeLookup("2404:6800:4001:abcd::1")).isEqualTo("IPv6 Range 2");
assertThat(trie.longestPrefixRangeLookup("77f:8b7a:3e82:6fb3:ba15:9b68:7fe0:a695")).isEqualTo("IPv6 Range 4");
}

@Test
public void testBadEntry() {
final CIDRPatriciaTrie trie = new CIDRPatriciaTrie();
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> trie.insertCIDR("127.a.3.21/12", "Bad Range 1"));
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> trie.insertCIDR("not.an.ip.address/12", "Bad Range 2"));
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> trie.insertCIDR("127.0.0.0", "Bad Range 3"));
}

@Test
public void testCleanCopy() throws InterruptedException {
final CIDRPatriciaTrie trie = new CIDRPatriciaTrie();
final long expireAt = DateTime.now(DateTimeZone.UTC).getMillis() + 500L;
trie.insertCIDR("192.168.1.0/24", "IPv4 Range 1", expireAt);
trie.insertCIDR("10.0.0.0/8", "IPv4 Range 2", expireAt);
trie.insertCIDR("35.138.0.0/15", "IPv4 Range 3", expireAt);
Thread.sleep(1000L);
final CIDRPatriciaTrie copy = trie.cleanCopy();
assertThat(copy.isEmpty()).isTrue();
}

@Test
public void testToBinaryIP() {
String cidrBinary = CIDRPatriciaTrie.toBinaryString("2002:0000:0000:1234:0000:0000:0000:0000", 64);
assertThat(cidrBinary).isEqualTo("0010000000000010000000000000000000000000000000000001001000110100");
cidrBinary = CIDRPatriciaTrie.toBinaryString("192.168.103.16", -1);
assertThat(cidrBinary).isEqualTo("11000000101010000110011100010000");
cidrBinary = CIDRPatriciaTrie.toBinaryString("2001:db7::", -1);
assertThat(cidrBinary).isEqualTo("00100000000000010000110110110111000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000");
cidrBinary = CIDRPatriciaTrie.toBinaryString("192.168.102.0", 24);
assertThat(cidrBinary).isEqualTo("110000001010100001100110");

}

private static CIDRPatriciaTrie buildTrie() {
final CIDRPatriciaTrie trie = new CIDRPatriciaTrie();
trie.insertCIDR("192.168.1.0/24", "IPv4 Range 1");
trie.insertCIDR("10.0.0.0/8", "IPv4 Range 2");
trie.insertCIDR("35.138.0.0/15", "IPv4 Range 3");
trie.insertCIDR("192.168.102.0/24", "HR");
trie.insertCIDR("192.168.102.0/28", "HR Subnet 1");
trie.insertCIDR("192.168.102.16/28", "HR Subnet 2");
trie.insertCIDR("192.168.102.32/28", "HR Subnet 3");
trie.insertCIDR("2001:db8::/32", "IPv6 Range 1");
trie.insertCIDR("2404:6800:4001::/48", "IPv6 Range 2");
trie.insertCIDR("8dbf:8000::/19", "IPv6 Range 3");
trie.insertCIDR("77f::/16", "IPv6 Range 4");
trie.insertCIDR("17c5:b180::/35", "IPv6 Range 5");
trie.insertCIDR("2001:db7::/128","Single IPv6");
return trie;
}
}

0 comments on commit f9433fd

Please sign in to comment.