Skip to content

Commit

Permalink
Make it possible to pass additional arguments when retrieving and imp…
Browse files Browse the repository at this point in the history
…licitly creating counters.

PiperOrigin-RevId: 601481160
  • Loading branch information
mjanusz authored and copybara-github committed Jan 25, 2024
1 parent c902001 commit 9fe3b0a
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions ffn/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def next(self):
class Counters:
"""Container for counters."""

def __init__(self, parent=None):
def __init__(self, parent: Counters | None):
self._lock = threading.Lock() # for self._counters
self.reset()
self.parent = parent
Expand All @@ -146,16 +146,31 @@ def reset(self):
self._counters = {}
self._last_update = 0

def __getitem__(self, name):
def __getitem__(self, name: str) -> StatCounter:
return self.get(name)

def get(self, name: str, **kwargs) -> StatCounter:
"""Retrieves the counter associated with the provided name.
If the counter does not exist, it will be created.
Args:
name: counter name
**kwargs: forwarded to _make_counter
Returns:
a counter corresponding to the specified name
"""
with self._lock:
if name not in self._counters:
self._counters[name] = self._make_counter(name)
self._counters[name] = self._make_counter(name, **kwargs)
return self._counters[name]

def __iter__(self):
return iter(self._counters.items())

def _make_counter(self, name):
def _make_counter(self, name: str, **kwargs) -> StatCounter:
del kwargs
return StatCounter(self.update_status, name)

def update_status(self):
Expand All @@ -164,17 +179,17 @@ def update_status(self):
def get_sub_counters(self):
return Counters(self)

def dump(self, filename):
def dump(self, filename: str):
with storage.atomic_file(filename, 'w') as fd:
for name, counter in sorted(self._counters.items()):
fd.write('%s: %d\n' % (name, counter.value))

def dumps(self):
def dumps(self) -> str:
state = {name: counter.value for name, counter in
self._counters.items()}
return json.dumps(state)

def loads(self, encoded_state):
def loads(self, encoded_state: str):
state = json.loads(encoded_state)
for name, value in state.items():
# Do not set the exported counters. Otherwise after computing
Expand Down

0 comments on commit 9fe3b0a

Please sign in to comment.