Skip to content

Commit

Permalink
Labels flowing through build and scan
Browse files Browse the repository at this point in the history
  • Loading branch information
tjgreen42 committed Jan 16, 2025
1 parent 30f87b9 commit b8168e0
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 161 deletions.
82 changes: 26 additions & 56 deletions pgvectorscale/src/access_method/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ unsafe fn aminsert_internal(
let heap_relation = PgRelation::from_pg(heaprel);
let mut meta_page = MetaPage::fetch(&index_relation);

let labvec = LabeledVector::from_datums(values, isnull, &meta_page);
if labvec.is_none() {
let vec = LabeledVector::from_datums(values, isnull, &meta_page);
if vec.is_none() {
// TODO: check this handling of nulls
return false;
}
let labvec = labvec.unwrap();
let vec = vec.unwrap();

let heap_pointer = ItemPointer::with_item_pointer_data(*heap_tid);
let mut storage = meta_page.get_storage_type();
Expand All @@ -195,7 +195,7 @@ unsafe fn aminsert_internal(
insert_storage(
&plain,
&index_relation,
labvec,
vec,
heap_pointer,
&mut meta_page,
&mut stats,
Expand All @@ -211,7 +211,7 @@ unsafe fn aminsert_internal(
insert_storage(
&bq,
&index_relation,
labvec,
vec,
heap_pointer,
&mut meta_page,
&mut stats,
Expand All @@ -224,23 +224,23 @@ unsafe fn aminsert_internal(
unsafe fn insert_storage<S: Storage>(
storage: &S,
index_relation: &PgRelation,
labvec: LabeledVector,
vec: LabeledVector,
heap_pointer: ItemPointer,
meta_page: &mut MetaPage,
stats: &mut InsertStats,
) {
let mut tape = Tape::resume(index_relation, S::page_type());
let index_pointer = storage.create_node(labvec, heap_pointer, meta_page, &mut tape, stats);
let index_pointer = storage.create_node(
vec.vec().to_index_slice(),
vec.labels().cloned(),
heap_pointer,
meta_page,
&mut tape,
stats,
);

let mut graph = Graph::new(GraphNeighborStore::Disk, meta_page);
graph.insert(
index_relation,
index_pointer,
labvec.vec(),
labvec.into_labels(),
storage,
stats,
)
graph.insert(index_relation, index_pointer, vec, storage, stats)
}

#[pg_guard]
Expand Down Expand Up @@ -442,44 +442,19 @@ unsafe extern "C" fn build_callback(
state: *mut std::os::raw::c_void,
) {
let index_relation = unsafe { PgRelation::from_pg(index) };
let heap_pointer = ItemPointer::with_item_pointer_data(*ctid);
let state = (state as *mut StorageBuildState).as_mut().unwrap();
match state {
StorageBuildState::SbqSpeedup(bq, state) => {
let vec = PgVector::from_pg_parts(values, isnull, 0, &state.meta_page, true, false);
let vec = LabeledVector::from_datums(values, isnull, &state.meta_page);
if let Some(vec) = vec {
let heap_pointer = ItemPointer::with_item_pointer_data(*ctid);
let labels: Option<Array<i32>> = if state.meta_page.has_labels() {
Array::<i32>::from_datum(*values.add(1), *isnull.add(1))
} else {
None
};
build_callback_memory_wrapper(
index_relation,
heap_pointer,
vec,
labels,
state,
*bq,
);
build_callback_memory_wrapper(index_relation, heap_pointer, vec, state, *bq);
}
}
StorageBuildState::Plain(plain, state) => {
let vec = PgVector::from_pg_parts(values, isnull, 0, &state.meta_page, true, false);
let vec = LabeledVector::from_datums(values, isnull, &state.meta_page);
if let Some(vec) = vec {
let heap_pointer = ItemPointer::with_item_pointer_data(*ctid);
let labels: Option<Array<i32>> = if state.meta_page.has_labels() {
Array::<i32>::from_datum(*values.add(1), *isnull.add(1))
} else {
None
};
build_callback_memory_wrapper(
index_relation,
heap_pointer,
vec,
labels,
state,
*plain,
);
build_callback_memory_wrapper(index_relation, heap_pointer, vec, state, *plain);
}
}
}
Expand All @@ -489,14 +464,13 @@ unsafe extern "C" fn build_callback(
unsafe fn build_callback_memory_wrapper<S: Storage>(
index: PgRelation,
heap_pointer: ItemPointer,
vector: PgVector,
labels: Option<Array<i32>>,
vector: LabeledVector,
state: &mut BuildState,
storage: &mut S,
) {
let mut old_context = state.memcxt.set_as_current();

build_callback_internal(index, heap_pointer, vector, labels, state, storage);
build_callback_internal(index, heap_pointer, vector, state, storage);

old_context.set_as_current();
state.memcxt.reset();
Expand All @@ -513,18 +487,14 @@ unsafe fn build_callback_memory_wrapper<S: Storage>(
fn build_callback_internal<S: Storage>(
index: PgRelation,
heap_pointer: ItemPointer,
vector: PgVector,
labels: Option<Array<i32>>,
vec: LabeledVector,
state: &mut BuildState,
storage: &mut S,
) {
check_for_interrupts!();

state.ntuples += 1;

let labels: Option<Vec<u16>> =
labels.map(|labels| labels.into_iter().flatten().map(|x| x as u16).collect());

if state.ntuples % 1000 == 0 {
debug1!(
"Processed {} tuples in {}s which is {}s/tuple. Dist/tuple: Prune: {} search: {}. Stats: {:?}",
Expand All @@ -538,17 +508,17 @@ fn build_callback_internal<S: Storage>(
}

let index_pointer = storage.create_node(
vector.to_index_slice(),
vec.vec().to_index_slice(),
vec.labels().cloned(),
heap_pointer,
&state.meta_page,
&mut state.tape,
&mut state.stats,
labels,
);

state
.graph
.insert(&index, index_pointer, vector, storage, &mut state.stats);
.insert(&index, index_pointer, vec, storage, &mut state.stats);
}

const BUILD_PHASE_TRAINING: i64 = 0;
Expand Down
15 changes: 6 additions & 9 deletions pgvectorscale/src/access_method/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::util::{HeapPointer, IndexPointer, ItemPointer};

use super::graph_neighbor_store::GraphNeighborStore;

use super::labels::LabeledVector;
use super::neighbor_with_distance::{Distance, DistanceWithTieBreak};
use super::pg_vector::PgVector;
use super::stats::{GreedySearchStats, InsertStats, PruneNeighborStats, StatsNodeVisit};
Expand Down Expand Up @@ -268,8 +269,7 @@ impl<'a> Graph<'a> {
fn greedy_search_for_build<S: Storage>(
&self,
index_pointer: IndexPointer,
query: PgVector,
labels: Option<Vec<u16>>,
query: LabeledVector,
meta_page: &MetaPage,
storage: &S,
stats: &mut GreedySearchStats,
Expand All @@ -279,7 +279,7 @@ impl<'a> Graph<'a> {
//no nodes in the graph
return HashSet::with_capacity(0);
}
let dm = storage.get_query_distance_measure(query, labels);
let dm = storage.get_query_distance_measure(query);
let search_list_size = meta_page.get_search_list_size_for_build() as usize;

let mut l = ListSearchResult::new(
Expand All @@ -301,8 +301,7 @@ impl<'a> Graph<'a> {
/// the next elements.
pub fn greedy_search_streaming_init<S: Storage>(
&self,
query: PgVector,
labels: Option<Vec<u16>>,
query: LabeledVector,
search_list_size: usize,
storage: &S,
) -> ListSearchResult<S::QueryDistanceMeasure, S::LSNPrivateData> {
Expand All @@ -311,7 +310,7 @@ impl<'a> Graph<'a> {
//no nodes in the graph
return ListSearchResult::empty();
}
let dm = storage.get_query_distance_measure(query, labels);
let dm = storage.get_query_distance_measure(query);

ListSearchResult::new(
init_ids.unwrap(),
Expand Down Expand Up @@ -443,8 +442,7 @@ impl<'a> Graph<'a> {
&mut self,
index: &PgRelation,
index_pointer: IndexPointer,
vec: PgVector,
labels: Option<Vec<u16>>,
vec: LabeledVector,
storage: &S,
stats: &mut InsertStats,
) {
Expand All @@ -471,7 +469,6 @@ impl<'a> Graph<'a> {
let v = self.greedy_search_for_build(
index_pointer,
vec,
labels,
meta_page,
storage,
&mut stats.greedy_search_stats,
Expand Down
37 changes: 27 additions & 10 deletions pgvectorscale/src/access_method/labels.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::{meta_page::MetaPage, pg_vector::PgVector};
use pgrx::{pg_sys::Datum, Array, FromDatum};
use pgrx::{
pg_sys::{Datum, ScanKeyData},
Array, FromDatum,
};

pub type Labels = Vec<u16>;

Expand All @@ -18,11 +21,7 @@ impl LabeledVector {
isnull: *mut bool,
meta_page: &MetaPage,
) -> Option<Self> {
let vec = PgVector::from_pg_parts(values, isnull, 0, &meta_page, true, false);
if vec.is_none() {
return None;
}
let vec = vec.unwrap();
let vec = PgVector::from_pg_parts(values, isnull, 0, meta_page, true, false)?;

let labels = if meta_page.has_labels() {
let arr = Array::<i32>::from_datum(*values.add(1), *isnull.add(1));
Expand All @@ -34,15 +33,33 @@ impl LabeledVector {
Some(Self::new(vec, labels))
}

pub unsafe fn from_scan_key_data(
keys: &[ScanKeyData],
orderbys: &[ScanKeyData],
meta_page: &MetaPage,
) -> Self {
let query = unsafe {
PgVector::from_datum(
orderbys[0].sk_argument,
meta_page,
true, /* needed for search */
true, /* needed for resort */
)
};

let labels: Option<Vec<u16>> = (!keys.is_empty()).then(|| {
let arr = unsafe { Array::<i32>::from_datum(keys[0].sk_argument, false).unwrap() };
arr.into_iter().flatten().map(|i| i as u16).collect()
});

Self::new(query, labels)
}

pub fn vec(&self) -> &PgVector {
&self.vec
}

pub fn labels(&self) -> Option<&Vec<u16>> {
self.labels.as_ref()
}

pub fn into_labels(self) -> Option<Vec<u16>> {
self.labels
}
}
13 changes: 7 additions & 6 deletions pgvectorscale/src/access_method/plain_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use pgvectorscale_derive::{Readable, Writeable};
use rkyv::vec::ArchivedVec;
use rkyv::{Archive, Deserialize, Serialize};

use super::labels::LabeledVector;
use super::neighbor_with_distance::NeighborWithDistance;
use super::storage::ArchivedData;
use crate::util::{ArchivedItemPointer, HeapPointer, ItemPointer, ReadableBuffer, WritableBuffer};
Expand All @@ -25,32 +24,34 @@ pub struct Node {

impl Node {
fn new_internal(
labvec: LabeledVector,
vector: Vec<f32>,
labels: Option<Vec<u16>>,
pq_vector: Vec<u8>,
heap_item_pointer: ItemPointer,
meta_page: &MetaPage,
) -> Self {
let num_neighbors = meta_page.get_num_neighbors();
Self {
vector: labvec.vec().to_index_slice().to_vec(),
vector,
// always use vectors of num_clusters on length because we never want the serialized size of a Node to change
pq_vector,
// always use vectors of num_neighbors on length because we never want the serialized size of a Node to change
neighbor_index_pointers: (0..num_neighbors)
.map(|_| ItemPointer::new(InvalidBlockNumber, InvalidOffsetNumber))
.collect(),
heap_item_pointer,
labels: labvec.into_labels(),
labels,
}
}

pub fn new_for_full_vector(
labvec: LabeledVector,
vector: Vec<f32>,
labels: Option<Vec<u16>>,
heap_item_pointer: ItemPointer,
meta_page: &MetaPage,
) -> Self {
let pq_vector = Vec::with_capacity(0);
Self::new_internal(labvec, pq_vector, heap_item_pointer, meta_page)
Self::new_internal(vector, labels, pq_vector, heap_item_pointer, meta_page)
}
}

Expand Down
Loading

0 comments on commit b8168e0

Please sign in to comment.