Skip to content

Commit

Permalink
Merge pull request #30 from mmann1123/dim_reduct
Browse files Browse the repository at this point in the history
add PCA multiband
  • Loading branch information
mmann1123 authored Nov 25, 2024
2 parents 232a9c3 + c9d40f2 commit 618f17d
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 47 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ Example outputs:

![png](examples/output_8_0.png)


The package can also calculate PCA components from stacks of raster data, and can be used to generate new rasters with the PCA components as bands.

Example outputs from PCA components for African ppt data:

![Principal component outputs](examples/pcas.png)

# Install

To install xr_fresh, you can use pip. However, since xr_fresh includes a C++ extension module, it requires compilation during the installation process. Here are the steps to install xr_fresh:
Expand All @@ -47,6 +54,12 @@ pip install .
```
Note: If you run into problems related to `rle` try running `python setup.py build_ext --inplace` from the `xr_fresh` directory

To run PCA you must also install `ray`.

```
conda install -c conda-forge "ray-default"
```
Note: `ray` is only is beta for Windows and will not be installed by default. Please read more about the installation [here](https://docs.ray.io/en/latest/ray-overview/installation.html)

## Example

Expand Down
Binary file added examples/pcas.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TestCluster(unittest.TestCase):

def setUp(self):
# Common setup for each test
# Common setup for each test case
self.cluster = Cluster()

def tearDown(self):
Expand Down
104 changes: 104 additions & 0 deletions tests/test_dimension_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import unittest
import geowombat as gw
from xr_fresh.dimension_reduction import ExtendedGeoWombatAccessor
import ray
import numpy as np


class TestDimensionReduction(unittest.TestCase):

def setUp(self):
# Initialize Ray
ray.init(num_cpus=8)

def tearDown(self):
# Shutdown Ray
ray.shutdown()

def test_k_pca(self):
# Example usage
with gw.open(
sorted(
[
"./tests/data/RadT_tavg_202301.tif",
"./tests/data/RadT_tavg_202302.tif",
"./tests/data/RadT_tavg_202304.tif",
"./tests/data/RadT_tavg_202305.tif",
]
),
stack_dim="band",
band_names=[0, 1, 2, 3],
) as src:
# get third k principal components - base zero counting
transformed_dataarray = src.gw_ext.k_pca(
gamma=15, n_components=3, n_workers=8, chunk_size=256
)

# Check the shape of the transformed data
self.assertEqual(transformed_dataarray.shape, (src.y.size, src.x.size, 3))

# Check the attributes of the transformed data
self.assertEqual(transformed_dataarray.attrs["crs"], src.attrs["crs"])
self.assertEqual(
transformed_dataarray.attrs["transform"], src.attrs["transform"]
)

def test_k_pca_invalid_gamma(self):
with gw.open(
sorted(
[
"./tests/data/RadT_tavg_202301.tif",
"./tests/data/RadT_tavg_202302.tif",
"./tests/data/RadT_tavg_202304.tif",
"./tests/data/RadT_tavg_202305.tif",
]
),
stack_dim="band",
band_names=[0, 1, 2, 3],
) as src:
gamma = -1
n_components = 3
n_workers = 8
chunk_size = 256
with self.assertRaises(ValueError):
src.gw_ext.k_pca(
gamma=gamma,
n_components=n_components,
n_workers=n_workers,
chunk_size=chunk_size,
)

def test_k_pca_no_equal_components(self):
with gw.open(
sorted(
[
"./tests/data/RadT_tavg_202301.tif",
"./tests/data/RadT_tavg_202302.tif",
"./tests/data/RadT_tavg_202304.tif",
"./tests/data/RadT_tavg_202305.tif",
]
),
stack_dim="band",
band_names=[0, 1, 2, 3],
) as src:
gamma = 15
n_components = 3
n_workers = 8
chunk_size = 256
transformed_dataarray = src.gw_ext.k_pca(
gamma=gamma,
n_components=n_components,
n_workers=n_workers,
chunk_size=chunk_size,
)

for comp in transformed_dataarray.component.values:
component_data = transformed_dataarray.sel(component=comp).values
unique_values = np.unique(component_data)
self.assertGreater(
len(unique_values), 1, f"Component {comp} has all equal values"
)


if __name__ == "__main__":
unittest.main()
90 changes: 44 additions & 46 deletions xr_fresh/dimension_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ExtendedGeoWombatAccessor(GeoWombatAccessor):
def k_pca(
self,
gamma: float,
n_component: int,
n_components: int,
n_workers: int,
chunk_size: int,
) -> xr.DataArray:
Expand All @@ -27,18 +27,18 @@ def k_pca(
Args:
gamma (float): The gamma parameter for the RBF kernel.
n_component (int): The number of component that will be kept
n_components (int): The number of components to keep.
n_workers (int): The number of parallel jobs for KernelPCA and ParallelTask.
chunk_size (int): The size of the chunks for processing.
Returns:
xr.DataArray: A DataArray with the Kernel PCA components as bands.
Examples:
# Initialize Ray
with ray.init(num_cpus=8) as rays:
# Example usage
with gw.open(
sorted(
Expand All @@ -54,52 +54,63 @@ def k_pca(
) as src:
# get third k principal components - base zero counting
transformed_dataarray = src.gw_ext.k_pca(
gamma=15, n_component=3, n_workers=8, chunk_size=256
gamma=15, n_components=3, n_workers=8, chunk_size=256
)
transformed_dataarray.sel(component=3).plot()
transformed_dataarray.plot.imshow(col='component', col_wrap=1, figsize=(8, 12))
plt.show()
"""

# Transpose data to have shape (height, width, num_features)
data = self._obj.transpose("y", "x", "band").values
height, width, num_features = data.shape
# Transpose data to have shape (num_features, height, width)
data = self._obj.transpose("band", "y", "x").values
num_features, height, width = data.shape

# Reshape data to 2D array
transposed_data = data.reshape(-1, num_features)
# Reshape data to 2D array (pixels, features)
transposed_data = data.reshape(num_features, -1).T

# Drop rows with NaNs
transposed_data = transposed_data[~np.isnan(transposed_data).any(axis=1)]
valid_indices = ~np.isnan(transposed_data).any(axis=1)
transposed_data_valid = transposed_data[valid_indices]

# Sample data for fitting Kernel PCA
num_samples = 10000
num_samples = min(10000, transposed_data_valid.shape[0])
np.random.seed(42) # For reproducibility
sampled_features = transposed_data[
np.random.choice(transposed_data.shape[0], num_samples, replace=False)
]
sampled_indices = np.random.choice(
transposed_data_valid.shape[0], num_samples, replace=False
)
sampled_features = transposed_data_valid[sampled_indices]

# Fit Kernel PCA on the sampled features
kpca = KernelPCA(
kernel="rbf", gamma=gamma, n_components=n_component + 1, n_jobs=n_workers
kernel="rbf", gamma=gamma, n_components=n_components, n_jobs=n_workers
)
kpca.fit(sampled_features)

# Extract necessary attributes from kpca for transformation
X_fit_ = kpca.X_fit_
eigenvectors = kpca.eigenvectors_[:, n_component - 1]
eigenvalues = kpca.eigenvalues_[n_component - 1]
eigenvectors = kpca.eigenvectors_
eigenvalues = kpca.eigenvalues_

@numba.jit(nopython=True, parallel=True)
def transform_entire_dataset_numba(
data, X_fit_, eigenvector, eigenvalue, gamma
data, X_fit_, eigenvectors, eigenvalues, gamma
):
height, width = data.shape[1], data.shape[2]
transformed_data = np.zeros((height, width))
num_features, height, width = data.shape
n_components = eigenvectors.shape[1]
transformed_data = np.zeros((height, width, n_components))

for i in numba.prange(height):
for j in range(width):
feature_vector = data[:, i, j]
if np.isnan(feature_vector).any():
transformed_data[i, j, :] = np.nan
continue
k = np.exp(-gamma * np.sum((feature_vector - X_fit_) ** 2, axis=1))
transformed_feature = np.dot(k, eigenvector / np.sqrt(eigenvalue))
transformed_data[i, j] = transformed_feature
for c in range(n_components):
transformed_feature = np.dot(
k, eigenvectors[:, c] / np.sqrt(eigenvalues[c])
)
transformed_data[i, j, c] = transformed_feature

return transformed_data

Expand All @@ -109,21 +120,18 @@ def process_window(
data_slice,
window_id,
X_fit_,
eigenvector,
eigenvalue,
eigenvectors,
eigenvalues,
gamma,
num_workers=n_workers,
):
data_chunk = data_block_id[
data_slice
].data.compute() # Convert Dask array to NumPy array
data_chunk = data_block_id[data_slice].data.compute()
return transform_entire_dataset_numba(
data_chunk, X_fit_, eigenvector, eigenvalue, gamma
data_chunk, X_fit_, eigenvectors, eigenvalues, gamma
)

# Perform transformation in parallel
pt = ParallelTask(
self._obj,
self._obj.transpose("band", "y", "x"),
row_chunks=chunk_size,
col_chunks=chunk_size,
scheduler="ray",
Expand All @@ -134,19 +142,13 @@ def process_window(
futures = pt.map(process_window, X_fit_, eigenvectors, eigenvalues, gamma)

# Combine the results
transformed_data = np.zeros((height, width, 1), dtype=np.float64)

# Combine the results
transformed_data = np.zeros((height, width))
for window_id, future in enumerate(ray.get(futures)):
transformed_data = np.zeros((height, width, n_components), dtype=np.float64)
results = ray.get(futures)
for window_id, future in enumerate(results):
window = pt.windows[window_id]
row_start, col_start = window.row_off, window.col_off
row_end, col_end = row_start + window.height, col_start + window.width
transformed_data[row_start:row_end, col_start:col_end] = future

# extend dimension of transformed_data
if len(transformed_data.shape) == 2:
transformed_data = np.expand_dims(transformed_data, axis=2)
transformed_data[row_start:row_end, col_start:col_end, :] = future

# Create a new DataArray with the transformed data
transformed_dataarray = xr.DataArray(
Expand All @@ -155,9 +157,7 @@ def process_window(
coords={
"y": self._obj.y,
"x": self._obj.x,
"component": [
n_component
], # [f"component_{i+1}" for i in range(n_components)],
"component": [f"component_{i+1}" for i in range(n_components)],
},
attrs=self._obj.attrs,
)
Expand All @@ -167,5 +167,3 @@ def process_window(

# Register the new accessor
xr.register_dataarray_accessor("gw_ext")(ExtendedGeoWombatAccessor)

# %%

0 comments on commit 618f17d

Please sign in to comment.