diff --git a/py_src/yolov4/tflite/__init__.py b/py_src/yolov4/tflite/__init__.py index c917d3da..96af2464 100644 --- a/py_src/yolov4/tflite/__init__.py +++ b/py_src/yolov4/tflite/__init__.py @@ -21,6 +21,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import platform + import numpy as np try: @@ -33,6 +35,12 @@ from ..common.base_class import BaseClass +EDGETPU_SHARED_LIB = { + "Linux": "libedgetpu.so.1", + "Darwin": "libedgetpu.1.dylib", + "Windows": "edgetpu.dll", +}[platform.system()] + class YOLOv4(BaseClass): def __init__(self, tiny: bool = False, tpu: bool = False): @@ -46,11 +54,13 @@ def __init__(self, tiny: bool = False, tpu: bool = False): self.output_index = None self.output_size = None - def load_tflite(self, tflite_path: str) -> None: + def load_tflite( + self, tflite_path: str, edgetpu_lib: str = EDGETPU_SHARED_LIB + ) -> None: if self.tpu: self.interpreter = tflite.Interpreter( model_path=tflite_path, - experimental_delegates=[load_delegate("libedgetpu.so.1")], + experimental_delegates=[load_delegate(edgetpu_lib)], ) else: self.interpreter = tflite.Interpreter(model_path=tflite_path)