Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python wrapper entry cache adjustment #354

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion wrappers/python/aries_askar/bindings/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
c_void_p,
)

from .lib import ByteBuffer, Lib, StrBuffer, finalize_struct
from .lib import ByteBuffer, Lib, StrBuffer, entry_cache, finalize_struct


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,6 +124,7 @@ class EntryListHandle(ArcHandle):

_dtor_ = "askar_entry_list_free"

@entry_cache
def get_category(self, index: int) -> str:
"""Get the entry category."""
cat = StrBuffer()
Expand All @@ -136,6 +137,7 @@ def get_category(self, index: int) -> str:
)
return str(cat)

@entry_cache
def get_name(self, index: int) -> str:
"""Get the entry name."""
name = StrBuffer()
Expand All @@ -148,6 +150,7 @@ def get_name(self, index: int) -> str:
)
return str(name)

@entry_cache
def get_value(self, index: int) -> memoryview:
"""Get the entry value."""
val = ByteBuffer()
Expand All @@ -160,6 +163,7 @@ def get_value(self, index: int) -> memoryview:
)
return val.view

@entry_cache
def get_tags(self, index: int) -> dict:
"""Get the entry tags."""
tags = StrBuffer()
Expand All @@ -185,6 +189,7 @@ class KeyEntryListHandle(ArcHandle):

_dtor_ = "askar_key_entry_list_free"

@entry_cache
def get_algorithm(self, index: int) -> str:
"""Get the key algorithm."""
name = StrBuffer()
Expand All @@ -197,6 +202,7 @@ def get_algorithm(self, index: int) -> str:
)
return str(name)

@entry_cache
def get_name(self, index: int) -> str:
"""Get the key name."""
name = StrBuffer()
Expand All @@ -209,6 +215,7 @@ def get_name(self, index: int) -> str:
)
return str(name)

@entry_cache
def get_metadata(self, index: int) -> str:
"""Get for the key metadata."""
metadata = StrBuffer()
Expand All @@ -221,6 +228,7 @@ def get_metadata(self, index: int) -> str:
)
return str(metadata)

@entry_cache
def get_tags(self, index: int) -> dict:
"""Get the key tags."""
tags = StrBuffer()
Expand Down
24 changes: 24 additions & 0 deletions wrappers/python/aries_askar/bindings/lib.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Library instance and allocated buffer handling."""

import asyncio
import functools
import itertools
import logging
import os
import sys
import threading
import time
from copy import deepcopy

try:
import orjson as json
Expand Down Expand Up @@ -48,6 +50,28 @@
}


def entry_cache(fn):
"""Cache results for properties of individual entries."""

@functools.wraps(fn)
def wrapper(self, index: int):
if not hasattr(self, "_ecache"):
setattr(self, "_ecache", {})
cache = self._ecache
ckey = (fn, index)
if ckey in cache:
res = cache[ckey]
else:
res = fn(self, index)
cache[ckey] = res
if isinstance(res, dict):
# make sure the cached copy is not mutated
res = deepcopy(res)
return res

return wrapper


def _convert_log_level(level: Union[str, int, None]):
if level is None or level == "-1":
return -1
Expand Down
13 changes: 1 addition & 12 deletions wrappers/python/aries_askar/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
except ImportError:
import json

from functools import lru_cache
from typing import Optional, Sequence, Union

from . import bindings
Expand All @@ -32,25 +31,21 @@ def __init__(self, lst: EntryListHandle, pos: int):
self._pos = pos

@property
@lru_cache(maxsize=None)
def category(self) -> str:
"""Accessor for the entry category."""
return self._list.get_category(self._pos)

@property
@lru_cache(maxsize=None)
def name(self) -> str:
"""Accessor for the entry name."""
return self._list.get_name(self._pos)

@property
@lru_cache(maxsize=None)
def value(self) -> bytes:
"""Accessor for the entry value."""
return bytes(self.raw_value)

@property
@lru_cache(maxsize=None)
def raw_value(self) -> memoryview:
"""Accessor for the entry raw value."""
return self._list.get_value(self._pos)
Expand All @@ -61,7 +56,6 @@ def value_json(self) -> dict:
return json.loads(self.value)

@property
@lru_cache(maxsize=None)
def tags(self) -> dict:
"""Accessor for the entry tags."""
return self._list.get_tags(self._pos)
Expand Down Expand Up @@ -152,31 +146,26 @@ def __init__(self, lst: KeyEntryListHandle, pos: int):
self._pos = pos

@property
@lru_cache(maxsize=None)
def algorithm(self) -> str:
"""Accessor for the key entry algorithm."""
return self._list.get_algorithm(self._pos)

@property
@lru_cache(maxsize=None)
def name(self) -> str:
"""Accessor for the key entry name."""
return self._list.get_name(self._pos)

@property
@lru_cache(maxsize=None)
def metadata(self) -> str:
"""Accessor for the key entry metadata."""
return self._list.get_metadata(self._pos)

@property
@lru_cache(maxsize=None)
def key(self) -> Key:
"""Accessor for the entry metadata."""
"""Accessor for the key instance."""
return Key(self._list.load_key(self._pos))

@property
@lru_cache(maxsize=None)
def tags(self) -> dict:
"""Accessor for the entry tags."""
return self._list.get_tags(self._pos)
Expand Down
3 changes: 2 additions & 1 deletion wrappers/python/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ per_file_ignores = */__init__.py:D104
minversion = 5.0
testpaths =
tests
asyncio_mode=strict
asyncio_mode=auto
asyncio_default_fixture_loop_scope=session
65 changes: 54 additions & 11 deletions wrappers/python/tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import gc
import os
from weakref import WeakKeyDictionary

from pytest import mark, raises
import pytest_asyncio
Expand All @@ -10,6 +12,7 @@
Key,
Store,
)
from aries_askar.bindings.lib import entry_cache


TEST_STORE_URI = os.getenv("TEST_STORE_URI", "sqlite://:memory:")
Expand All @@ -26,15 +29,13 @@ def raw_key() -> str:


@pytest_asyncio.fixture
@mark.asyncio
async def store() -> Store:
key = raw_key()
store = await Store.provision(TEST_STORE_URI, "raw", key, recreate=True)
yield store
await store.close(remove=True)


@mark.asyncio
async def test_insert_update(store: Store):
async with store as session:
# Insert a new entry
Expand Down Expand Up @@ -81,7 +82,6 @@ async def test_insert_update(store: Store):
assert found is None


@mark.asyncio
async def test_remove_all(store: Store):
async with store as session:
# Insert a new entry
Expand All @@ -104,7 +104,6 @@ async def test_remove_all(store: Store):
assert found is None


@mark.asyncio
async def test_scan(store: Store):
async with store as session:
await session.insert(
Expand Down Expand Up @@ -133,7 +132,6 @@ async def test_scan(store: Store):
assert len(rows) == 1 and dict(rows[0]) == TEST_ENTRY


@mark.asyncio
async def test_txn_basic(store: Store):
async with store.transaction() as txn:
# Insert a new entry
Expand Down Expand Up @@ -167,7 +165,6 @@ async def test_txn_basic(store: Store):
assert dict(found) == TEST_ENTRY


@mark.asyncio
async def test_txn_autocommit(store: Store):
with raises(Exception):
async with store.transaction(autocommit=True) as txn:
Expand Down Expand Up @@ -203,7 +200,6 @@ async def test_txn_autocommit(store: Store):
assert dict(found) == TEST_ENTRY


@mark.asyncio
async def test_txn_contention(store: Store):
async with store.transaction() as txn:
await txn.insert(
Expand Down Expand Up @@ -240,7 +236,6 @@ async def inc():
assert int(result.value) == INC_COUNT * TASKS


@mark.asyncio
async def test_key_store_ed25519(store: Store):
# test key operations in a new session
async with store as session:
Expand Down Expand Up @@ -279,7 +274,6 @@ async def test_key_store_ed25519(store: Store):
assert await session.fetch_key(key_name) is None


@mark.asyncio
@mark.parametrize(
"key_alg",
[KeyAlg.A128CBC_HS256, KeyAlg.XC20P],
Expand Down Expand Up @@ -318,7 +312,6 @@ async def test_key_store_symmetric(store: Store, key_alg: KeyAlg):
assert await session.fetch_key(key_name) is None


@mark.asyncio
async def test_profile(store: Store):
# New session in the default profile
async with store as session:
Expand Down Expand Up @@ -409,7 +402,6 @@ async def test_profile(store: Store):
assert (await store.get_default_profile()) == profile


@mark.asyncio
async def test_copy(store: Store):
async with store as session:
# Insert a new entry
Expand All @@ -429,3 +421,54 @@ async def test_copy(store: Store):
entries = await session.fetch_all(TEST_ENTRY["category"])
assert len(entries) == 1
assert entries[0].name == TEST_ENTRY["name"]


def test_entry_cache():
instances = WeakKeyDictionary()

class MockList:
def __init__(self, name: str, value: dict):
self._name = name
self._value = value
self._calls = []
instances[self] = True

@entry_cache
def get_name(self, index: int) -> str:
self._calls.append(index)
return self._name + str(index)

@entry_cache
def get_value(self, index: int) -> dict:
self._calls.append(index)
return self._value

NAME = "testname"
VALUE = {"a": "b"}
lst = MockList(NAME, VALUE)
# check instance is registered
assert instances

# check first call goes to method
assert lst.get_name(99) == NAME + "99"
assert lst._calls == [99]
assert lst.get_name(45) == NAME + "45"
assert lst._calls == [99, 45]
val = lst.get_value(11)
assert val == VALUE
assert lst._calls == [99, 45, 11]

# check dict value is copied
val["a"] = "c"
assert val != VALUE

# check second call goes to cache
assert lst.get_name(99) == NAME + "99"
assert lst.get_name(45) == NAME + "45"
assert lst.get_value(11) == VALUE
assert lst._calls == [99, 45, 11]

# ensure no extra references are keeping the instance around
del lst
gc.collect()
assert not instances
Loading