From ddc9bf1aa995a50c671ea41858044b61037267a3 Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Wed, 29 May 2024 16:08:17 +0000 Subject: [PATCH] No public description PiperOrigin-RevId: 638306209 Change-Id: I29149616d8d244df0e1b28241221518aa8833dab --- nanodo/data.py | 5 +++-- pyproject.toml | 3 ++- tests/train_test.py | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/nanodo/data.py b/nanodo/data.py index b68f22c..bd558c6 100644 --- a/nanodo/data.py +++ b/nanodo/data.py @@ -13,6 +13,7 @@ # limitations under the License. """Data pipeline.""" +from collections.abc import Mapping, Sequence import dataclasses import enum import functools @@ -112,11 +113,11 @@ def get_tokenizer(self) -> spm.SentencePieceProcessor: def _py_tokenize( - features: dict[str, str], + features: Mapping[str, str], spt: _SPTokenizer, pad_len: int | None = None, pad_id: int = PAD_ID, -) -> list[int]: +) -> Sequence[int]: """Tokenizes text into ids, optionally pads or truncates to pad_len.""" text = features['text'] tokenizer = spt.get_tokenizer() diff --git a/pyproject.toml b/pyproject.toml index 2aa1991..18dee1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ dependencies = [ "optax>=0.2.2", "orbax>=0.1.7", "sentencepiece>=0.2.0", - "tensorflow-datasets>=4.9.4", + # TODO: Temporary fix while waiting for TFDS to be released. + "tfds-nightly", "tensorflow>=2.16.1", ] diff --git a/tests/train_test.py b/tests/train_test.py index 0ed42cd..5f92033 100644 --- a/tests/train_test.py +++ b/tests/train_test.py @@ -166,7 +166,6 @@ def test_train_and_evaluate(self, preprocessing): c = _get_config(self) c.checkpoint = True - c.pygrain_worker_count = 0 cfg = model.DoConfig(**c.model, V=c.V) m = model.TransformerDo(cfg)