Skip to content

MNIST and Fashion MNIST!

Pre-release
Pre-release
Compare
Choose a tag to compare
@andylolu2 andylolu2 released this 18 Aug 19:46
· 27 commits to main since this release

Summary

Reproducible configs for MNIST and Fashion MNIST with their checkpoints. View release files for checkpoints (model weights).

Usage example

First, download the checkpoint files.

from PIL import Image

import jax
from ml_collections import FrozenConfigDict

from jax_diffusion.configs.mnist import get_config as mnist_config
from jax_diffusion.configs.fashion_mnist import get_config as fmnist_config
from jax_diffusion.train.trainer import Trainer
from jax_diffusion.utils.image import image_grid

# config = mnist_config()  # uncomment one
# config = fmnist_config()  # uncomment one
config.restore = "<path-to-checkpoint-file / directory>"
config = FrozenConfigDict(config)  # needed to be hashable

trainer = Trainer(jax.random.PRNGKey(0), **config.experiment_kwargs)
trainer.restore_checkpoint(config.restore)

img = trainer.sample(num=25, steps=100, jax.random.PRNGKey(1))
img = image_grid(img)[:, :, 0]
img = Image.fromarray(img)
img.save("sample.png", dpi=(144, 144))

MNIST

Config / hyperparameters
  >>> from jax_diffusion.configs.mnist import get_config
  >>> get_config().copy_and_resolve_references()
  ckpt_dir: gs://jax-diffusion-checkpoints
  ckpt_interval: 180
  dry_run: false
  effective_steps: 100000
  eval_interval: 180
  experiment_kwargs:
    config:
      dataset_kwargs:
        data_dir: /home/andylo/tensorflow_datasets
        map_calls: auto
        name: mnist
        prefetch: auto
        resize_dim: 32
        seed: 42
      diffusion:
        T: 1000
        beta_1: 0.0001
        beta_T: 0.02
      eval:
        dataset_kwargs:
          batch_size: 128
          subset: 40%
        sample_kwargs:
          num: 6
          steps: 100
      model:
        unet_kwargs:
          attention_num_heads: 4
          attention_resolutions: !!python/tuple
          - 16
          dim_init: 48
          dim_mults: !!python/tuple
          - 1
          - 2
          - 2
          - 2
          dropout: 0.1
          dtype: fp32
          kernel_size: 3
          num_groups: 4
          num_res_blocks: 2
          sinusoidal_embed_dim: 48
          time_embed_dim: 192
      seed: 42
      train:
        dataset_kwargs:
          augment: false
          batch_size: 128
          buffer_size: 1000
          repeat: true
          shuffle: true
          subset: 100%
        ema_step_size: 0.0004999999999999449
        lr_schedule:
          kwargs:
            value: 0.0001
          schedule_type: constant
        optimizer:
          kwargs:
            grac_acc_steps: 1
            max_grad_norm: 1.0
          optimizer_type: adam
  log_interval: 1
  log_level: 0
  project_name: jax-diffusion
  restore: ''
  seed: 42
  steps: 100000 

Fashion MNIST

Terminated after roughly 90K steps even though it is set at 100K steps in the config.

Config / hyperparameters
>>> from jax_diffusion.configs.fashion_mnist import get_config
>>> get_config().copy_and_resolve_references()
ckpt_dir: gs://jax-diffusion-checkpoints
ckpt_interval: 180
dry_run: false
effective_steps: 100000
eval_interval: 180
experiment_kwargs:
  config:
    dataset_kwargs:
      data_dir: /home/andylo/tensorflow_datasets
      map_calls: auto
      name: fashion_mnist
      prefetch: auto
      resize_dim: 32
      seed: 42
    diffusion:
      T: 1000
      beta_1: 0.0001
      beta_T: 0.02
    eval:
      dataset_kwargs:
        batch_size: 128
        subset: 40%
      sample_kwargs:
        num: 6
        steps: 100
    model:
      unet_kwargs:
        attention_num_heads: 4
        attention_resolutions: !!python/tuple
        - 16
        dim_init: 64
        dim_mults: !!python/tuple
        - 1
        - 2
        - 2
        - 2
        dropout: 0.1
        dtype: fp32
        kernel_size: 3
        num_groups: 4
        num_res_blocks: 2
        sinusoidal_embed_dim: 64
        time_embed_dim: 256
    seed: 42
    train:
      dataset_kwargs:
        augment: false
        batch_size: 128
        buffer_size: 1000
        repeat: true
        shuffle: true
        subset: 100%
      ema_step_size: 0.0004999999999999449
      lr_schedule:
        kwargs:
          value: 0.0001
        schedule_type: constant
      optimizer:
        kwargs:
          grac_acc_steps: 1
          max_grad_norm: 1.0
        optimizer_type: adam
log_interval: 1
log_level: 0
project_name: jax-diffusion
restore: ''
seed: 42
steps: 100000