MNIST and Fashion MNIST!
Pre-release
Pre-release
Summary
Reproducible configs for MNIST and Fashion MNIST with their checkpoints. View release files for checkpoints (model weights).
Usage example
First, download the checkpoint files.
from PIL import Image
import jax
from ml_collections import FrozenConfigDict
from jax_diffusion.configs.mnist import get_config as mnist_config
from jax_diffusion.configs.fashion_mnist import get_config as fmnist_config
from jax_diffusion.train.trainer import Trainer
from jax_diffusion.utils.image import image_grid
# config = mnist_config() # uncomment one
# config = fmnist_config() # uncomment one
config.restore = "<path-to-checkpoint-file / directory>"
config = FrozenConfigDict(config) # needed to be hashable
trainer = Trainer(jax.random.PRNGKey(0), **config.experiment_kwargs)
trainer.restore_checkpoint(config.restore)
img = trainer.sample(num=25, steps=100, jax.random.PRNGKey(1))
img = image_grid(img)[:, :, 0]
img = Image.fromarray(img)
img.save("sample.png", dpi=(144, 144))
MNIST
Config / hyperparameters
>>> from jax_diffusion.configs.mnist import get_config
>>> get_config().copy_and_resolve_references()
ckpt_dir: gs://jax-diffusion-checkpoints
ckpt_interval: 180
dry_run: false
effective_steps: 100000
eval_interval: 180
experiment_kwargs:
config:
dataset_kwargs:
data_dir: /home/andylo/tensorflow_datasets
map_calls: auto
name: mnist
prefetch: auto
resize_dim: 32
seed: 42
diffusion:
T: 1000
beta_1: 0.0001
beta_T: 0.02
eval:
dataset_kwargs:
batch_size: 128
subset: 40%
sample_kwargs:
num: 6
steps: 100
model:
unet_kwargs:
attention_num_heads: 4
attention_resolutions: !!python/tuple
- 16
dim_init: 48
dim_mults: !!python/tuple
- 1
- 2
- 2
- 2
dropout: 0.1
dtype: fp32
kernel_size: 3
num_groups: 4
num_res_blocks: 2
sinusoidal_embed_dim: 48
time_embed_dim: 192
seed: 42
train:
dataset_kwargs:
augment: false
batch_size: 128
buffer_size: 1000
repeat: true
shuffle: true
subset: 100%
ema_step_size: 0.0004999999999999449
lr_schedule:
kwargs:
value: 0.0001
schedule_type: constant
optimizer:
kwargs:
grac_acc_steps: 1
max_grad_norm: 1.0
optimizer_type: adam
log_interval: 1
log_level: 0
project_name: jax-diffusion
restore: ''
seed: 42
steps: 100000
Fashion MNIST
Terminated after roughly 90K steps even though it is set at 100K steps in the config.
Config / hyperparameters
>>> from jax_diffusion.configs.fashion_mnist import get_config
>>> get_config().copy_and_resolve_references()
ckpt_dir: gs://jax-diffusion-checkpoints
ckpt_interval: 180
dry_run: false
effective_steps: 100000
eval_interval: 180
experiment_kwargs:
config:
dataset_kwargs:
data_dir: /home/andylo/tensorflow_datasets
map_calls: auto
name: fashion_mnist
prefetch: auto
resize_dim: 32
seed: 42
diffusion:
T: 1000
beta_1: 0.0001
beta_T: 0.02
eval:
dataset_kwargs:
batch_size: 128
subset: 40%
sample_kwargs:
num: 6
steps: 100
model:
unet_kwargs:
attention_num_heads: 4
attention_resolutions: !!python/tuple
- 16
dim_init: 64
dim_mults: !!python/tuple
- 1
- 2
- 2
- 2
dropout: 0.1
dtype: fp32
kernel_size: 3
num_groups: 4
num_res_blocks: 2
sinusoidal_embed_dim: 64
time_embed_dim: 256
seed: 42
train:
dataset_kwargs:
augment: false
batch_size: 128
buffer_size: 1000
repeat: true
shuffle: true
subset: 100%
ema_step_size: 0.0004999999999999449
lr_schedule:
kwargs:
value: 0.0001
schedule_type: constant
optimizer:
kwargs:
grac_acc_steps: 1
max_grad_norm: 1.0
optimizer_type: adam
log_interval: 1
log_level: 0
project_name: jax-diffusion
restore: ''
seed: 42
steps: 100000