From 7374fd86935ba501ad74ece53645714af5e9c0e5 Mon Sep 17 00:00:00 2001 From: Michael Feil Date: Fri, 19 May 2023 01:04:38 +0200 Subject: [PATCH] bump version 2.0.2 --- hf_hub_ctranslate2/__init__.py | 6 +++--- hf_hub_ctranslate2/translate.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/hf_hub_ctranslate2/__init__.py b/hf_hub_ctranslate2/__init__.py index d7422a2..4164126 100644 --- a/hf_hub_ctranslate2/__init__.py +++ b/hf_hub_ctranslate2/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Compatability between Huggingface and Ctranslate2.""" -__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub", "MultiLingualTranslatorCT2fromHfHub"] +__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub", "MultiLingualTranslatorCT2fromHfHub", "_private"] from hf_hub_ctranslate2.translate import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub, MultiLingualTranslatorCT2fromHfHub - -__version__ = "2.0.1" \ No newline at end of file +import hf_hub_ctranslate2._private as _private +__version__ = "2.0.2" \ No newline at end of file diff --git a/hf_hub_ctranslate2/translate.py b/hf_hub_ctranslate2/translate.py index 09f5de1..09b9a14 100644 --- a/hf_hub_ctranslate2/translate.py +++ b/hf_hub_ctranslate2/translate.py @@ -9,7 +9,7 @@ from typing import Any, Union, List import os -import hf_hub_ctranslate2._private.utils as utils +import hf_hub_ctranslate2._private.utils as _utils class CTranslate2ModelfromHuggingfaceHub: @@ -29,10 +29,10 @@ def __init__( model_path = model_name_or_path else: try: - model_path = utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs) + model_path = _utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs) except: hub_kwargs["local_files_only"] = True - model_path = utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs) + model_path = _utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs) self.model = self.ctranslate_class( model_path, device=device, @@ -40,10 +40,14 @@ def __init__( compute_type=compute_type, ) - if tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True) - else: + if tokenizer is not None: self.tokenizer = tokenizer + else: + if "tokenizer.json" in os.listdir(model_path): + self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True) + if "tokenizer.json" in os.listdir(model_path): + self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True) + def _forward(self, *args: Any, **kwds: Any) -> Any: raise NotImplementedError