Skip to content

Commit

Permalink
add clip: max model length
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Jun 10, 2024
1 parent af98e71 commit 0977109
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, *, engine_args: EngineArgs):
self.model = AutoModel.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
)
if torch.cuda.is_available():
self.model = self.model.cuda()
Expand All @@ -36,6 +37,7 @@ def __init__(self, *, engine_args: EngineArgs):
self.processor = AutoProcessor.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
)
self.engine_args = engine_args

Expand All @@ -51,6 +53,15 @@ def __init__(self, *, engine_args: EngineArgs):
assert hasattr(
self.model, "get_image_features"
), f"AutoModel of {engine_args.model_name_or_path} does not have get_image_features method"
self.max_length = None
if hasattr(self.model.config, "max_length"):
self.max_length = self.model.config.max_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.max_length = self.model.config.max_position_embeddings
elif hasattr(self.model.config, "text_config") and hasattr(
self.model.config.text_config, "max_length"
):
self.max_length = self.model.config.text_config.max_length

def encode_pre(self, sentences_or_images: list[Union[str, "ImageClass"]]):
# return input_tuples
Expand All @@ -72,6 +83,7 @@ def encode_pre(self, sentences_or_images: list[Union[str, "ImageClass"]]):
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
)
preprocessed = {k: v.to(self.model.device) for k, v in preprocessed.items()}

Expand Down Expand Up @@ -118,5 +130,7 @@ def encode_post(self, out_features) -> list[float]:
return embeddings

def tokenize_lengths(self, text_list: list[str]) -> list[int]:
preprocessed = self.processor(text=text_list, truncation=True)
preprocessed = self.processor(
text=text_list, truncation=True, max_length=self.max_length
)
return [len(t) for t in preprocessed["input_ids"]]
Loading

0 comments on commit 0977109

Please sign in to comment.