From 24304f01f548367dd75262ca581a43ac88d78cce Mon Sep 17 00:00:00 2001 From: Daniel Ramos <46768340+danielcamposramos@users.noreply.github.com> Date: Sun, 26 Jan 2025 22:00:44 -0300 Subject: [PATCH] Update fasterwhisper.py Changed/Added some parameters and corrected others. --- transcription-api/backends/fasterwhisper.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/transcription-api/backends/fasterwhisper.py b/transcription-api/backends/fasterwhisper.py index f6db2c4..c32181d 100644 --- a/transcription-api/backends/fasterwhisper.py +++ b/transcription-api/backends/fasterwhisper.py @@ -6,11 +6,11 @@ from faster_whisper import WhisperModel, download_model, decode_audio class FasterWhisperBackend(Backend): - device: str = "cpu" # cpu, cuda - quantization: str = "int8" # int8, float16 + device: str = 'cpu' # cpu, cuda? (not working) + quantization: str = "default" # default,int8,float16,float32 model: WhisperModel | None = None - def __init__(self, model_size, device: str = "cpu"): + def __init__(self, model_size, device: str = None): self.model_size = model_size self.device = device self.__post_init__() @@ -26,8 +26,8 @@ def model_path(self) -> str: raise RuntimeError(f"model not found in {local_model_path}") def load(self) -> None: - # Get CPU threads env variable or default to 4 - cpu_threads = int(os.environ.get("CPU_THREADS", 4)) + # Get CPU threads env variable or default to 6 + cpu_threads = int(os.environ.get("CPU_THREADS", 6)) self.model = WhisperModel( self.model_path(), device=self.device, compute_type=self.quantization, cpu_threads=cpu_threads ) @@ -40,10 +40,13 @@ def get_model(self) -> None: if not os.path.exists(local_model_path): os.makedirs(local_model_path) try: + print("Model check start...") download_model(self.model_size, output_dir=local_model_path, local_files_only=True, cache_dir=local_model_cache) - print("Model already cached...") + print("Model ceck ended...") except: + print("Model download start...") download_model(self.model_size, output_dir=local_model_path, local_files_only=False, cache_dir=local_model_cache) + print("Model download ended.") def transcribe( self, input: np.ndarray, silent: bool = False, language: str = None @@ -96,4 +99,4 @@ def transcribe( "duration": info.duration, "segments": result, } - return transcription \ No newline at end of file + return transcription