Skip to content

Commit

Permalink
fix: optimum classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Dec 10, 2024
1 parent 932b2bc commit 103cb6f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, *, engine_args: EngineArgs):
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()),
)

self.model = optimize_model(
model = optimize_model(
model_name_or_path=engine_args.model_name_or_path,
model_class=ORTModelForSequenceClassification,
revision=engine_args.revision,
Expand All @@ -48,7 +48,7 @@ def __init__(self, *, engine_args: EngineArgs):
file_name=onnx_file.as_posix(),
optimize_model=not os.environ.get("INFINITY_ONNX_DISABLE_OPTIMIZE", False),
)
self.model.use_io_binding = False
model.use_io_binding = False

self.tokenizer = AutoTokenizer.from_pretrained(
engine_args.model_name_or_path,
Expand All @@ -60,12 +60,11 @@ def __init__(self, *, engine_args: EngineArgs):

self._pipe = pipeline(
task="text-classification",
model=self.model,
model=model,
trust_remote_code=engine_args.trust_remote_code,
top_k=None,
revision=engine_args.revision,
tokenizer=self.tokenizer,
device=engine_args.device,
)

def encode_pre(self, sentences: list[str]):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import torch
from optimum.pipelines import pipeline # type: ignore
from transformers.pipelines import pipeline # type: ignore
from optimum.onnxruntime import ORTModelForSequenceClassification
from infinity_emb.args import EngineArgs

from infinity_emb.transformer.classifier.optimum import OptimumClassifier


def test_classifier(model_name: str = "SamLowe/roberta-base-go_emotions-onnx"):
model = OptimumClassifier(
engine_args=EngineArgs(
model_name_or_path=model_name,
device="cuda" if torch.cuda.is_available() else "cpu",
) # type: ignore
)

pipe = pipeline(
task="text-classification",
model=ORTModelForSequenceClassification.from_pretrained(
model_name, file_name="onnx/model_quantized.onnx"
),
model="SamLowe/roberta-base-go_emotions", # hoping that this is the same model as model_name
top_k=None,
)

Expand Down

0 comments on commit 103cb6f

Please sign in to comment.