Skip to content

Commit

Permalink
Cache NodeRefs in grammar_builder (#70)
Browse files Browse the repository at this point in the history
Cache NodeRefs by hashing serialized grammar nodes; add hashing to repeat function
  • Loading branch information
hudson-ai authored Nov 26, 2024
1 parent e79e538 commit 65cbcbe
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions parser/src/grammar_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ pub struct GrammarBuilder {
placeholder: Node,
strings: HashMap<String, NodeRef>,
curr_grammar_id: u32,
node_refs: HashMap<String, NodeRef>,
nodes: Vec<Node>,
pub regex: RegexBuilder,
at_most_cache: HashMap<(NodeRef, usize), NodeRef>,
repeat_exact_cache: HashMap<(NodeRef, usize), NodeRef>,
}

pub struct RegexBuilder {
Expand Down Expand Up @@ -136,8 +139,11 @@ impl GrammarBuilder {
},
strings: HashMap::new(),
curr_grammar_id: 0,
node_refs: HashMap::new(),
nodes: vec![],
regex: RegexBuilder::new(),
at_most_cache: HashMap::new(),
repeat_exact_cache: HashMap::new(),
}
}

Expand Down Expand Up @@ -174,11 +180,27 @@ impl GrammarBuilder {
}

pub fn add_node(&mut self, node: Node) -> NodeRef {
// Generate a key for the node from its serialized form if it is not the placeholder
let key = (node != self.placeholder).then(|| serde_json::to_string(&node).ok()).flatten();

// Return the node reference if it already exists
if let Some(ref key) = key {
if let Some(node_ref) = self.node_refs.get(key) {
return *node_ref;
}
}

// Create new node reference
let r = NodeRef {
idx: self.nodes.len(),
grammar_id: self.curr_grammar_id,
};

// Add the node and store the reference (if it's not the placeholder)
self.nodes.push(node);
if let Some(key) = key {
self.node_refs.insert(key, r);
}
r
}

Expand Down Expand Up @@ -321,7 +343,10 @@ impl GrammarBuilder {
// at_most() recursively factors the sequence into K-size pieces,
// in an attempt to keep grammar size O(log(n)).
fn at_most(&mut self, elt: NodeRef, n: usize) -> NodeRef {
if n == 0 {
if let Some(r) = self.at_most_cache.get(&(elt, n)) {
return *r;
}
let r = if n == 0 {
// If the max ('n') is 0, an empty rule
self.empty()
} else if n == 1 {
Expand Down Expand Up @@ -378,7 +403,9 @@ impl GrammarBuilder {
// (inclusive) in 'elt_n'. Clearly, the sequences of length at most 'n'
// are the alternation of 'elt_max_nk' and 'elt_n'.
self.select(&[elt_n, elt_max_nk])
}
};
self.at_most_cache.insert((elt, n), r);
r
}

// simple_repeat() "simply" repeats the element ('elt') 'n' times.
Expand All @@ -393,7 +420,10 @@ impl GrammarBuilder {
// Repeat element 'elt' exactly 'n' times, using factoring
// in an attempt to keep grammar size O(log(n)).
fn repeat_exact(&mut self, elt: NodeRef, n: usize) -> NodeRef {
if n > 2 * K {
if let Some(r) = self.repeat_exact_cache.get(&(elt, n)) {
return *r;
}
let r = if n > 2 * K {
// For large 'n', try to keep the number of rules O(log(n))
// by "factoring" the sequence into K-sized pieces

Expand All @@ -418,7 +448,9 @@ impl GrammarBuilder {
// For small 'n' (currently, 8 or less), simply
// repeat 'elt' 'n' times.
self.simple_repeat(elt, n)
}
};
self.repeat_exact_cache.insert((elt, n), r);
r
}

// at_least() accepts a sequence of at least 'n' copies of
Expand Down

0 comments on commit 65cbcbe

Please sign in to comment.