Skip to content

Commit

Permalink
Merge pull request #354 from andrewwhitehead/fix/entry-cache
Browse files Browse the repository at this point in the history
Python wrapper entry cache adjustment
  • Loading branch information
swcurran authored Jan 29, 2025
2 parents f948424 + aa554a1 commit dba65c0
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 25 deletions.
10 changes: 9 additions & 1 deletion wrappers/python/aries_askar/bindings/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
c_void_p,
)

from .lib import ByteBuffer, Lib, StrBuffer, finalize_struct
from .lib import ByteBuffer, Lib, StrBuffer, entry_cache, finalize_struct


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,6 +124,7 @@ class EntryListHandle(ArcHandle):

_dtor_ = "askar_entry_list_free"

@entry_cache
def get_category(self, index: int) -> str:
"""Get the entry category."""
cat = StrBuffer()
Expand All @@ -136,6 +137,7 @@ def get_category(self, index: int) -> str:
)
return str(cat)

@entry_cache
def get_name(self, index: int) -> str:
"""Get the entry name."""
name = StrBuffer()
Expand All @@ -148,6 +150,7 @@ def get_name(self, index: int) -> str:
)
return str(name)

@entry_cache
def get_value(self, index: int) -> memoryview:
"""Get the entry value."""
val = ByteBuffer()
Expand All @@ -160,6 +163,7 @@ def get_value(self, index: int) -> memoryview:
)
return val.view

@entry_cache
def get_tags(self, index: int) -> dict:
"""Get the entry tags."""
tags = StrBuffer()
Expand All @@ -185,6 +189,7 @@ class KeyEntryListHandle(ArcHandle):

_dtor_ = "askar_key_entry_list_free"

@entry_cache
def get_algorithm(self, index: int) -> str:
"""Get the key algorithm."""
name = StrBuffer()
Expand All @@ -197,6 +202,7 @@ def get_algorithm(self, index: int) -> str:
)
return str(name)

@entry_cache
def get_name(self, index: int) -> str:
"""Get the key name."""
name = StrBuffer()
Expand All @@ -209,6 +215,7 @@ def get_name(self, index: int) -> str:
)
return str(name)

@entry_cache
def get_metadata(self, index: int) -> str:
"""Get for the key metadata."""
metadata = StrBuffer()
Expand All @@ -221,6 +228,7 @@ def get_metadata(self, index: int) -> str:
)
return str(metadata)

@entry_cache
def get_tags(self, index: int) -> dict:
"""Get the key tags."""
tags = StrBuffer()
Expand Down
24 changes: 24 additions & 0 deletions wrappers/python/aries_askar/bindings/lib.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Library instance and allocated buffer handling."""

import asyncio
import functools
import itertools
import logging
import os
import sys
import threading
import time
from copy import deepcopy

try:
import orjson as json
Expand Down Expand Up @@ -48,6 +50,28 @@
}


def entry_cache(fn):
"""Cache results for properties of individual entries."""

@functools.wraps(fn)
def wrapper(self, index: int):
if not hasattr(self, "_ecache"):
setattr(self, "_ecache", {})
cache = self._ecache
ckey = (fn, index)
if ckey in cache:
res = cache[ckey]
else:
res = fn(self, index)
cache[ckey] = res
if isinstance(res, dict):
# make sure the cached copy is not mutated
res = deepcopy(res)
return res

return wrapper


def _convert_log_level(level: Union[str, int, None]):
if level is None or level == "-1":
return -1
Expand Down
13 changes: 1 addition & 12 deletions wrappers/python/aries_askar/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
except ImportError:
import json

from functools import lru_cache
from typing import Optional, Sequence, Union

from . import bindings
Expand All @@ -32,25 +31,21 @@ def __init__(self, lst: EntryListHandle, pos: int):
self._pos = pos

@property
@lru_cache(maxsize=None)
def category(self) -> str:
"""Accessor for the entry category."""
return self._list.get_category(self._pos)

@property
@lru_cache(maxsize=None)
def name(self) -> str:
"""Accessor for the entry name."""
return self._list.get_name(self._pos)

@property
@lru_cache(maxsize=None)
def value(self) -> bytes:
"""Accessor for the entry value."""
return bytes(self.raw_value)

@property
@lru_cache(maxsize=None)
def raw_value(self) -> memoryview:
"""Accessor for the entry raw value."""
return self._list.get_value(self._pos)
Expand All @@ -61,7 +56,6 @@ def value_json(self) -> dict:
return json.loads(self.value)

@property
@lru_cache(maxsize=None)
def tags(self) -> dict:
"""Accessor for the entry tags."""
return self._list.get_tags(self._pos)
Expand Down Expand Up @@ -152,31 +146,26 @@ def __init__(self, lst: KeyEntryListHandle, pos: int):
self._pos = pos

@property
@lru_cache(maxsize=None)
def algorithm(self) -> str:
"""Accessor for the key entry algorithm."""
return self._list.get_algorithm(self._pos)

@property
@lru_cache(maxsize=None)
def name(self) -> str:
"""Accessor for the key entry name."""
return self._list.get_name(self._pos)

@property
@lru_cache(maxsize=None)
def metadata(self) -> str:
"""Accessor for the key entry metadata."""
return self._list.get_metadata(self._pos)

@property
@lru_cache(maxsize=None)
def key(self) -> Key:
"""Accessor for the entry metadata."""
"""Accessor for the key instance."""
return Key(self._list.load_key(self._pos))

@property
@lru_cache(maxsize=None)
def tags(self) -> dict:
"""Accessor for the entry tags."""
return self._list.get_tags(self._pos)
Expand Down
3 changes: 2 additions & 1 deletion wrappers/python/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ per_file_ignores = */__init__.py:D104
minversion = 5.0
testpaths =
tests
asyncio_mode=strict
asyncio_mode=auto
asyncio_default_fixture_loop_scope=session
65 changes: 54 additions & 11 deletions wrappers/python/tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import gc
import os
from weakref import WeakKeyDictionary

from pytest import mark, raises
import pytest_asyncio
Expand All @@ -10,6 +12,7 @@
Key,
Store,
)
from aries_askar.bindings.lib import entry_cache


TEST_STORE_URI = os.getenv("TEST_STORE_URI", "sqlite://:memory:")
Expand All @@ -26,15 +29,13 @@ def raw_key() -> str:


@pytest_asyncio.fixture
@mark.asyncio
async def store() -> Store:
key = raw_key()
store = await Store.provision(TEST_STORE_URI, "raw", key, recreate=True)
yield store
await store.close(remove=True)


@mark.asyncio
async def test_insert_update(store: Store):
async with store as session:
# Insert a new entry
Expand Down Expand Up @@ -81,7 +82,6 @@ async def test_insert_update(store: Store):
assert found is None


@mark.asyncio
async def test_remove_all(store: Store):
async with store as session:
# Insert a new entry
Expand All @@ -104,7 +104,6 @@ async def test_remove_all(store: Store):
assert found is None


@mark.asyncio
async def test_scan(store: Store):
async with store as session:
await session.insert(
Expand Down Expand Up @@ -133,7 +132,6 @@ async def test_scan(store: Store):
assert len(rows) == 1 and dict(rows[0]) == TEST_ENTRY


@mark.asyncio
async def test_txn_basic(store: Store):
async with store.transaction() as txn:
# Insert a new entry
Expand Down Expand Up @@ -167,7 +165,6 @@ async def test_txn_basic(store: Store):
assert dict(found) == TEST_ENTRY


@mark.asyncio
async def test_txn_autocommit(store: Store):
with raises(Exception):
async with store.transaction(autocommit=True) as txn:
Expand Down Expand Up @@ -203,7 +200,6 @@ async def test_txn_autocommit(store: Store):
assert dict(found) == TEST_ENTRY


@mark.asyncio
async def test_txn_contention(store: Store):
async with store.transaction() as txn:
await txn.insert(
Expand Down Expand Up @@ -240,7 +236,6 @@ async def inc():
assert int(result.value) == INC_COUNT * TASKS


@mark.asyncio
async def test_key_store_ed25519(store: Store):
# test key operations in a new session
async with store as session:
Expand Down Expand Up @@ -279,7 +274,6 @@ async def test_key_store_ed25519(store: Store):
assert await session.fetch_key(key_name) is None


@mark.asyncio
@mark.parametrize(
"key_alg",
[KeyAlg.A128CBC_HS256, KeyAlg.XC20P],
Expand Down Expand Up @@ -318,7 +312,6 @@ async def test_key_store_symmetric(store: Store, key_alg: KeyAlg):
assert await session.fetch_key(key_name) is None


@mark.asyncio
async def test_profile(store: Store):
# New session in the default profile
async with store as session:
Expand Down Expand Up @@ -409,7 +402,6 @@ async def test_profile(store: Store):
assert (await store.get_default_profile()) == profile


@mark.asyncio
async def test_copy(store: Store):
async with store as session:
# Insert a new entry
Expand All @@ -429,3 +421,54 @@ async def test_copy(store: Store):
entries = await session.fetch_all(TEST_ENTRY["category"])
assert len(entries) == 1
assert entries[0].name == TEST_ENTRY["name"]


def test_entry_cache():
instances = WeakKeyDictionary()

class MockList:
def __init__(self, name: str, value: dict):
self._name = name
self._value = value
self._calls = []
instances[self] = True

@entry_cache
def get_name(self, index: int) -> str:
self._calls.append(index)
return self._name + str(index)

@entry_cache
def get_value(self, index: int) -> dict:
self._calls.append(index)
return self._value

NAME = "testname"
VALUE = {"a": "b"}
lst = MockList(NAME, VALUE)
# check instance is registered
assert instances

# check first call goes to method
assert lst.get_name(99) == NAME + "99"
assert lst._calls == [99]
assert lst.get_name(45) == NAME + "45"
assert lst._calls == [99, 45]
val = lst.get_value(11)
assert val == VALUE
assert lst._calls == [99, 45, 11]

# check dict value is copied
val["a"] = "c"
assert val != VALUE

# check second call goes to cache
assert lst.get_name(99) == NAME + "99"
assert lst.get_name(45) == NAME + "45"
assert lst.get_value(11) == VALUE
assert lst._calls == [99, 45, 11]

# ensure no extra references are keeping the instance around
del lst
gc.collect()
assert not instances

0 comments on commit dba65c0

Please sign in to comment.