Skip to content

Commit

Permalink
refactor: switch to movie dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianWoelki committed Nov 1, 2024
1 parent 09732c3 commit bf021e2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 55 deletions.
10 changes: 5 additions & 5 deletions scripts/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def clean_text(text):
for example in tqdm(training_dataset, desc="Processing"):
text = clean_text(example['Overview'])
texts.append(text)
if len(texts) == 25000: # Temporary due to resource constraints.
break
# if len(texts) == 25000: # Temporary due to resource constraints.
# break

def sparse_vector_to_dict(vector):
indices = []
Expand Down Expand Up @@ -84,13 +84,13 @@ def compute_similarity(vec1, vec2):
top_k_indices = [idx for idx, _ in sorted(similarities, key=lambda x: x[1], reverse=True)[:k]]
groundtruth.append(top_k_indices)

with open('data.msgpack', 'wb') as f:
with open('data-50k.msgpack', 'wb') as f:
msgpack.dump(sparse_vectors, f)

with open('queries.msgpack', 'wb') as f:
with open('queries-50k.msgpack', 'wb') as f:
msgpack.dump(query_vectors, f)

with open('groundtruth.msgpack', 'wb') as f:
with open('groundtruth-50k.msgpack', 'wb') as f:
msgpack.dump(groundtruth, f)

print("\n=== Summary ===")
Expand Down
63 changes: 13 additions & 50 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use index::{
pq::PQIndex,
DistanceMetric, IndexType, SparseIndex,
};
use ordered_float::OrderedFloat;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};

mod benchmark;
Expand All @@ -62,55 +61,19 @@ struct Args {
}

#[allow(dead_code)]
async fn plot_msmarco_dataset() {
let (groundtruth, vectors, query_vectors) = data::ms_marco::load_msmarco_dataset().unwrap();
let mut query_sparse_vectors = vec![];
let mut groundtruth_sparse_vectors = vec![];

// TODO: Parallelize this in the future.
for (indices, values) in groundtruth.0.iter().zip(groundtruth.1.iter()) {
let vector = SparseVector {
indices: indices.iter().map(|i| *i as usize).collect(),
values: values.iter().map(|v| OrderedFloat(*v)).collect(),
};
groundtruth_sparse_vectors.push(vector);
}
for (indices, values) in query_vectors.iter() {
let sparse_vector = SparseVector {
indices: indices.clone(),
values: values
.iter()
.map(|&v| ordered_float::OrderedFloat(v))
.collect(),
};
query_sparse_vectors.push(sparse_vector);
}
async fn plot_movie_dataset() {
let vectors = read_sparse_vectors("./scripts/data-50k.msgpack").unwrap();
let query_vectors = read_sparse_vectors("./scripts/queries-50k.msgpack").unwrap();
let groundtruth = read_groundtruth("./scripts/groundtruth-50k.msgpack").unwrap();

let vectors_sparse_vectors = vectors
.par_iter()
.map(|(indices, values)| {
let sparse_vector = SparseVector {
indices: indices.clone(),
values: values
.iter()
.map(|&v| ordered_float::OrderedFloat(v))
.collect(),
};
sparse_vector
})
let groundtruth_flat = groundtruth
.iter()
.map(|nn| vectors[nn[0]].clone())
.collect::<Vec<SparseVector>>();

plot_sparsity_distribution(
&vectors_sparse_vectors,
format!("amount: {}", vectors_sparse_vectors.len()).as_str(),
)
.show();
plot_nearest_neighbor_distances(
&query_sparse_vectors,
&groundtruth_sparse_vectors,
&DistanceMetric::Cosine,
)
.show();
plot_sparsity_distribution(&vectors, format!("amount: {}", vectors.len()).as_str()).show();
plot_nearest_neighbor_distances(&query_vectors, &groundtruth_flat, &DistanceMetric::Jaccard)
.show();
}

#[allow(dead_code)]
Expand Down Expand Up @@ -250,9 +213,9 @@ async fn main() {
fs::create_dir_all(&dir_path).expect("Failed to create directory");

if dataset_type == "real" {
let vectors = read_sparse_vectors("./scripts/data.msgpack").unwrap();
let query_vectors = read_sparse_vectors("./scripts/queries.msgpack").unwrap();
let groundtruth = read_groundtruth("./scripts/groundtruth.msgpack").unwrap();
let vectors = read_sparse_vectors("./scripts/data-50k.msgpack").unwrap();
let query_vectors = read_sparse_vectors("./scripts/queries-50k.msgpack").unwrap();
let groundtruth = read_groundtruth("./scripts/groundtruth-50k.msgpack").unwrap();

let seed = 42;

Expand Down

0 comments on commit bf021e2

Please sign in to comment.