Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
unixpickle committed Dec 20, 2021
0 parents commit 1f791b8
Show file tree
Hide file tree
Showing 30 changed files with 4,227 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
*.egg-info/
.DS_Store
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# GLIDE

This is the official codebase for running the small, filtered-data GLIDE model from [GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](openai.com).

For details on the pre-trained models in this repository, see the [Model Card](model-card.md).

# Usage

To install this package, clone this repository and then run:

```
pip install -e .
```

For detailed usage examples, see the [notebooks](notebooks) directory.

* The [text2im](notebooks/text2im.ipynb) notebook shows how to use GLIDE (filtered) with classifier-free guidance to produce images conditioned on text prompts.
* The [inpaint](notebooks/inpaint.ipynb) notebook shows how to use GLIDE (filtered) to fill in a masked region of an image, conditioned on a text prompt.
* The [clip_guided](notebooks/clip_guided.ipynb) notebook shows how to use GLIDE (filtered) + a filtered noise-aware CLIP model to produce images conditioned on text prompts.
3 changes: 3 additions & 0 deletions glide_text2im/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
A codebase for performing model inference with a text-conditional diffusion model.
"""
Empty file added glide_text2im/clip/__init__.py
Empty file.
179 changes: 179 additions & 0 deletions glide_text2im/clip/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import math
from abc import ABC, abstractmethod
from itertools import product
from typing import Any, Optional

import attr
import numpy as np
import torch


@attr.s
class AttentionMask(ABC):
query_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
key_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
block_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
n_head: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
is_head_specific: bool = attr.ib(default=False)
n_query_pad: int = attr.ib(default=0)
n_key_pad: int = attr.ib(default=0)

def __attrs_post_init__(self) -> None:
if self.query_context_size % self.block_size != 0:
raise ValueError()
if self.key_context_size % self.block_size != 0:
raise ValueError()
if self.n_query_pad >= self.query_context_size:
raise ValueError()
if self.n_key_pad >= self.key_context_size:
raise ValueError()

self.n_query_block = self.query_context_size // self.block_size
self.n_key_block = self.key_context_size // self.block_size
self.first_pad_query_block_idx = self.n_query_block - int(
math.ceil(self.n_query_pad / self.block_size)
)
self.first_pad_key_block_idx = self.n_key_block - int(
math.ceil(self.n_key_pad / self.block_size)
)

def _make_global_layout(self) -> None:
if not self.is_head_specific:
m = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
r = product(*[range(n) for n in m.shape])

for qb, kb in r:
m[qb, kb] = np.any(self.block_layout(None, 0, qb, kb, 0))
else:
m = np.ones([self.n_head, self.n_query_block, self.n_key_block], dtype=np.bool)
r = product(*[range(n) for n in m.shape])

for h, qb, kb in r:
m[h, qb, kb] = np.any(self.block_layout(None, h, qb, kb, 0))

self.global_layout = m

@abstractmethod
def _block_layout(
self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
) -> np.ndarray:
raise NotImplementedError()

def block_layout(
self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
) -> np.ndarray:
"""
`query_idx`, `key_idx` are block-level, zero-based indices.
"""

m = np.ones([self.block_size, self.block_size], dtype=np.bool)

if query_idx >= self.first_pad_query_block_idx:
n_pad = min(
self.block_size,
(query_idx + 1) * self.block_size - (self.query_context_size - self.n_query_pad),
)
assert n_pad > 0
m[self.block_size - n_pad :] = False
if key_idx >= self.first_pad_key_block_idx:
n_pad = min(
self.block_size,
(key_idx + 1) * self.block_size - (self.key_context_size - self.n_key_pad),
)
assert n_pad > 0
m[:, self.block_size - n_pad :] = False

return m & self._block_layout(blk_shape, head_idx, query_idx, key_idx, blk_idx)


@attr.s
class DenseAttentionMask(AttentionMask):
def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()

self.global_layout = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
n_zero_query_blocks = self.n_query_pad // self.block_size
n_zero_key_blocks = self.n_key_pad // self.block_size
self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False

def _block_layout(
self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
) -> np.ndarray:
return np.ones([self.block_size, self.block_size], dtype=np.bool)


@attr.s
class DenseCausalAttentionMask(AttentionMask):
def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()

self.global_layout = np.tril(np.ones([self.n_query_block, self.n_key_block], dtype=np.bool))
n_zero_query_blocks = self.n_query_pad // self.block_size
n_zero_key_blocks = self.n_key_pad // self.block_size
self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False

def _block_layout(
self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
) -> np.ndarray:
if query_idx > key_idx:
return np.ones(2 * [self.block_size], dtype=np.bool)
elif query_idx < key_idx:
return np.zeros(2 * [self.block_size], dtype=np.bool)
else:
return np.tril(np.ones(2 * [self.block_size], dtype=np.bool))


@attr.s(eq=False, repr=False)
class AttentionInfo:
n_heads: int = attr.ib()
ctx_blks_q: int = attr.ib()
ctx_blks_k: int = attr.ib()
block_size: int = attr.ib()
pytorch_attn_bias: Optional[torch.Tensor] = attr.ib()


def to_attention_info(d: AttentionMask) -> AttentionInfo:
return AttentionInfo(
n_heads=d.n_head,
ctx_blks_q=d.n_query_block,
ctx_blks_k=d.n_key_block,
block_size=d.block_size,
pytorch_attn_bias=None,
)


def make_full_layout(d: AttentionMask) -> np.ndarray:
"""
Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of
the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead.
"""

if not d.is_head_specific:
u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1])
r = product(range(d.n_query_block), range(d.n_key_block))
v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r])
v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size])

w = u * v
w = np.transpose(w, [0, 2, 1, 3])
w = np.reshape(w, [d.query_context_size, d.key_context_size])
return w
else:
if len(d.global_layout.shape) == 2:
u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1])
u = np.tile(u, [d.n_head, 1, 1, 1, 1])
elif len(d.global_layout.shape) == 3:
u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1])
else:
raise RuntimeError()

s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block))
v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s])
v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size])

w = u * v
w = np.transpose(w, [0, 1, 3, 2, 4])
w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size])
return w
18 changes: 18 additions & 0 deletions glide_text2im/clip/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
logit_scale: 100.0

# Diffusion settings
beta_schedule: "squaredcos_cap_v2"
n_timesteps: 1000

# Architecture settings
image_size: 64
patch_size: 4
n_vocab: 65536
max_text_len: 77
n_embd: 512
n_head_state_text: 64
n_head_text: 8
n_xf_blocks_text: 12
n_head_state_image: 64
n_head_image: 12
n_xf_blocks_image: 12
Loading

0 comments on commit 1f791b8

Please sign in to comment.