Skip to content

Commit

Permalink
bump version 2.0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed May 18, 2023
1 parent c36d715 commit 7374fd8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
6 changes: 3 additions & 3 deletions hf_hub_ctranslate2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""Compatability between Huggingface and Ctranslate2."""
__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub", "MultiLingualTranslatorCT2fromHfHub"]
__all__ = ["__version__", "TranslatorCT2fromHfHub", "GeneratorCT2fromHfHub", "MultiLingualTranslatorCT2fromHfHub", "_private"]
from hf_hub_ctranslate2.translate import TranslatorCT2fromHfHub, GeneratorCT2fromHfHub, MultiLingualTranslatorCT2fromHfHub

__version__ = "2.0.1"
import hf_hub_ctranslate2._private as _private
__version__ = "2.0.2"
16 changes: 10 additions & 6 deletions hf_hub_ctranslate2/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Union, List
import os

import hf_hub_ctranslate2._private.utils as utils
import hf_hub_ctranslate2._private.utils as _utils


class CTranslate2ModelfromHuggingfaceHub:
Expand All @@ -29,21 +29,25 @@ def __init__(
model_path = model_name_or_path
else:
try:
model_path = utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs)
model_path = _utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs)
except:
hub_kwargs["local_files_only"] = True
model_path = utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs)
model_path = _utils.download_model(model_name_or_path, hub_kwargs=hub_kwargs)
self.model = self.ctranslate_class(
model_path,
device=device,
device_index=device_index,
compute_type=compute_type,
)

if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True)
else:
if tokenizer is not None:
self.tokenizer = tokenizer
else:
if "tokenizer.json" in os.listdir(model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True)
if "tokenizer.json" in os.listdir(model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True)


def _forward(self, *args: Any, **kwds: Any) -> Any:
raise NotImplementedError
Expand Down

0 comments on commit 7374fd8

Please sign in to comment.