Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636694738
Change-Id: Ie3df52cce8bbed4f1e4874a7adad0a414682d604
  • Loading branch information
Alex Alemi authored and peterjliu committed May 26, 2024
1 parent e759c22 commit 8fd56f3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
32 changes: 19 additions & 13 deletions nanodo/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@

import sentencepiece as spm


PAD_ID = 0
EOS_ID = 1
BOS_ID = 2

### pure python helpers for use with grain ###


Expand Down Expand Up @@ -69,7 +65,9 @@ def py_batched_tfds(
pygrain_ops = [
grain.MapOperation(
map_function=functools.partial(
_py_tokenize, spt=spt, pad_id=PAD_ID, pad_len=pad_len
_py_tokenize,
spt=spt,
pad_len=pad_len,
)
)
]
Expand All @@ -93,9 +91,9 @@ def py_batched_tfds(
def get_py_tokenizer(path: str) -> spm.SentencePieceProcessor:
sp = spm.SentencePieceProcessor()
sp.Load(path)
assert sp.bos_id() == BOS_ID
assert sp.eos_id() == EOS_ID
assert sp.pad_id() == PAD_ID
assert sp.eos_id() != -1
assert sp.bos_id() != -1
return sp


Expand All @@ -121,9 +119,13 @@ def _py_tokenize(
) -> list[int]:
"""Tokenizes text into ids, optionally pads or truncates to pad_len."""
text = features['text']
ids = spt.get_tokenizer().EncodeAsIds(text)
ids.insert(0, BOS_ID)
ids.append(EOS_ID)
tokenizer = spt.get_tokenizer()
bos_id = tokenizer.bos_id()
eos_id = tokenizer.eos_id()
ids = tokenizer.EncodeAsIds(text)

ids.insert(0, bos_id)
ids.append(eos_id)
if pad_len is not None:
if len(ids) < pad_len:
ids.extend([pad_id] * (pad_len - len(ids)))
Expand Down Expand Up @@ -165,15 +167,19 @@ def __call__(
# pylint: disable=invalid-name


def get_in_out(in_BxL: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
def get_in_out(
in_BxL: jax.Array,
pad_id: int = PAD_ID,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Returns input, output, and weights for a batch of examples."""
# Assumes input of the form <BOS> <IDs> <EOS> for eval.
x_BxL = in_BxL
y_BxL = jnp.pad(
in_BxL[:, 1:],
((0, 0), (0, 1)),
mode='constant',
constant_values=PAD_ID,
constant_values=pad_id,
)
weights_BxL = jnp.where(y_BxL != PAD_ID, 1, 0).astype(jnp.float32)
weights_BxL = jnp.where(y_BxL != pad_id, 1, 0).astype(jnp.float32)

return x_BxL, y_BxL, weights_BxL
12 changes: 9 additions & 3 deletions nanodo/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,21 @@
(
"lm1b:1.1.0",
"cc_all.32000.100extra.bos.model",
): _BPN * (10_449_751 / 41_715_169.0), # 0.36139860649310773
): _BPN * (
10_449_751 / 41_715_169.0
), # 0.36139860649310773
(
"c4:3.1.0",
"cc_all.32000.100extra.bos.model",
): _BPN * (183_808_378 / 789_615_977.0), # 0.3358334217374176
): _BPN * (
183_808_378 / 789_615_977.0
), # 0.3358334217374176
(
"huggingface:cerebras__slimpajama_627b", # validation
"cc_all.32000.100extra.bos.model",
): _BPN * (560_013_105 / 2_174_889_064.0), # 0.3714801562937696
): _BPN * (
560_013_105 / 2_174_889_064.0
), # 0.3714801562937696
}


Expand Down

0 comments on commit 8fd56f3

Please sign in to comment.