Skip to content

Commit

Permalink
Remove Hash and Eq from AstNodeRef for types not implementing `…
Browse files Browse the repository at this point in the history
…Eq` or `Hash` (#16100)

## Summary

This is a follow up to
#15763 (comment)

It reverts the change to using ptr equality for `AstNodeRef`s, which in
turn removes the `Eq`, `PartialEq`, and `Hash` implementations for
`AstNodeRef`s parametrized with AST nodes.
Cheap comparisons shouldn't be needed because the node field is
generally marked as `[#tracked]` and `#[no_eq]` and removing the
implementations even enforces that those
attributes are set on all `AstNodeRef` fields (which is good).

The only downside this has is that we technically wouldn't have to mark
the `Unpack::target` as `#[tracked]` because
the `target` field is accessed in every query accepting `Unpack` as an
argument.

Overall, enforcing the use of `#[tracked]` seems like a good trade off,
espacially considering that it's very likely that
we'd probably forget to mark the `Unpack::target` field as tracked if we
add a new `Unpack` query that doesn't access the target.

## Test Plan

`cargo test`
  • Loading branch information
MichaReiser authored Feb 11, 2025
1 parent ce31c26 commit 9c17931
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 20 deletions.
47 changes: 35 additions & 12 deletions crates/red_knot_python_semantic/src/ast_node_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ use ruff_db::parsed::ParsedModule;
/// This means that changes to expressions in other scopes don't invalidate the expression's id, giving
/// us some form of scope-stable identity for expressions. Only queries accessing the node field
/// run on every AST change. All other queries only run when the expression's identity changes.
///
/// The one exception to this is if it is known that all queries tacking the tracked struct
/// as argument or returning it as part of their result are known to access the node field.
/// Marking the field tracked is then unnecessary.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
_parsed: ParsedModule,
parsed: ParsedModule,

/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
Expand All @@ -59,7 +55,7 @@ impl<T> AstNodeRef<T> {
/// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
_parsed: parsed,
parsed,
node: std::ptr::NonNull::from(node),
}
}
Expand Down Expand Up @@ -89,17 +85,44 @@ where
}
}

impl<T> PartialEq for AstNodeRef<T> {
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node.eq(&other.node)
if self.parsed == other.parsed {
// Comparing the pointer addresses is sufficient to determine equality
// if the parsed are the same.
self.node.eq(&other.node)
} else {
// Otherwise perform a deep comparison.
self.node().eq(other.node())
}
}
}

impl<T> Eq for AstNodeRef<T> {}
impl<T> Eq for AstNodeRef<T> where T: Eq {}

impl<T> Hash for AstNodeRef<T> {
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node.hash(state);
self.node().hash(state);
}
}

#[allow(unsafe_code)]
unsafe impl<T> salsa::Update for AstNodeRef<T> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let old_ref = &mut (*old_pointer);

if old_ref.parsed == new_value.parsed && old_ref.node.eq(&new_value.node) {
false
} else {
*old_ref = new_value;
true
}
}
}

Expand Down Expand Up @@ -133,7 +156,7 @@ mod tests {
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };

assert_ne!(node1, cloned_node);
assert_eq!(node1, cloned_node);

let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
Expand Down
16 changes: 8 additions & 8 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ impl DefinitionCategory {
/// [`DefinitionKind`] fields in salsa tracked structs should be tracked (attributed with `#[tracked]`)
/// because the kind is a thin wrapper around [`AstNodeRef`]. See the [`AstNodeRef`] documentation
/// for an in-depth explanation of why this is necessary.
#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub enum DefinitionKind<'db> {
Import(AstNodeRef<ast::Alias>),
ImportFrom(ImportFromDefinitionKind),
Expand Down Expand Up @@ -559,7 +559,7 @@ impl<'db> From<Option<Unpack<'db>>> for TargetKind<'db> {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub struct MatchPatternDefinitionKind {
pattern: AstNodeRef<ast::Pattern>,
Expand All @@ -577,7 +577,7 @@ impl MatchPatternDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ComprehensionDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
Expand All @@ -603,7 +603,7 @@ impl ComprehensionDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ImportFromDefinitionKind {
node: AstNodeRef<ast::StmtImportFrom>,
alias_index: usize,
Expand All @@ -619,7 +619,7 @@ impl ImportFromDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct AssignmentDefinitionKind<'db> {
target: TargetKind<'db>,
value: AstNodeRef<ast::Expr>,
Expand All @@ -645,7 +645,7 @@ impl<'db> AssignmentDefinitionKind<'db> {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct WithItemDefinitionKind {
node: AstNodeRef<ast::WithItem>,
target: AstNodeRef<ast::ExprName>,
Expand All @@ -666,7 +666,7 @@ impl WithItemDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ForStmtDefinitionKind<'db> {
target: TargetKind<'db>,
iterable: AstNodeRef<ast::Expr>,
Expand Down Expand Up @@ -697,7 +697,7 @@ impl<'db> ForStmtDefinitionKind<'db> {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ExceptHandlerDefinitionKind {
handler: AstNodeRef<ast::ExceptHandlerExceptHandler>,
is_star: bool,
Expand Down
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/unpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) struct Unpack<'db> {
/// expression is `(a, b)`.
#[no_eq]
#[return_ref]
#[tracked]
pub(crate) target: AstNodeRef<ast::Expr>,

/// The ingredient representing the value expression of the unpacking. For example, in
Expand Down

0 comments on commit 9c17931

Please sign in to comment.