You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Consider the following VAE example with MNIST data:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import inferpy as inf
# number of components
k = 2
# size of the hidden layer in the NN
d0 = 100
# dimensionality of the data
dx = 28 * 28
# number of observations (dataset size)
N = 1000
# batch size
M = 100
# digits considered
DIG = [0, 1, 2]
# minimum scale
scale_epsilon = 0.01
# inference parameters
num_epochs = 1000
learning_rate = 0.01
# reset tensorflow
tf.reset_default_graph()
tf.set_random_seed(1234)
from inferpy.data import mnist
# load the data
(x_train, y_train), _ = mnist.load_data(num_instances=N, digits=DIG)
mnist.plot_digits(x_train, grid=[5,5])
############## Inferpy ##############
@inf.probmodel
def vae(k, d0, dx, decoder):
with inf.datamodel():
z = inf.Normal(tf.ones(k) * 0.5, 1, name="z") # shape = [N,k]
output = decoder(z, d0, dx)
x_loc = output[:, :dx]
x_scale = tf.nn.softmax(output[:, dx:]) + scale_epsilon
x = inf.Normal(x_loc, x_scale, name="x") # shape = [N,d]
# Neural networks for decoding and encoding
def decoder(z, d0, dx): # k -> d0 -> 2*dx
h0 = tf.layers.dense(z, d0, tf.nn.relu)
return tf.layers.dense(h0, 2 * dx)
def encoder(x, d0, k): # dx -> d0 -> 2*k
h0 = tf.layers.dense(x, d0, tf.nn.relu)
return tf.layers.dense(h0, 2 * k)
# Q model for making inference
@inf.probmodel
def qmodel(k, d0, dx, encoder):
with inf.datamodel():
x = inf.Normal(tf.ones(dx) * 0.5, 1, name="x")
output = encoder(x, d0, k)
qz_loc = output[:, :k]
qz_scale = tf.nn.softmax(output[:, k:]) + scale_epsilon
qz = inf.Normal(qz_loc, qz_scale, name="z")
# Inference
############################
m = vae(k, d0, dx, decoder)
q = qmodel(k, d0, dx, encoder)
The code code for the posterior queries is different depending on using VI or SVI:
# with VI
VI = inf.inference.VI(q, epochs=100)
m.fit({"x": x_train}, VI)
postz = m.posterior("z", data={"x": x_train}).sample()
# with SVI
SVI = inf.inference.SVI(q, epochs=100, batch_size=M)
m.fit({"x": x_train}, SVI)
postz = np.concatenate([
m.posterior("z", data={"x": x_train[i:i+M,:]}).sample()
for i in range(0,N,M)])
Note that in the SVI case, the observed data at the posterior query must be of the size
of the batch, instead of the size of the total dataset. This forces the user to be aware of the batch_size, which is not always the case (this is an optional parameter).
We might consider to internally split the data (like it is done for the training) and concatenate the result. In doing so, the user does not need to know the size of batch size. Then we might consider what happens with data of smaller size than the batch.
The text was updated successfully, but these errors were encountered:
Consider the following VAE example with MNIST data:
The code code for the posterior queries is different depending on using VI or SVI:
Note that in the SVI case, the observed data at the posterior query must be of the size
of the batch, instead of the size of the total dataset. This forces the user to be aware of the batch_size, which is not always the case (this is an optional parameter).
We might consider to internally split the data (like it is done for the training) and concatenate the result. In doing so, the user does not need to know the size of batch size. Then we might consider what happens with data of smaller size than the batch.
The text was updated successfully, but these errors were encountered: