Skip to content

Commit

Permalink
feat: add count parameter to sparse generator
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianWoelki committed Nov 22, 2024
1 parent f2e7f5b commit fcb0b91
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 48 deletions.
68 changes: 59 additions & 9 deletions src/data/generator_sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct SparseDataGenerator {
metric: DistanceMetric,
system: sysinfo::System,
seed: u64,
query_count: Option<usize>,

pub vectors: Vec<SparseVector>,
pub query_vectors: Vec<SparseVector>,
Expand All @@ -43,6 +44,7 @@ impl SparseDataGenerator {
/// * `range` - The range of values for non-zero elements.
/// * `sparsity` - The probability of an element being zero (sparsity factor). The higher the value, the sparser the data.
/// * `metric` - The DistanceMetric that will be used to fetch the groundtruth vectors.
/// * `query_count` - Optional number of query vectors to generate. Defaults to 10% of count if None.
///
/// # Returns
///
Expand All @@ -54,6 +56,7 @@ impl SparseDataGenerator {
sparsity: f32,
metric: DistanceMetric,
seed: u64,
query_count: Option<usize>,
) -> Self {
SparseDataGenerator {
dim,
Expand All @@ -63,6 +66,7 @@ impl SparseDataGenerator {
metric,
system: sysinfo::System::new(),
seed,
query_count,
vectors: vec![],
query_vectors: vec![],
groundtruth: vec![],
Expand Down Expand Up @@ -116,8 +120,8 @@ impl SparseDataGenerator {
results.append(&mut result);
}

let query_vectors =
self.generate_vectors(self.dim, self.count / 10, self.range, self.sparsity);
let query_count = self.query_count.unwrap_or(self.count / 10);
let query_vectors = self.generate_vectors(self.dim, query_count, self.range, self.sparsity);

let mut groundtruth_vectors = Vec::with_capacity(query_vectors.len());
for query_vector in &query_vectors {
Expand Down Expand Up @@ -221,8 +225,15 @@ mod tests {
let dim = 100;
let range = (0.0, 1.0);
let sparsity = 0.5;
let mut generator =
SparseDataGenerator::new(dim, count, range, sparsity, DistanceMetric::Euclidean, seed);
let mut generator = SparseDataGenerator::new(
dim,
count,
range,
sparsity,
DistanceMetric::Euclidean,
seed,
None,
);
generator.generate().await;

assert_eq!(generator.vectors.len(), count);
Expand Down Expand Up @@ -266,15 +277,47 @@ mod tests {
}
}

#[tokio::test]
async fn test_custom_query_count() {
let seed = 42;
let count = 100;
let dim = 100;
let range = (0.0, 1.0);
let sparsity = 0.5;
let custom_query_count = 5;

let mut generator = SparseDataGenerator::new(
dim,
count,
range,
sparsity,
DistanceMetric::Euclidean,
seed,
Some(custom_query_count),
);
generator.generate().await;

assert_eq!(generator.vectors.len(), count);
assert_eq!(generator.query_vectors.len(), custom_query_count);
assert_eq!(generator.groundtruth.len(), generator.query_vectors.len());
}

#[tokio::test]
async fn test_sparsity() {
let seed = 42;
let count = 10;
let dim = 100;
let range = (0.0, 1.0);
let sparsity = 0.8;
let mut generator =
SparseDataGenerator::new(dim, count, range, sparsity, DistanceMetric::Euclidean, seed);
let mut generator = SparseDataGenerator::new(
dim,
count,
range,
sparsity,
DistanceMetric::Euclidean,
seed,
None,
);
generator.generate().await;

assert_eq!(generator.vectors.len(), count);
Expand All @@ -301,8 +344,15 @@ mod tests {
let range = (0.0, 1.0);
let sparsity = 0.5;
let k = 10;
let mut generator =
SparseDataGenerator::new(dim, count, range, sparsity, DistanceMetric::Euclidean, seed);
let mut generator = SparseDataGenerator::new(
dim,
count,
range,
sparsity,
DistanceMetric::Euclidean,
seed,
None,
);
generator.generate().await;

assert_eq!(generator.query_vectors.len(), count / 10);
Expand Down Expand Up @@ -347,7 +397,7 @@ mod tests {
let expected_groundtruth = vec![0, 1, 2];

let generator =
SparseDataGenerator::new(0, 0, (0.0, 1.0), 0.0, DistanceMetric::Euclidean, seed);
SparseDataGenerator::new(0, 0, (0.0, 1.0), 0.0, DistanceMetric::Euclidean, seed, None);
let groundtruth_vectors = generator.find_nearest_neighbors(&vectors, &query_vector, 3);

assert_eq!(groundtruth_vectors, expected_groundtruth);
Expand Down
36 changes: 0 additions & 36 deletions src/data/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,6 @@ impl SparseVector {
result
}

pub fn squared_distance(&self, other: &SparseVector) -> f32 {
let dot_product = self.dot(other);
let self_norm = self.dot(self);
let other_norm = other.dot(other);
self_norm + other_norm - 2.0 * dot_product
}

pub fn add_scaled(&mut self, other: &SparseVector, scale: f32) {
let mut i = 0;
let mut j = 0;
let mut new_indices = Vec::new();
let mut new_values = Vec::new();

while i < self.indices.len() || j < other.indices.len() {
if j == other.indices.len()
|| (i < self.indices.len() && self.indices[i] < other.indices[j])
{
new_indices.push(self.indices[i]);
new_values.push(self.values[i]);
i += 1;
} else if i == self.indices.len() || self.indices[i] > other.indices[j] {
new_indices.push(other.indices[j]);
new_values.push(OrderedFloat(other.values[j].0 * scale));
j += 1;
} else {
new_indices.push(self.indices[i]);
new_values.push(OrderedFloat(self.values[i].0 + other.values[j].0 * scale));
i += 1;
j += 1;
}
}

self.indices = new_indices;
self.values = new_values;
}

/// Jaccard distance is based on the Jaccard similarity coefficient, which
/// measures the overlap between two sets. The Jaccard distance is 1 minus
/// the Jaccard similarity.
Expand Down
15 changes: 12 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,15 @@ async fn plot_artificially_generated_data() {
// - High-dimensional sparse data: Manhattan distance
let amount = 1000;
let dim = 10000;
let mut data_generator =
SparseDataGenerator::new(dim, amount, (0.0, 1.0), 0.95, DistanceMetric::Cosine, 42);
let mut data_generator = SparseDataGenerator::new(
dim,
amount,
(0.0, 1.0),
0.95,
DistanceMetric::Cosine,
42,
None,
);
data_generator.generate().await;

// Get the first element of the groundtruth data
Expand Down Expand Up @@ -500,13 +507,14 @@ async fn main() {
0.96,
distance_metric,
seed,
None,
); // Dummy values
let (vectors, query_vectors, groundtruth) = data_generator
.load_data(dataset_path)
.expect("Failed to load dataset");
println!("...finished loading data");

let timeout = Duration::from_secs(5 * 60);
// let timeout = Duration::from_secs(5 * 60);
// let transformed_result = if let Some(reduction_technique) = &args.reduction_technique {
// match reduction_technique.as_str() {
// "pca" => execute_with_timeout(
Expand Down Expand Up @@ -787,6 +795,7 @@ async fn generate_data(
config.sparsity,
config.distance_metric,
seed,
None,
);
generator.generate().await;
generator
Expand Down

0 comments on commit fcb0b91

Please sign in to comment.