Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adding random forest classifier test which breaks #301

Merged
merged 18 commits into from
Apr 12, 2024
Merged
25 changes: 23 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,41 @@ jobs:
pip install -U numpy
pip install GDAL==$GDAL_VERSION
# Install GeoWombat
pip install .[perf,ml]
pip install .[perf]
- save_cache:
key: deps1-{{ .Branch }}-{{ checksum "setup.cfg" }}
paths:
- gwenv
- restore_cache:
key: deps1-{{ .Branch }}-{{ checksum "setup.cfg" }}

- run:
name: Run tests
working_directory: ~/project/tests/
command: |
. /home/circleci/project/gwenv/bin/activate
pip install testfixtures
python -m unittest
python -m unittest discover -s ./tests -p 'test_*.py'
# - run:
# name: Install GeoWombat with ml addon
# command: |
# pip install .[ml]
# export GEOWOMBAT_ADDON=ml
- run:
name: Run ml tests
working_directory: ~/project/tests/
command: |
. /home/circleci/project/gwenv/bin/activate
pip install testfixtures
pip install .[ml]
# Run tests conditionally based on the geowombat-ml addon
# if [[ "$GEOWOMBAT_ADDON" == "ml" ]]; then
python -m unittest ml_test.py
# fi




# new_release:
# docker:
# - image: cimg/python:3.8.12
Expand Down
91 changes: 46 additions & 45 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,52 @@ on:
branches: [ main ]

jobs:
Tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v2
with:
path: |
~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml', '**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install GDAL binaries
run: |
sudo apt update && sudo apt install -y software-properties-common
sudo add-apt-repository ppa:ubuntugis/ppa -y
sudo apt update && sudo apt install -y gdal-bin libgdal-dev libgl1 libspatialindex-dev g++ libmysqlclient-dev
echo "CPLUS_INCLUDE_PATH=/usr/include/gdal" >> $GITHUB_ENV
echo "C_INCLUDE_PATH=/usr/include/gdal" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/usr/local/lib" >> $GITHUB_ENV
- name: Install Python packages
run: |
pip install -U pip setuptools wheel
pip install numpy
GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}')
pip install GDAL==$GDAL_VERSION --no-cache-dir
pip install arosics
- name: Install GeoWombat
run: |
pip install ".[stac,web,coreg,perf,tests]"
- name: Run Unittests
run: |
pip install testfixtures
python -m unittest discover -p 'test_*.py'
- name: Run ml Unittests
run: |
pip install testfixtures
pip install ".[ml]"
python -m unittest discover -p 'ml_*.py'
# Quality:
# runs-on: ubuntu-latest
# steps:
Expand All @@ -35,51 +81,6 @@ jobs:
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# python -m poetry run flake8 . --exclude .venv --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

Tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip # caching pip dependencies based on changes to pyproject.toml
- name: Install GDAL binaries
run: |
# Temporary? dpkg fix: https://askubuntu.com/questions/1276111/error-upgrading-grub-efi-amd64-signed-special-device-old-ssd-does-not-exist
sudo rm /var/cache/debconf/config.dat
sudo dpkg --configure -a
# Install GDAL
sudo apt update && sudo apt upgrade -y && sudo apt install -y
sudo apt install software-properties-common -y
sudo add-apt-repository ppa:ubuntugis/ppa
sudo apt update -y && sudo apt install -y
sudo apt install libmysqlclient-dev default-libmysqlclient-dev -y
sudo apt install gdal-bin libgdal-dev libgl1 libspatialindex-dev g++ -y
export CPLUS_INCLUDE_PATH=/usr/include/gdal
export C_INCLUDE_PATH=/usr/include/gdal
export LD_LIBRARY_PATH=/usr/local/lib
- name: Install Python packages
run: |
# Install Python GDAL
pip install -U pip setuptools>=59.5.0 wheel
pip install numpy>=1.19.0
GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}')
pip install GDAL==$GDAL_VERSION --no-cache-dir
pip install arosics --no-deps
- name: Install GeoWombat
run: |
# Remove compiled Cython files
rm -rf src/geowombat/moving/*.so
pip install .[stac,web,ml,coreg,stac,perf,tests]
if: steps.cache-gwenv.outputs.cache-hit != 'true'
- name: Unittests
run: |
python -m unittest

# Version:
# needs: Tests
# if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ coreg = earthpy
zarr = zarr
numcodecs
ml = dask-ml>=2022.5.27
scikit-learn>=0.23.0,<=1.2.0
lightgbm
sklearn-xarray@git+https://github.com/jgrss/sklearn-xarray.git
perf = rtree
Expand Down
46 changes: 39 additions & 7 deletions tests/test_ml.py → tests/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier

# from sklearn.model_selection import GridSearchCV, KFold
from sklearn.naive_bayes import GaussianNB
Expand Down Expand Up @@ -46,6 +47,13 @@
]
)

tree_pipeline = Pipeline(
[
("scaler", StandardScaler()),
("clf", RandomForestClassifier(random_state=0)),
]
)

cl_wo_feat = Pipeline(
[
("scaler", StandardScaler()),
Expand All @@ -62,7 +70,7 @@ def test_output_values_missing(self):
with gw.open(l8_224078_20200518, nodata=0) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(src, pl_wo_feat, aoi_poly, col="lc")
Expand All @@ -86,6 +94,30 @@ def test_output_values_missing(self):
)
)

def test_tree_predict(self):
with gw.config.update(
ref_res=300,
):
with gw.open(l8_224078_20200518, nodata=0) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(src, tree_pipeline, aoi_poly, col="lc")
y1 = predict(src, X, clf)
y2 = fit_predict(src, tree_pipeline, aoi_poly, col="lc")

self.assertTrue(np.all(np.isnan(y1.values[0, 0:5, 0])))
self.assertTrue(np.all(np.isnan(y2.values[0, 0:5, 0])))
self.assertTrue(
np.allclose(
y1.values[0, -5:-1, 0],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what's going on here with the indexing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just grabbing some portion of the values to compare. I guess I really don't need to do that but oh well.

y2.values[0, -5:-1, 0],
equal_nan=True,
)
)

def test_output_type_attri(self):

with gw.config.update(
Expand All @@ -94,7 +126,7 @@ def test_output_type_attri(self):
with gw.open(l8_224078_20200518, nodata=0) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(src, pl_wo_feat, aoi_poly, col="lc")
Expand All @@ -115,7 +147,7 @@ def test_fitpredict_eq_fit_predict_point(self):
with gw.open(l8_224078_20200518, nodata=0) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(src, pl_wo_feat, aoi_point, col="lc")
Expand All @@ -134,7 +166,7 @@ def test_fitpredict_time_point(self):
) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y1 = fit_predict(
Expand All @@ -155,7 +187,7 @@ def test_fitpredict_eq_fit_predict_cluster(self):
with gw.open(l8_224078_20200518, nodata=0) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(data=src, clf=cl_wo_feat)
Expand All @@ -171,7 +203,7 @@ def test_classes_match_prediction_a(self):
with gw.open(l8_224078_20200518) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(src, pl_wo_feat, aoi_point, col="lc")
Expand All @@ -196,7 +228,7 @@ def test_classes_match_prediction_b(self):
with gw.open(l8_224078_20200518) as src:
with warnings.catch_warnings():
warnings.simplefilter(
'ignore',
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
X, Xy, clf = fit(src, pl_wo_feat, aoi_point, col="lc")
Expand Down
Loading