-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
88e4f99
commit ea44f83
Showing
20 changed files
with
1,210 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
/data | ||
/results | ||
/venv | ||
/models | ||
/models | ||
/data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os.path | ||
import os.path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from joblib import Parallel, delayed | ||
from sklearn.metrics import calinski_harabasz_score, silhouette_score | ||
from sklearn.mixture import GaussianMixture | ||
from sklearn.cluster import MiniBatchKMeans | ||
from tqdm import tqdm | ||
from lib.utils import get_index | ||
|
||
import lib.globals | ||
import lib.globals | ||
|
||
|
||
def _plot_kmeans_scores(X, min, max, step): | ||
""" | ||
Calculates scores for multiple values of kmeans | ||
Args: | ||
X (np.ndarray) | ||
min (int) | ||
max (int) | ||
step (int) | ||
""" | ||
rng = list(range(min, max, step)) | ||
|
||
def process(n): | ||
clf = GaussianMixture(n_components = n, random_state = 42) | ||
# clf = MiniBatchKMeans(n_clusters=n, random_state=42) | ||
labels = clf.fit_predict(X) | ||
|
||
s = silhouette_score(X, labels) | ||
c = calinski_harabasz_score(X, labels) | ||
b = clf.bic(X) | ||
|
||
return s, c, b | ||
|
||
n_jobs = len(rng) | ||
results = Parallel(n_jobs=n_jobs)(delayed(process)(i) for i in tqdm(rng)) | ||
results = np.column_stack(results).T | ||
|
||
fig, ax = plt.subplots(nrows=3) | ||
ax[0].plot(rng, results[:, 0], "o-", color="blue", label="Silhouette score") | ||
ax[1].plot(rng, results[:, 1], "o-", color="orange", label="CH score") | ||
ax[2].plot(rng, results[:, 2], "o-", color="red", label="BIC") | ||
|
||
for a in ax: | ||
a.legend(loc="upper right") | ||
|
||
plt.tight_layout() | ||
plt.savefig("plots/best_k.pdf") | ||
plt.show() | ||
|
||
|
||
def main(encodings_name): | ||
f = np.load( | ||
os.path.join(lib.globals.encodings_dir, encodings_name), | ||
allow_pickle=True, | ||
) | ||
|
||
X, encodings = f["X_true"], f["features"] | ||
|
||
arr_lens = np.array([len(xi) for xi in X]) | ||
(len_above_idx,) = np.where(arr_lens >= 30) | ||
X, encodings, = get_index((X, encodings), index=len_above_idx) | ||
|
||
_plot_kmeans_scores(encodings, min=2, max=100, step=3) | ||
|
||
|
||
if __name__ == "__main__": | ||
NAME = "20200124-0206_lstm_vae_bidir_data=combined_filt5_var.npz_dim=128_act=None_bat=4_eps=0.1_zdim=8_anneal=20___pred__combined_filt20_var.npz" | ||
main(NAME) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
import keras | ||
from keras.layers import Activation, Dense, Input | ||
from keras.layers import Conv2D, Flatten | ||
from keras.layers import Reshape, Conv2DTranspose | ||
from keras.models import Model | ||
from keras import backend as K | ||
from keras.datasets import mnist | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
|
||
np.random.seed(1337) | ||
|
||
# MNIST dataset | ||
(x_train, _), (x_test, _) = mnist.load_data() | ||
|
||
image_size = x_train.shape[1] | ||
x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) | ||
x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) | ||
x_train = x_train.astype('float32') / 255 | ||
x_test = x_test.astype('float32') / 255 | ||
|
||
# Generate corrupted MNIST images by adding noise with normal dist | ||
# centered at 0.5 and std=0.5 | ||
noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape) | ||
x_train_noisy = x_train + noise | ||
noise = np.random.normal(loc=0.5, scale=0.5, size=x_test.shape) | ||
x_test_noisy = x_test + noise | ||
|
||
x_train_noisy = np.clip(x_train_noisy, 0., 1.) | ||
x_test_noisy = np.clip(x_test_noisy, 0., 1.) | ||
|
||
# Network parameters | ||
input_shape = (image_size, image_size, 1) | ||
batch_size = 128 | ||
kernel_size = 3 | ||
latent_dim = 16 | ||
# Encoder/Decoder number of CNN layers and filters per layer | ||
layer_filters = [32, 64] | ||
|
||
# Build the Autoencoder Model | ||
# First build the Encoder Model | ||
inputs = Input(shape=input_shape, name='encoder_input') | ||
x = inputs | ||
# Stack of Conv2D blocks | ||
# Notes: | ||
# 1) Use Batch Normalization before ReLU on deep networks | ||
# 2) Use MaxPooling2D as alternative to strides>1 | ||
# - faster but not as good as strides>1 | ||
for filters in layer_filters: | ||
x = Conv2D(filters=filters, | ||
kernel_size=kernel_size, | ||
strides=2, | ||
activation='relu', | ||
padding='same')(x) | ||
|
||
# Shape info needed to build Decoder Model | ||
shape = K.int_shape(x) | ||
|
||
# Generate the latent vector | ||
x = Flatten()(x) | ||
latent = Dense(latent_dim, name='latent_vector')(x) | ||
|
||
# Instantiate Encoder Model | ||
encoder = Model(inputs, latent, name='encoder') | ||
encoder.summary() | ||
|
||
# Build the Decoder Model | ||
latent_inputs = Input(shape=(latent_dim,), name='decoder_input') | ||
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs) | ||
x = Reshape((shape[1], shape[2], shape[3]))(x) | ||
|
||
# Stack of Transposed Conv2D blocks | ||
# Notes: | ||
# 1) Use Batch Normalization before ReLU on deep networks | ||
# 2) Use UpSampling2D as alternative to strides>1 | ||
# - faster but not as good as strides>1 | ||
for filters in layer_filters[::-1]: | ||
x = Conv2DTranspose(filters=filters, | ||
kernel_size=kernel_size, | ||
strides=2, | ||
activation='relu', | ||
padding='same')(x) | ||
|
||
x = Conv2DTranspose(filters=1, | ||
kernel_size=kernel_size, | ||
padding='same')(x) | ||
|
||
outputs = Activation('sigmoid', name='decoder_output')(x) | ||
|
||
# Instantiate Decoder Model | ||
decoder = Model(latent_inputs, outputs, name='decoder') | ||
decoder.summary() | ||
|
||
# Autoencoder = Encoder + Decoder | ||
# Instantiate Autoencoder Model | ||
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder') | ||
autoencoder.summary() | ||
|
||
autoencoder.compile(loss='mse', optimizer='adam') | ||
|
||
# Train the autoencoder | ||
autoencoder.fit(x_train_noisy, | ||
x_train, | ||
validation_data=(x_test_noisy, x_test), | ||
epochs=30, | ||
batch_size=batch_size) | ||
|
||
# Predict the Autoencoder output from corrupted test images | ||
x_decoded = autoencoder.predict(x_test_noisy) | ||
|
||
# Display the 1st 8 corrupted and denoised images | ||
rows, cols = 10, 30 | ||
num = rows * cols | ||
imgs = np.concatenate([x_test[:num], x_test_noisy[:num], x_decoded[:num]]) | ||
imgs = imgs.reshape((rows * 3, cols, image_size, image_size)) | ||
imgs = np.vstack(np.split(imgs, rows, axis=1)) | ||
imgs = imgs.reshape((rows * 3, -1, image_size, image_size)) | ||
imgs = np.vstack([np.hstack(i) for i in imgs]) | ||
imgs = (imgs * 255).astype(np.uint8) | ||
plt.figure() | ||
plt.axis('off') | ||
plt.title('Original images: top rows, ' | ||
'Corrupted Input: middle rows, ' | ||
'Denoised Input: third rows') | ||
plt.imshow(imgs, interpolation='none', cmap='gray') | ||
Image.fromarray(imgs).save('corrupted_and_denoised.png') | ||
plt.show() |
Oops, something went wrong.