diff --git a/hf_hub_ctranslate2/__init__.py b/hf_hub_ctranslate2/__init__.py index f22c469..1a64edc 100644 --- a/hf_hub_ctranslate2/__init__.py +++ b/hf_hub_ctranslate2/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Compatability between Huggingface and Ctranslate2.""" -__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub"] -from hf_hub_ctranslate2.translate import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub +__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub", "MultiLingualTranslatorCT2fromHfHub"] +from hf_hub_ctranslate2.translate import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub, MultiLingualTranslatorCT2fromHfHub -__version__ = "1.0.2" +__version__ = "1.0.3" \ No newline at end of file diff --git a/hf_hub_ctranslate2/translate.py b/hf_hub_ctranslate2/translate.py index 3c9ac36..03882e4 100644 --- a/hf_hub_ctranslate2/translate.py +++ b/hf_hub_ctranslate2/translate.py @@ -28,8 +28,11 @@ def __init__( if os.path.isdir(model_name_or_path): model_path = model_name_or_path else: - model_path = download_model(model_name_or_path, hub_kwargs=hub_kwargs) - + try: + model_path = download_model(model_name_or_path, hub_kwargs=hub_kwargs) + except: + hub_kwargs["local_files_only"] = True + model_path = download_model(model_name_or_path, hub_kwargs=hub_kwargs) self.model = self.ctranslate_class( model_path, device=device, @@ -44,17 +47,22 @@ def __init__( def _forward(self, *args: Any, **kwds: Any) -> Any: raise NotImplementedError + + def tokenize_encode(self, text, *args, **kwargs): + return [ + self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(p)) for p in text + ] + def tokenize_decode(self, tokens_out, *args, **kwargs): + raise NotImplementedError - def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: Any): + def generate(self, text: Union[str, List[str]], encode_kwargs={}, *forward_args, **forward_kwds: Any): orig_type = list if isinstance(text, str): orig_type = str text = [text] - token_list = [ - self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(p)) for p in text - ] - texts_out = self._forward(token_list, *forward_args, **forward_kwds) - + token_list = self.tokenize_encode(text, **encode_kwargs) + tokens_out = self._forward(token_list, *forward_args, **forward_kwds) + texts_out = self.tokenize_decode(tokens_out) if orig_type == str: return texts_out[0] else: @@ -71,7 +79,7 @@ def __init__( tokenizer: Union[AutoTokenizer, None] = None, hub_kwargs={}, ): - """for ctranslate2.Translator models + """for ctranslate2.Translator models, in particular m2m-100 Args: model_name_or_path (str): _description_ @@ -92,7 +100,9 @@ def __init__( ) def _forward(self, *args, **kwds): - tokens_out = self.model.translate_batch(*args, **kwds) + return self.model.translate_batch(*args, **kwds) + + def tokenize_decode(self, tokens_out, *args): return [ self.tokenizer.decode( self.tokenizer.convert_tokens_to_ids(tokens_out[i].hypotheses[0]) @@ -140,6 +150,101 @@ def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: A """ return super().generate(text, *forward_args, **forward_kwds) +class MultiLingualTranslatorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub): + def __init__( + self, + model_name_or_path: str, + device: Literal["cpu", "cuda"] = "cuda", + device_index=0, + compute_type: Literal["int8_float16", "int8"] = "int8_float16", + tokenizer: Union[AutoTokenizer, None] = None, + hub_kwargs={}, + ): + """for ctranslate2.Translator models + + Args: + model_name_or_path (str): _description_ + device (Literal["cpu", "cuda"], optional): _description_. Defaults to "cuda". + device_index (int, optional): _description_. Defaults to 0. + compute_type (Literal["int8_float16", "int8"], optional): _description_. Defaults to "int8_float16". + tokenizer (Union[AutoTokenizer, None], optional): _description_. Defaults to None. + hub_kwargs (dict, optional): _description_. Defaults to {}. + """ + self.ctranslate_class = ctranslate2.Translator + super().__init__( + model_name_or_path, + device, + device_index, + compute_type, + tokenizer, + hub_kwargs, + ) + + def _forward(self, *args, **kwds): + target_prefix = [[self.tokenizer.lang_code_to_token[l]] for l in kwds.pop("tgt_lang")] + # target_prefix=[['__de__'], ['__fr__']] + return self.model.translate_batch(*args, **kwds, target_prefix=target_prefix) + + def tokenize_encode(self, text, *args, **kwargs): + tokens = [] + src_lang = kwargs.pop("src_lang") + for t, src_language in zip(text, src_lang): + self.tokenizer.src_lang = src_language + tokens.append(self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(t))) + return tokens + + def tokenize_decode(self, tokens_out, *args): + return [ + self.tokenizer.decode( + self.tokenizer.convert_tokens_to_ids(tokens_out[i].hypotheses[0][1:]) + ) + for i in range(len(tokens_out)) + ] + + def generate(self, text: Union[str, List[str]], src_lang: Union[str, List[str]], tgt_lang: Union[str, List[str]], *forward_args, **forward_kwds: Any): + """_summary_ + + Args: + text (Union[str, List[str]]): Input texts + src_lang (Union[str, List[str]]): soruce language of the Input texts + tgt_lang (Union[str, List[str]]): target language for outputs + max_batch_size (int, optional): Batch size. Defaults to 0. + batch_type (str, optional): _. Defaults to "examples". + asynchronous (bool, optional): Only False supported. Defaults to False. + beam_size (int, optional): _. Defaults to 2. + patience (float, optional): _. Defaults to 1. + num_hypotheses (int, optional): _. Defaults to 1. + length_penalty (float, optional): _. Defaults to 1. + coverage_penalty (float, optional): _. Defaults to 0. + repetition_penalty (float, optional): _. Defaults to 1. + no_repeat_ngram_size (int, optional): _. Defaults to 0. + disable_unk (bool, optional): _. Defaults to False. + suppress_sequences (Optional[List[List[str]]], optional): _. + Defaults to None. + end_token (Optional[Union[str, List[str], List[int]]], optional): _. + Defaults to None. + return_end_token (bool, optional): _. Defaults to False. + prefix_bias_beta (float, optional): _. Defaults to 0. + max_input_length (int, optional): _. Defaults to 1024. + max_decoding_length (int, optional): _. Defaults to 256. + min_decoding_length (int, optional): _. Defaults to 1. + use_vmap (bool, optional): _. Defaults to False. + return_scores (bool, optional): _. Defaults to False. + return_attention (bool, optional): _. Defaults to False. + return_alternatives (bool, optional): _. Defaults to False. + min_alternative_expansion_prob (float, optional): _. Defaults to 0. + sampling_topk (int, optional): _. Defaults to 1. + sampling_temperature (float, optional): _. Defaults to 1. + replace_unknowns (bool, optional): _. Defaults to False. + callback (_type_, optional): _. Defaults to None. + + Returns: + Union[str, List[str]]: text as output, if list, same len as input + """ + if not len(text) == len(src_lang) == len(tgt_lang): + raise ValueError(f"unequal len: text={len(text)} src_lang={len(src_lang)} tgt_lang={len(tgt_lang)}") + forward_kwds["tgt_lang"] = tgt_lang + return super().generate(text, *forward_args, **forward_kwds, encode_kwargs={"src_lang": src_lang}) class GeneratorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub): def __init__( @@ -172,12 +277,15 @@ def __init__( ) def _forward(self, *args, **kwds): - tokens_out = self.model.generate_batch(*args, **kwds) + return self.model.generate_batch(*args, **kwds) + + def tokenize_decode(self, tokens_out, *args): return [ self.tokenizer.decode(tokens_out[i].sequences_ids[0]) for i in range(len(tokens_out)) ] + def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: Any): """_summary_ @@ -211,3 +319,4 @@ def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: A str | List[str]: text as output, if list, same len as input """ return super().generate(text, *forward_args, **forward_kwds) + diff --git a/requirements_dev.txt b/requirements_dev.txt index 75b535e..f11d408 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -7,4 +7,5 @@ coverage~=6.2 flake8 pre-commit myst-parser[sphinx] -mkdocs \ No newline at end of file +mkdocs +sentencepiece \ No newline at end of file diff --git a/tests/test_translate.py b/tests/test_translate.py index 42181d8..048f183 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -1,4 +1,4 @@ -from hf_hub_ctranslate2 import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub +from hf_hub_ctranslate2 import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub, MultiLingualTranslatorCT2fromHfHub from hf_hub_ctranslate2._private.utils import download_model from transformers import AutoTokenizer @@ -17,6 +17,23 @@ def test_translator(model_name="michaelfeil/ct2fast-flan-alpaca-base"): for o in outputs: assert isinstance(o, str) +def test_multilingualtranslator(model_name="michaelfeil/ct2fast-m2m100_418M"): + model = MultiLingualTranslatorCT2fromHfHub( + model_name_or_path=model_name, device="cpu", compute_type="int8", + tokenizer=AutoTokenizer.from_pretrained(f"facebook/{model_name.split('-')[-1]}") + ) + + outputs = model.generate( + ["How do you call a fast Flamingo?", "Wie geht es dir?"], + src_lang=["en", "de"], + tgt_lang=["de", "fr"] + ) + assert len(outputs) == 2 + assert len(outputs[0]) != len(outputs[1]) + assert "nennt" in outputs[0].lower() + assert "comment" in outputs[1].lower() + for o in outputs: + assert isinstance(o, str) def test_generator(model_name="michaelfeil/ct2fast-pythia-160m"): model = GeneratorCT2fromHfHub(