Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] [WIP] Combine terminal statement support with statically known branches #15817

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,16 @@ def raise_in_both_branches(cond: bool):
# Exceptions can occur anywhere, so "before" and "raise" are valid possibilities
reveal_type(x) # revealed: Literal["before", "raise1", "raise2"]
else:
# This should not be included below, but we do not currently model that control flows that
# terminate via `raise` cannot enter the `else` clause.
x = "unreachable"
finally:
# Exceptions can occur anywhere, so "before" and "raise" are valid possibilities
reveal_type(x) # revealed: Literal["before", "raise1", "raise2"]
# TODO: Literal["before", "raise1", "raise2"]
reveal_type(x) # revealed: Literal["before", "raise1", "raise2", "unreachable"]
# Exceptions can occur anywhere, so "before" and "raise" are valid possibilities
reveal_type(x) # revealed: Literal["before", "raise1", "raise2"]
# TODO: Literal["before", "raise1", "raise2"]
reveal_type(x) # revealed: Literal["before", "raise1", "raise2", "unreachable"]

def raise_in_nested_then_branch(cond1: bool, cond2: bool):
x = "before"
Expand Down Expand Up @@ -636,5 +640,5 @@ def _(cond: bool):
return

# TODO: Literal["a"]
reveal_type(x) # revealed: Literal["a", "b"]
reveal_type(x) # revealed: Literal["a"]
```
24 changes: 21 additions & 3 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl<'db> SemanticIndexBuilder<'db> {

let file_scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::default());
self.use_def_maps.push(UseDefMapBuilder::default());
self.use_def_maps.push(UseDefMapBuilder::new(self.db));
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::default());

let scope_id = ScopeId::new(self.db, self.file, file_scope_id, countme::Count::default());
Expand Down Expand Up @@ -356,7 +356,7 @@ impl<'db> SemanticIndexBuilder<'db> {
constraint: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
self.current_use_def_map_mut()
.record_visibility_constraint(VisibilityConstraint::VisibleIfNot(constraint))
.record_negated_visibility_constraint_id(constraint)
}

/// Records a visibility constraint by applying it to all live bindings and declarations.
Expand Down Expand Up @@ -706,7 +706,25 @@ where

builder.declare_parameters(parameters);

builder.visit_body(body);
// HACK: Visit the function body, but treat the last statement specially if
// it is a return. If it is, this would cause all definitions in the
// function to be marked as non-visible with our current treatment of
// terminal statements. Since we currently model the externally visible
// definitions in a function scope as the set of bindings that are visible
// at the end of the body, we then consider this function to have no
// externally visible definitions. To get around this, we take a flow
// snapshot just before processing the return statement, and use _that_ as
// the "end-of-body" state that we resolve external references against.
if let Some((last_stmt, first_stmts)) = body.split_last() {
builder.visit_body(first_stmts);
let pre_return_state = matches!(last_stmt, ast::Stmt::Return(_))
.then(|| builder.flow_snapshot());
builder.visit_stmt(last_stmt);
if let Some(pre_return_state) = pre_return_state {
builder.flow_restore(pre_return_state);
}
}

builder.pop_scope()
},
);
Expand Down
66 changes: 46 additions & 20 deletions crates/red_knot_python_semantic/src/semantic_index/use_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::ScopedSymbolId;
use crate::semantic_index::use_def::symbol_state::DeclarationIdWithConstraint;
use crate::visibility_constraints::{VisibilityConstraint, VisibilityConstraints};
use crate::Db;
use ruff_index::IndexVec;
use rustc_hash::FxHashMap;

Expand Down Expand Up @@ -476,11 +477,12 @@ impl std::iter::FusedIterator for DeclarationsIterator<'_, '_> {}
pub(super) struct FlowSnapshot {
symbol_states: IndexVec<ScopedSymbolId, SymbolState>,
scope_start_visibility: ScopedVisibilityConstraintId,
reachable: bool,
always_reachable: bool,
}

#[derive(Debug)]
pub(super) struct UseDefMapBuilder<'db> {
db: &'db dyn Db,

/// Append-only array of [`Definition`].
all_definitions: IndexVec<ScopedDefinitionId, Option<Definition<'db>>>,

Expand All @@ -505,27 +507,27 @@ pub(super) struct UseDefMapBuilder<'db> {
/// Currently live bindings and declarations for each symbol.
symbol_states: IndexVec<ScopedSymbolId, SymbolState>,

reachable: bool,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I replaced this with reachability: ScopedVisibilityConstraintId, but I realized that scope_start_visibility is already what we need: the start of the scope is visible iff the flow is still reachable.

always_reachable: bool,
}

impl Default for UseDefMapBuilder<'_> {
fn default() -> Self {
impl<'db> UseDefMapBuilder<'db> {
pub(super) fn new(db: &'db dyn Db) -> Self {
Self {
db,
all_definitions: IndexVec::from_iter([None]),
all_constraints: IndexVec::new(),
visibility_constraints: VisibilityConstraints::default(),
scope_start_visibility: ScopedVisibilityConstraintId::ALWAYS_TRUE,
bindings_by_use: IndexVec::new(),
definitions_by_definition: FxHashMap::default(),
symbol_states: IndexVec::new(),
reachable: true,
always_reachable: true,
}
}
}

impl<'db> UseDefMapBuilder<'db> {
pub(super) fn mark_unreachable(&mut self) {
self.reachable = false;
self.record_visibility_constraint_id(ScopedVisibilityConstraintId::ALWAYS_FALSE);
self.always_reachable = false;
}

pub(super) fn add_symbol(&mut self, symbol: ScopedSymbolId) {
Expand Down Expand Up @@ -581,6 +583,15 @@ impl<'db> UseDefMapBuilder<'db> {
.add_and_constraint(self.scope_start_visibility, constraint);
}

pub(super) fn record_negated_visibility_constraint_id(
&mut self,
constraint: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
let new_constraint_id = self.visibility_constraints.add_not_constraint(constraint);
self.record_visibility_constraint_id(new_constraint_id);
new_constraint_id
}

pub(super) fn record_visibility_constraint(
&mut self,
constraint: VisibilityConstraint<'db>,
Expand Down Expand Up @@ -611,6 +622,10 @@ impl<'db> UseDefMapBuilder<'db> {
pub(super) fn simplify_visibility_constraints(&mut self, snapshot: FlowSnapshot) {
debug_assert!(self.symbol_states.len() >= snapshot.symbol_states.len());

if !self.always_reachable {
return;
}

self.scope_start_visibility = snapshot.scope_start_visibility;

// Note that this loop terminates when we reach a symbol not present in the snapshot.
Expand Down Expand Up @@ -664,7 +679,7 @@ impl<'db> UseDefMapBuilder<'db> {
FlowSnapshot {
symbol_states: self.symbol_states.clone(),
scope_start_visibility: self.scope_start_visibility,
reachable: self.reachable,
always_reachable: self.always_reachable,
}
}

Expand All @@ -679,6 +694,7 @@ impl<'db> UseDefMapBuilder<'db> {
// Restore the current visible-definitions state to the given snapshot.
self.symbol_states = snapshot.symbol_states;
self.scope_start_visibility = snapshot.scope_start_visibility;
self.always_reachable = snapshot.always_reachable;

// If the snapshot we are restoring is missing some symbols we've recorded since, we need
// to fill them in so the symbol IDs continue to line up. Since they don't exist in the
Expand All @@ -687,22 +703,34 @@ impl<'db> UseDefMapBuilder<'db> {
num_symbols,
SymbolState::undefined(self.scope_start_visibility),
);

self.reachable = snapshot.reachable;
}

/// Merge the given snapshot into the current state, reflecting that we might have taken either
/// path to get here. The new state for each symbol should include definitions from both the
/// prior state and the snapshot.
pub(super) fn merge(&mut self, snapshot: FlowSnapshot) {
// Unreachable snapshots should not be merged: If the current snapshot is unreachable, it
// should be completely overwritten by the snapshot we're merging in. If the other snapshot
// is unreachable, we should return without merging.
if !snapshot.reachable {
// As an optimization, if we know statically that either of the snapshots is always
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting out these two if clauses is how we verify that this is truly an optimization — we should get the same results for the tests with and without it

// unreachable, we can leave it out of the merged result entirely. Note that we cannot
// perform any type inference at this point, so this is largely limited to unreachability
// via terminal statements. If a flow's reachability depends on an expression in the code,
// we will include the flow in the merged result; the visibility constraints of its
// bindings will include this reachability condition, so that later during type inference,
// we can determine whether any particular binding is non-visible due to unreachability.
if self
.visibility_constraints
.evaluate_without_inference(self.db, snapshot.scope_start_visibility)
.is_always_false()
{
self.always_reachable = false;
return;
}
if !self.reachable {
if self
.visibility_constraints
.evaluate_without_inference(self.db, self.scope_start_visibility)
.is_always_false()
{
self.restore(snapshot);
self.always_reachable = false;
return;
}

Expand All @@ -727,9 +755,7 @@ impl<'db> UseDefMapBuilder<'db> {
self.scope_start_visibility = self
.visibility_constraints
.add_or_constraint(self.scope_start_visibility, snapshot.scope_start_visibility);

// Both of the snapshots are reachable, so the merged result is too.
self.reachable = true;
self.always_reachable &= snapshot.always_reachable;
}

pub(super) fn finish(mut self) -> UseDefMap<'db> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ impl ScopedVisibilityConstraintId {
/// present at index 0.
pub(crate) const ALWAYS_TRUE: ScopedVisibilityConstraintId =
ScopedVisibilityConstraintId::from_u32(0);

/// A special ID that is used for an "always false" / "never visible" constraint.
/// When we create a new [`VisibilityConstraints`] object, this constraint is always
/// present at index 1.
pub(crate) const ALWAYS_FALSE: ScopedVisibilityConstraintId =
ScopedVisibilityConstraintId::from_u32(1);
}

const INLINE_VISIBILITY_CONSTRAINTS: usize = 4;
Expand Down
73 changes: 62 additions & 11 deletions crates/red_knot_python_semantic/src/visibility_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ const MAX_RECURSION_DEPTH: usize = 24;
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum VisibilityConstraint<'db> {
AlwaysTrue,
AlwaysFalse,
Ambiguous,
VisibleIf(Constraint<'db>),
VisibleIfNot(ScopedVisibilityConstraintId),
Expand All @@ -186,7 +187,10 @@ pub(crate) struct VisibilityConstraints<'db> {
impl Default for VisibilityConstraints<'_> {
fn default() -> Self {
Self {
constraints: IndexVec::from_iter([VisibilityConstraint::AlwaysTrue]),
constraints: IndexVec::from_iter([
VisibilityConstraint::AlwaysTrue,
VisibilityConstraint::AlwaysFalse,
]),
}
}
}
Expand All @@ -196,14 +200,40 @@ impl<'db> VisibilityConstraints<'db> {
&mut self,
constraint: VisibilityConstraint<'db>,
) -> ScopedVisibilityConstraintId {
self.constraints.push(constraint)
match constraint {
VisibilityConstraint::AlwaysTrue => ScopedVisibilityConstraintId::ALWAYS_TRUE,
VisibilityConstraint::AlwaysFalse => ScopedVisibilityConstraintId::ALWAYS_FALSE,
_ => self.constraints.push(constraint),
}
}

pub(crate) fn add_not_constraint(
&mut self,
a: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
if a == ScopedVisibilityConstraintId::ALWAYS_TRUE {
ScopedVisibilityConstraintId::ALWAYS_FALSE
} else if a == ScopedVisibilityConstraintId::ALWAYS_FALSE {
ScopedVisibilityConstraintId::ALWAYS_TRUE
} else {
self.add(VisibilityConstraint::VisibleIfNot(a))
}
}

pub(crate) fn add_or_constraint(
&mut self,
a: ScopedVisibilityConstraintId,
b: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
if a == ScopedVisibilityConstraintId::ALWAYS_FALSE {
return b;
} else if b == ScopedVisibilityConstraintId::ALWAYS_FALSE {
return a;
} else if a == ScopedVisibilityConstraintId::ALWAYS_TRUE
|| b == ScopedVisibilityConstraintId::ALWAYS_TRUE
{
return ScopedVisibilityConstraintId::ALWAYS_TRUE;
}
match (&self.constraints[a], &self.constraints[b]) {
(_, VisibilityConstraint::VisibleIfNot(id)) if a == *id => {
ScopedVisibilityConstraintId::ALWAYS_TRUE
Expand All @@ -224,17 +254,31 @@ impl<'db> VisibilityConstraints<'db> {
b
} else if b == ScopedVisibilityConstraintId::ALWAYS_TRUE {
a
} else if a == ScopedVisibilityConstraintId::ALWAYS_FALSE
|| b == ScopedVisibilityConstraintId::ALWAYS_FALSE
{
ScopedVisibilityConstraintId::ALWAYS_FALSE
} else {
self.add(VisibilityConstraint::KleeneAnd(a, b))
}
}

/// Analyze the statically known visibility for a given visibility constraint, without
/// performing any type inference.
pub(crate) fn evaluate_without_inference(
&self,
db: &'db dyn Db,
id: ScopedVisibilityConstraintId,
) -> Truthiness {
self.evaluate_impl::<false>(db, id, MAX_RECURSION_DEPTH)
}

/// Analyze the statically known visibility for a given visibility constraint.
pub(crate) fn evaluate(&self, db: &'db dyn Db, id: ScopedVisibilityConstraintId) -> Truthiness {
self.evaluate_impl(db, id, MAX_RECURSION_DEPTH)
self.evaluate_impl::<true>(db, id, MAX_RECURSION_DEPTH)
}

fn evaluate_impl(
fn evaluate_impl<const INFERENCE_ALLOWED: bool>(
&self,
db: &'db dyn Db,
id: ScopedVisibilityConstraintId,
Expand All @@ -247,19 +291,26 @@ impl<'db> VisibilityConstraints<'db> {
let visibility_constraint = &self.constraints[id];
match visibility_constraint {
VisibilityConstraint::AlwaysTrue => Truthiness::AlwaysTrue,
VisibilityConstraint::AlwaysFalse => Truthiness::AlwaysFalse,
VisibilityConstraint::Ambiguous => Truthiness::Ambiguous,
VisibilityConstraint::VisibleIf(constraint) => Self::analyze_single(db, constraint),
VisibilityConstraint::VisibleIfNot(negated) => {
self.evaluate_impl(db, *negated, max_depth - 1).negate()
VisibilityConstraint::VisibleIf(constraint) => {
if INFERENCE_ALLOWED {
Self::analyze_single(db, constraint)
} else {
Truthiness::Ambiguous
}
}
VisibilityConstraint::VisibleIfNot(negated) => self
.evaluate_impl::<INFERENCE_ALLOWED>(db, *negated, max_depth - 1)
.negate(),
VisibilityConstraint::KleeneAnd(lhs, rhs) => {
let lhs = self.evaluate_impl(db, *lhs, max_depth - 1);
let lhs = self.evaluate_impl::<INFERENCE_ALLOWED>(db, *lhs, max_depth - 1);

if lhs == Truthiness::AlwaysFalse {
return Truthiness::AlwaysFalse;
}

let rhs = self.evaluate_impl(db, *rhs, max_depth - 1);
let rhs = self.evaluate_impl::<INFERENCE_ALLOWED>(db, *rhs, max_depth - 1);

if rhs == Truthiness::AlwaysFalse {
Truthiness::AlwaysFalse
Expand All @@ -270,13 +321,13 @@ impl<'db> VisibilityConstraints<'db> {
}
}
VisibilityConstraint::KleeneOr(lhs_id, rhs_id) => {
let lhs = self.evaluate_impl(db, *lhs_id, max_depth - 1);
let lhs = self.evaluate_impl::<INFERENCE_ALLOWED>(db, *lhs_id, max_depth - 1);

if lhs == Truthiness::AlwaysTrue {
return Truthiness::AlwaysTrue;
}

let rhs = self.evaluate_impl(db, *rhs_id, max_depth - 1);
let rhs = self.evaluate_impl::<INFERENCE_ALLOWED>(db, *rhs_id, max_depth - 1);

if rhs == Truthiness::AlwaysTrue {
Truthiness::AlwaysTrue
Expand Down
Loading