-
Notifications
You must be signed in to change notification settings - Fork 862
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo <[email protected]>
- Loading branch information
1 parent
3bcf5ec
commit e403d23
Showing
12 changed files
with
302 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import json | ||
from abc import ABC, abstractmethod | ||
from functools import lru_cache | ||
from typing import Any, Dict, List, Optional | ||
|
||
import dill | ||
import torch | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def _cache_from_str(json_str: str): | ||
"""Deserialize a json string to a Callable object. | ||
This function is cached to avoid redundant deserialization. | ||
""" | ||
data = json.loads(json_str) | ||
return dill.loads(bytes.fromhex(data["callable"])) | ||
|
||
|
||
class CustomLogitProcessor(ABC): | ||
"""Abstract base class for callable functions.""" | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
logits: torch.Tensor, | ||
custom_param_list: Optional[List[Dict[str, Any]]] = None, | ||
) -> torch.Tensor: | ||
"""Define the callable behavior.""" | ||
raise NotImplementedError | ||
|
||
def to_str(self) -> str: | ||
"""Serialize the callable function to a JSON-compatible string.""" | ||
return json.dumps({"callable": dill.dumps(self).hex()}) | ||
|
||
@classmethod | ||
def from_str(cls, json_str: str): | ||
"""Deserialize a callable function from a JSON string.""" | ||
return _cache_from_str(json_str) |
Oops, something went wrong.