Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement #599 #748

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub trait Configuration: Any {

/// The input to the function
type Input<'db>: Send + Sync;
type MapKey<'db>: Send + Sync + std::hash::Hash + Eq + std::fmt::Debug + Any;

/// The value computed by the function.
type Output<'db>: fmt::Debug + Send + Sync;
Expand Down Expand Up @@ -155,17 +156,22 @@ where
}
}

pub fn database_key_index(&self, k: Id) -> DatabaseKeyIndex {
pub fn set_capacity(&mut self, capacity: usize) {
self.lru.set_capacity(capacity);
}
}

impl<C> IngredientImpl<C>
where
C: Configuration,
{
fn database_key_index(&self, k: Id) -> DatabaseKeyIndex {
DatabaseKeyIndex {
ingredient_index: self.index,
key_index: k,
}
}

pub fn set_capacity(&mut self, capacity: usize) {
self.lru.set_capacity(capacity);
}

/// Returns a reference to the memo value that lives as long as self.
/// This is UNSAFE: the caller is responsible for ensuring that the
/// memo will not be released so long as the `&self` is valid.
Expand All @@ -184,6 +190,7 @@ where
&'db self,
zalsa: &'db Zalsa,
id: Id,
map_key: C::MapKey<'db>,
memo: memo::Memo<C::Output<'db>>,
memo_ingredient_index: MemoIngredientIndex,
) -> &'db memo::Memo<C::Output<'db>> {
Expand All @@ -197,9 +204,9 @@ where

// Safety: We delay the drop of `old_value` until a new revision starts which ensures no
// references will exist for the memo contents.
if let Some(old_value) =
unsafe { self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index) }
{
if let Some(old_value) = unsafe {
self.insert_memo_into_table_for(zalsa, id, map_key, memo, memo_ingredient_index)
} {
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
//
Expand All @@ -212,7 +219,8 @@ where

#[inline]
fn memo_ingredient_index(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex {
self.memo_ingredient_indices.get_zalsa_id(zalsa, id)
self.memo_ingredient_indices
.get_id_with_table(zalsa.table(), id)
}
}

Expand All @@ -232,15 +240,17 @@ where
) -> MaybeChangedAfter {
// SAFETY: The `db` belongs to the ingredient as per caller invariant
let db = unsafe { self.view_caster.downcast_unchecked(db) };
self.maybe_changed_after(db, input, revision)
let map_key = todo!();
self.maybe_changed_after(db, input, map_key, revision)
}

fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy {
C::CYCLE_STRATEGY
}

fn origin(&self, db: &dyn Database, key: Id) -> Option<QueryOrigin> {
self.origin(db.zalsa(), key)
let map_key = todo!();
self.origin(db.zalsa(), key, map_key)
}

fn mark_validated_output(
Expand All @@ -249,7 +259,8 @@ where
executor: DatabaseKeyIndex,
output_key: crate::Id,
) {
self.validate_specified_value(db, executor, output_key);
let map_key = todo!();
self.validate_specified_value(db, executor, output_key, map_key);
}

fn remove_stale_output(
Expand All @@ -269,11 +280,7 @@ where

fn reset_for_new_revision(&mut self, table: &mut Table) {
self.lru.for_each_evicted(|evict| {
let ingredient_index = table.ingredient_index(evict);
Self::evict_value_from_memo_for(
table.memos_mut(evict),
self.memo_ingredient_indices.get(ingredient_index),
)
Self::evict_value_from_memo_for(table, &self.memo_ingredient_indices, evict)
});
std::mem::take(&mut self.deleted_entries);
}
Expand All @@ -292,7 +299,8 @@ where
key_index: Id,
) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) {
let db = self.view_caster.downcast(db);
self.accumulated_map(db, key_index)
let map_key = todo!();
self.accumulated_map(db, key_index, map_key)
}
}

Expand Down
12 changes: 9 additions & 3 deletions src/function/accumulated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ where
{
/// Helper used by `accumulate` functions. Computes the results accumulated by `database_key_index`
/// and its inputs.
pub fn accumulated_by<'db, A>(&self, db: &'db C::DbView, key: Id) -> Vec<&'db A>
pub fn accumulated_by<'db, A>(
&'db self,
db: &'db C::DbView,
key: Id,
map_key: C::MapKey<'db>,
) -> Vec<&'db A>
where
A: accumulator::Accumulator,
{
Expand All @@ -39,7 +44,7 @@ where
let mut output = vec![];

// First ensure the result is up to date
self.fetch(db, key);
self.fetch(db, key, map_key);

let db = db.as_dyn_database();
let db_key = self.database_key_index(key);
Expand Down Expand Up @@ -96,9 +101,10 @@ where
&'db self,
db: &'db C::DbView,
key: Id,
map_key: C::MapKey<'db>,
) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) {
// NEXT STEP: stash and refactor `fetch` to return an `&Memo` so we can make this work
let memo = self.refresh_memo(db, key);
let memo = self.refresh_memo(db.zalsa(), db, key, map_key);
(
memo.revisions.accumulated.as_deref(),
memo.revisions.accumulated_inputs.load(),
Expand Down
3 changes: 2 additions & 1 deletion src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ where
&'db self,
db: &'db C::DbView,
active_query: ActiveQueryGuard<'_>,
map_key: C::MapKey<'db>,
opt_old_memo: Option<&Memo<C::Output<'_>>>,
) -> &'db Memo<C::Output<'db>> {
let zalsa = db.zalsa();
Expand All @@ -43,7 +44,6 @@ where

// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let database_key_index = active_query.database_key_index;
let id = database_key_index.key_index;
let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, id))) {
Ok(v) => v,
Expand Down Expand Up @@ -86,6 +86,7 @@ where
self.insert_memo(
zalsa,
id,
map_key,
Memo::new(Some(value), revision_now, revisions),
memo_ingredient_index,
)
Expand Down
48 changes: 30 additions & 18 deletions src/function/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use rayon::iter::Either;

use super::{memo::Memo, Configuration, IngredientImpl};
use crate::table::sync::ClaimGuard;
use crate::zalsa::MemoIngredientIndex;
use crate::{
accumulator::accumulated_map::InputAccumulatedValues,
Expand All @@ -11,11 +14,16 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> {
pub fn fetch<'db>(
&'db self,
db: &'db C::DbView,
id: Id,
map_key: C::MapKey<'db>,
) -> &'db C::Output<'db> {
let (zalsa, zalsa_local) = db.zalsas();
zalsa.unwind_if_revision_cancelled(db);

let memo = self.refresh_memo(db, id);
let memo = self.refresh_memo(zalsa, db, id, map_key);
// SAFETY: We just refreshed the memo so it is guaranteed to contain a value now.
let StampedValue {
value,
Expand All @@ -25,8 +33,6 @@ where
.revisions
.stamped_value(unsafe { memo.value.as_ref().unwrap_unchecked() });

self.lru.record_use(id);

zalsa_local.report_tracked_read(
self.database_key_index(id).into(),
durability,
Expand All @@ -43,18 +49,21 @@ where
#[inline]
pub(super) fn refresh_memo<'db>(
&'db self,
zalsa: &'db Zalsa,
db: &'db C::DbView,
id: Id,
mut map_key: C::MapKey<'db>,
) -> &'db Memo<C::Output<'db>> {
let zalsa = db.zalsa();
self.lru.record_use(id);
let memo_ingredient_index = self.memo_ingredient_index(zalsa, id);
loop {
if let Some(memo) = self
.fetch_hot(zalsa, db, id, memo_ingredient_index)
.or_else(|| self.fetch_cold(zalsa, db, id, memo_ingredient_index))
{
if let Some(memo) = self.fetch_hot(zalsa, db, id, &map_key, memo_ingredient_index) {
return memo;
}
match self.fetch_cold(zalsa, db, id, map_key, memo_ingredient_index) {
Either::Left(memo) => return memo,
Either::Right(key) => map_key = key,
}
}
}

Expand All @@ -64,9 +73,10 @@ where
zalsa: &'db Zalsa,
db: &'db C::DbView,
id: Id,
map_key: &C::MapKey<'db>,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<&'db Memo<C::Output<'db>>> {
let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
let memo_guard = self.get_memo_from_table_for(zalsa, id, map_key, memo_ingredient_index);
if let Some(memo) = memo_guard {
if memo.value.is_some()
&& self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo)
Expand All @@ -84,30 +94,32 @@ where
zalsa: &'db Zalsa,
db: &'db C::DbView,
id: Id,
map_key: C::MapKey<'db>,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<&'db Memo<C::Output<'db>>> {
) -> Either<&'db Memo<C::Output<'db>>, C::MapKey<'db>> {
let database_key_index = self.database_key_index(id);

// Try to claim this query: if someone else has claimed it already, go back and start again.
let _claim_guard =
zalsa
.sync_table_for(id)
.claim(db, zalsa, database_key_index, memo_ingredient_index)?;
let Some(_claim_guard) =
ClaimGuard::claim(db, zalsa, database_key_index, memo_ingredient_index)
else {
return Either::Right(map_key);
};

// Push the query on the stack.
let active_query = db.zalsa_local().push_query(database_key_index);

// Now that we've claimed the item, check again to see if there's a "hot" value.
let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
let opt_old_memo = self.get_memo_from_table_for(zalsa, id, &map_key, memo_ingredient_index);
if let Some(old_memo) = opt_old_memo {
if old_memo.value.is_some() && self.deep_verify_memo(db, zalsa, old_memo, &active_query)
{
// Unsafety invariant: memo is present in memo_map and we have verified that it is
// still valid for the current revision.
return unsafe { Some(self.extend_memo_lifetime(old_memo)) };
return unsafe { Either::Left(self.extend_memo_lifetime(old_memo)) };
}
}

Some(self.execute(db, active_query, opt_old_memo))
Either::Left(self.execute(db, active_query, map_key, opt_old_memo))
}
}
9 changes: 7 additions & 2 deletions src/function/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
pub(super) fn origin(&self, zalsa: &Zalsa, key: Id) -> Option<QueryOrigin> {
pub(super) fn origin<'db>(
&'db self,
zalsa: &'db Zalsa,
key: Id,
map_key: &C::MapKey<'db>,
) -> Option<QueryOrigin> {
let memo_ingredient_index = self.memo_ingredient_index(zalsa, key);
self.get_memo_from_table_for(zalsa, key, memo_ingredient_index)
self.get_memo_from_table_for(zalsa, key, map_key, memo_ingredient_index)
.map(|m| m.revisions.origin.clone())
}
}
Loading
Loading