From 8fd56f3207f19b9f785deafa6a5657143f5d8ac2 Mon Sep 17 00:00:00 2001 From: Alex Alemi Date: Thu, 23 May 2024 22:24:35 +0000 Subject: [PATCH] No public description PiperOrigin-RevId: 636694738 Change-Id: Ie3df52cce8bbed4f1e4874a7adad0a414682d604 --- nanodo/data.py | 32 +++++++++++++++++++------------- nanodo/evaluate.py | 12 +++++++++--- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/nanodo/data.py b/nanodo/data.py index 5cdfa50..b68f22c 100644 --- a/nanodo/data.py +++ b/nanodo/data.py @@ -26,11 +26,7 @@ import sentencepiece as spm - PAD_ID = 0 -EOS_ID = 1 -BOS_ID = 2 - ### pure python helpers for use with grain ### @@ -69,7 +65,9 @@ def py_batched_tfds( pygrain_ops = [ grain.MapOperation( map_function=functools.partial( - _py_tokenize, spt=spt, pad_id=PAD_ID, pad_len=pad_len + _py_tokenize, + spt=spt, + pad_len=pad_len, ) ) ] @@ -93,9 +91,9 @@ def py_batched_tfds( def get_py_tokenizer(path: str) -> spm.SentencePieceProcessor: sp = spm.SentencePieceProcessor() sp.Load(path) - assert sp.bos_id() == BOS_ID - assert sp.eos_id() == EOS_ID assert sp.pad_id() == PAD_ID + assert sp.eos_id() != -1 + assert sp.bos_id() != -1 return sp @@ -121,9 +119,13 @@ def _py_tokenize( ) -> list[int]: """Tokenizes text into ids, optionally pads or truncates to pad_len.""" text = features['text'] - ids = spt.get_tokenizer().EncodeAsIds(text) - ids.insert(0, BOS_ID) - ids.append(EOS_ID) + tokenizer = spt.get_tokenizer() + bos_id = tokenizer.bos_id() + eos_id = tokenizer.eos_id() + ids = tokenizer.EncodeAsIds(text) + + ids.insert(0, bos_id) + ids.append(eos_id) if pad_len is not None: if len(ids) < pad_len: ids.extend([pad_id] * (pad_len - len(ids))) @@ -165,15 +167,19 @@ def __call__( # pylint: disable=invalid-name -def get_in_out(in_BxL: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]: +def get_in_out( + in_BxL: jax.Array, + pad_id: int = PAD_ID, +) -> tuple[jax.Array, jax.Array, jax.Array]: + """Returns input, output, and weights for a batch of examples.""" # Assumes input of the form for eval. x_BxL = in_BxL y_BxL = jnp.pad( in_BxL[:, 1:], ((0, 0), (0, 1)), mode='constant', - constant_values=PAD_ID, + constant_values=pad_id, ) - weights_BxL = jnp.where(y_BxL != PAD_ID, 1, 0).astype(jnp.float32) + weights_BxL = jnp.where(y_BxL != pad_id, 1, 0).astype(jnp.float32) return x_BxL, y_BxL, weights_BxL diff --git a/nanodo/evaluate.py b/nanodo/evaluate.py index bdadca9..66db8d7 100644 --- a/nanodo/evaluate.py +++ b/nanodo/evaluate.py @@ -47,15 +47,21 @@ ( "lm1b:1.1.0", "cc_all.32000.100extra.bos.model", - ): _BPN * (10_449_751 / 41_715_169.0), # 0.36139860649310773 + ): _BPN * ( + 10_449_751 / 41_715_169.0 + ), # 0.36139860649310773 ( "c4:3.1.0", "cc_all.32000.100extra.bos.model", - ): _BPN * (183_808_378 / 789_615_977.0), # 0.3358334217374176 + ): _BPN * ( + 183_808_378 / 789_615_977.0 + ), # 0.3358334217374176 ( "huggingface:cerebras__slimpajama_627b", # validation "cc_all.32000.100extra.bos.model", - ): _BPN * (560_013_105 / 2_174_889_064.0), # 0.3714801562937696 + ): _BPN * ( + 560_013_105 / 2_174_889_064.0 + ), # 0.3714801562937696 }