Skip to content

Commit

Permalink
Update translate.py (#5)
Browse files Browse the repository at this point in the history
* Update translate.py

* Update test_translate.py

* release 1.03 for m2m

---------

Co-authored-by: Michael Feil <[email protected]>
  • Loading branch information
michaelfeil and michaelfeil authored May 13, 2023
1 parent 1406b9a commit 9d92920
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 16 deletions.
6 changes: 3 additions & 3 deletions hf_hub_ctranslate2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""Compatability between Huggingface and Ctranslate2."""
__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub"]
from hf_hub_ctranslate2.translate import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub
__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub", "MultiLingualTranslatorCT2fromHfHub"]
from hf_hub_ctranslate2.translate import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub, MultiLingualTranslatorCT2fromHfHub

__version__ = "1.0.2"
__version__ = "1.0.3"
131 changes: 120 additions & 11 deletions hf_hub_ctranslate2/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ def __init__(
if os.path.isdir(model_name_or_path):
model_path = model_name_or_path
else:
model_path = download_model(model_name_or_path, hub_kwargs=hub_kwargs)

try:
model_path = download_model(model_name_or_path, hub_kwargs=hub_kwargs)
except:
hub_kwargs["local_files_only"] = True
model_path = download_model(model_name_or_path, hub_kwargs=hub_kwargs)
self.model = self.ctranslate_class(
model_path,
device=device,
Expand All @@ -44,17 +47,22 @@ def __init__(

def _forward(self, *args: Any, **kwds: Any) -> Any:
raise NotImplementedError

def tokenize_encode(self, text, *args, **kwargs):
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(p)) for p in text
]
def tokenize_decode(self, tokens_out, *args, **kwargs):
raise NotImplementedError

def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: Any):
def generate(self, text: Union[str, List[str]], encode_kwargs={}, *forward_args, **forward_kwds: Any):
orig_type = list
if isinstance(text, str):
orig_type = str
text = [text]
token_list = [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(p)) for p in text
]
texts_out = self._forward(token_list, *forward_args, **forward_kwds)

token_list = self.tokenize_encode(text, **encode_kwargs)
tokens_out = self._forward(token_list, *forward_args, **forward_kwds)
texts_out = self.tokenize_decode(tokens_out)
if orig_type == str:
return texts_out[0]
else:
Expand All @@ -71,7 +79,7 @@ def __init__(
tokenizer: Union[AutoTokenizer, None] = None,
hub_kwargs={},
):
"""for ctranslate2.Translator models
"""for ctranslate2.Translator models, in particular m2m-100
Args:
model_name_or_path (str): _description_
Expand All @@ -92,7 +100,9 @@ def __init__(
)

def _forward(self, *args, **kwds):
tokens_out = self.model.translate_batch(*args, **kwds)
return self.model.translate_batch(*args, **kwds)

def tokenize_decode(self, tokens_out, *args):
return [
self.tokenizer.decode(
self.tokenizer.convert_tokens_to_ids(tokens_out[i].hypotheses[0])
Expand Down Expand Up @@ -140,6 +150,101 @@ def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: A
"""
return super().generate(text, *forward_args, **forward_kwds)

class MultiLingualTranslatorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
def __init__(
self,
model_name_or_path: str,
device: Literal["cpu", "cuda"] = "cuda",
device_index=0,
compute_type: Literal["int8_float16", "int8"] = "int8_float16",
tokenizer: Union[AutoTokenizer, None] = None,
hub_kwargs={},
):
"""for ctranslate2.Translator models
Args:
model_name_or_path (str): _description_
device (Literal[&quot;cpu&quot;, &quot;cuda&quot;], optional): _description_. Defaults to "cuda".
device_index (int, optional): _description_. Defaults to 0.
compute_type (Literal[&quot;int8_float16&quot;, &quot;int8&quot;], optional): _description_. Defaults to "int8_float16".
tokenizer (Union[AutoTokenizer, None], optional): _description_. Defaults to None.
hub_kwargs (dict, optional): _description_. Defaults to {}.
"""
self.ctranslate_class = ctranslate2.Translator
super().__init__(
model_name_or_path,
device,
device_index,
compute_type,
tokenizer,
hub_kwargs,
)

def _forward(self, *args, **kwds):
target_prefix = [[self.tokenizer.lang_code_to_token[l]] for l in kwds.pop("tgt_lang")]
# target_prefix=[['__de__'], ['__fr__']]
return self.model.translate_batch(*args, **kwds, target_prefix=target_prefix)

def tokenize_encode(self, text, *args, **kwargs):
tokens = []
src_lang = kwargs.pop("src_lang")
for t, src_language in zip(text, src_lang):
self.tokenizer.src_lang = src_language
tokens.append(self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(t)))
return tokens

def tokenize_decode(self, tokens_out, *args):
return [
self.tokenizer.decode(
self.tokenizer.convert_tokens_to_ids(tokens_out[i].hypotheses[0][1:])
)
for i in range(len(tokens_out))
]

def generate(self, text: Union[str, List[str]], src_lang: Union[str, List[str]], tgt_lang: Union[str, List[str]], *forward_args, **forward_kwds: Any):
"""_summary_
Args:
text (Union[str, List[str]]): Input texts
src_lang (Union[str, List[str]]): soruce language of the Input texts
tgt_lang (Union[str, List[str]]): target language for outputs
max_batch_size (int, optional): Batch size. Defaults to 0.
batch_type (str, optional): _. Defaults to "examples".
asynchronous (bool, optional): Only False supported. Defaults to False.
beam_size (int, optional): _. Defaults to 2.
patience (float, optional): _. Defaults to 1.
num_hypotheses (int, optional): _. Defaults to 1.
length_penalty (float, optional): _. Defaults to 1.
coverage_penalty (float, optional): _. Defaults to 0.
repetition_penalty (float, optional): _. Defaults to 1.
no_repeat_ngram_size (int, optional): _. Defaults to 0.
disable_unk (bool, optional): _. Defaults to False.
suppress_sequences (Optional[List[List[str]]], optional): _.
Defaults to None.
end_token (Optional[Union[str, List[str], List[int]]], optional): _.
Defaults to None.
return_end_token (bool, optional): _. Defaults to False.
prefix_bias_beta (float, optional): _. Defaults to 0.
max_input_length (int, optional): _. Defaults to 1024.
max_decoding_length (int, optional): _. Defaults to 256.
min_decoding_length (int, optional): _. Defaults to 1.
use_vmap (bool, optional): _. Defaults to False.
return_scores (bool, optional): _. Defaults to False.
return_attention (bool, optional): _. Defaults to False.
return_alternatives (bool, optional): _. Defaults to False.
min_alternative_expansion_prob (float, optional): _. Defaults to 0.
sampling_topk (int, optional): _. Defaults to 1.
sampling_temperature (float, optional): _. Defaults to 1.
replace_unknowns (bool, optional): _. Defaults to False.
callback (_type_, optional): _. Defaults to None.
Returns:
Union[str, List[str]]: text as output, if list, same len as input
"""
if not len(text) == len(src_lang) == len(tgt_lang):
raise ValueError(f"unequal len: text={len(text)} src_lang={len(src_lang)} tgt_lang={len(tgt_lang)}")
forward_kwds["tgt_lang"] = tgt_lang
return super().generate(text, *forward_args, **forward_kwds, encode_kwargs={"src_lang": src_lang})

class GeneratorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
def __init__(
Expand Down Expand Up @@ -172,12 +277,15 @@ def __init__(
)

def _forward(self, *args, **kwds):
tokens_out = self.model.generate_batch(*args, **kwds)
return self.model.generate_batch(*args, **kwds)

def tokenize_decode(self, tokens_out, *args):
return [
self.tokenizer.decode(tokens_out[i].sequences_ids[0])
for i in range(len(tokens_out))
]


def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: Any):
"""_summary_
Expand Down Expand Up @@ -211,3 +319,4 @@ def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: A
str | List[str]: text as output, if list, same len as input
"""
return super().generate(text, *forward_args, **forward_kwds)

3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ coverage~=6.2
flake8
pre-commit
myst-parser[sphinx]
mkdocs
mkdocs
sentencepiece
19 changes: 18 additions & 1 deletion tests/test_translate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hf_hub_ctranslate2 import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub
from hf_hub_ctranslate2 import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub, MultiLingualTranslatorCT2fromHfHub
from hf_hub_ctranslate2._private.utils import download_model
from transformers import AutoTokenizer

Expand All @@ -17,6 +17,23 @@ def test_translator(model_name="michaelfeil/ct2fast-flan-alpaca-base"):
for o in outputs:
assert isinstance(o, str)

def test_multilingualtranslator(model_name="michaelfeil/ct2fast-m2m100_418M"):
model = MultiLingualTranslatorCT2fromHfHub(
model_name_or_path=model_name, device="cpu", compute_type="int8",
tokenizer=AutoTokenizer.from_pretrained(f"facebook/{model_name.split('-')[-1]}")
)

outputs = model.generate(
["How do you call a fast Flamingo?", "Wie geht es dir?"],
src_lang=["en", "de"],
tgt_lang=["de", "fr"]
)
assert len(outputs) == 2
assert len(outputs[0]) != len(outputs[1])
assert "nennt" in outputs[0].lower()
assert "comment" in outputs[1].lower()
for o in outputs:
assert isinstance(o, str)

def test_generator(model_name="michaelfeil/ct2fast-pythia-160m"):
model = GeneratorCT2fromHfHub(
Expand Down

0 comments on commit 9d92920

Please sign in to comment.