Skip to content

Commit

Permalink
Fix no_dereferencing context manager which wasn't turning off auto-de…
Browse files Browse the repository at this point in the history
…referencing correctly in some cases + fix tests
  • Loading branch information
bagerard committed Dec 19, 2023
1 parent bfc42d0 commit 9b92cc7
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 35 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Development
- Fix validate() not being called when inheritance is used in EmbeddedDocument and validate is overriden #2784
- Add support for readPreferenceTags in connection parameters #2644
- Use estimated_documents_count OR documents_count when count is called, based on the query #2529
- Fix no_dereferencing context manager which wasn't turning off auto-dereferencing correctly in some cases
- BREAKING CHANGE: no_dereferencing context manager no longer returns the class in __enter__
as it was useless and making it look like it was returning a different class

Changes in 0.27.0
=================
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/querying.rst
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ data. To turn off dereferencing of the results of a query use
You can also turn off all dereferencing for a fixed period by using the
:class:`~mongoengine.context_managers.no_dereference` context manager::

with no_dereference(Post) as Post:
with no_dereference(Post):
post = Post.objects.first()
assert(isinstance(post.author, DBRef))

Expand Down
28 changes: 25 additions & 3 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from contextlib import contextmanager

from pymongo.read_concern import ReadConcern
Expand All @@ -18,6 +19,25 @@
)


thread_locals = threading.local()
thread_locals.no_dereferencing_class = {}


def no_dereferencing_active_for_class(cls):
return cls in thread_locals.no_dereferencing_class


def _register_no_dereferencing_for_class(cls):
thread_locals.no_dereferencing_class.setdefault(cls, 0)
thread_locals.no_dereferencing_class[cls] += 1


def _unregister_no_dereferencing_for_class(cls):
thread_locals.no_dereferencing_class[cls] -= 1
if thread_locals.no_dereferencing_class[cls] == 0:
thread_locals.no_dereferencing_class.pop(cls)


class switch_db:
"""switch_db alias context manager.
Expand Down Expand Up @@ -107,7 +127,7 @@ class no_dereference:
Turns off all dereferencing in Documents for the duration of the context
manager::
with no_dereference(Group) as Group:
with no_dereference(Group):
Group.objects.find()
"""

Expand All @@ -130,15 +150,17 @@ def __init__(self, cls):

def __enter__(self):
"""Change the objects default and _auto_dereference values."""
_register_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values."""
_unregister_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
return self.cls


class no_sub_classes:
Expand Down
19 changes: 13 additions & 6 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mongoengine.common import _import_class
from mongoengine.connection import get_db
from mongoengine.context_managers import (
no_dereferencing_active_for_class,
set_read_write_concern,
set_write_concern,
switch_db,
Expand Down Expand Up @@ -51,9 +52,6 @@ class BaseQuerySet:
providing :class:`~mongoengine.Document` objects as the results.
"""

__dereference = False
_auto_dereference = True

def __init__(self, document, collection):
self._document = document
self._collection_obj = collection
Expand All @@ -74,6 +72,9 @@ def __init__(self, document, collection):
self._as_pymongo = False
self._search_text = None

self.__dereference = False
self.__auto_dereference = True

# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get("allow_inheritance") is True:
Expand Down Expand Up @@ -795,7 +796,7 @@ def clone(self):
return self._clone_into(self.__class__(self._document, self._collection_obj))

def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
"""Copy all the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
"""
Expand Down Expand Up @@ -825,7 +826,6 @@ def _clone_into(self, new_qs):
"_empty",
"_hint",
"_collation",
"_auto_dereference",
"_search_text",
"_max_time_ms",
"_comment",
Expand All @@ -836,6 +836,8 @@ def _clone_into(self, new_qs):
val = getattr(self, prop)
setattr(new_qs, prop, copy.copy(val))

new_qs.__auto_dereference = self._BaseQuerySet__auto_dereference

if self._cursor_obj:
new_qs._cursor_obj = self._cursor_obj.clone()

Expand Down Expand Up @@ -1741,10 +1743,15 @@ def _dereference(self):
self.__dereference = _import_class("DeReference")()
return self.__dereference

@property
def _auto_dereference(self):
should_deref = not no_dereferencing_active_for_class(self._document)
return should_deref and self.__auto_dereference

def no_dereference(self):
"""Turn off any dereferencing for the results of this queryset."""
queryset = self.clone()
queryset._auto_dereference = False
queryset.__auto_dereference = False
return queryset

# Helper Functions
Expand Down
75 changes: 50 additions & 25 deletions tests/test_context_managers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import pytest
from bson import DBRef

from mongoengine import *
from mongoengine.connection import get_db
Expand All @@ -19,8 +20,6 @@

class TestContextManagers(MongoDBTestCase):
def test_set_write_concern(self):
connect("mongoenginetest")

class User(Document):
name = StringField()

Expand All @@ -39,8 +38,6 @@ class User(Document):
assert original_write_concern.document == collection.write_concern.document

def test_set_read_write_concern(self):
connect("mongoenginetest")

class User(Document):
name = StringField()

Expand All @@ -65,7 +62,6 @@ class User(Document):
assert original_write_concern.document == collection.write_concern.document

def test_switch_db_context_manager(self):
connect("mongoenginetest")
register_connection("testdb-1", "mongoenginetest2")

class Group(Document):
Expand All @@ -89,7 +85,6 @@ class Group(Document):
assert 1 == Group.objects.count()

def test_switch_collection_context_manager(self):
connect("mongoenginetest")
register_connection(alias="testdb-1", db="mongoenginetest2")

class Group(Document):
Expand Down Expand Up @@ -117,7 +112,6 @@ class Group(Document):

def test_no_dereference_context_manager_object_id(self):
"""Ensure that DBRef items in ListFields aren't dereferenced."""
connect("mongoenginetest")

class User(Document):
name = StringField()
Expand All @@ -136,25 +130,57 @@ class Group(Document):
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()

with no_dereference(Group) as NoDeRefGroup:
assert Group._fields["members"]._auto_dereference
assert not NoDeRefGroup._fields["members"]._auto_dereference
with no_dereference(Group):
assert not Group._fields["members"]._auto_dereference

with no_dereference(Group) as Group:
with no_dereference(Group):
group = Group.objects.first()
for m in group.members:
assert not isinstance(m, User)
assert not isinstance(group.ref, User)
assert not isinstance(group.generic, User)
assert isinstance(m, DBRef)
assert isinstance(group.ref, DBRef)
assert isinstance(group.generic, dict)

group = Group.objects.first()
for m in group.members:
assert isinstance(m, User)
assert isinstance(group.ref, User)
assert isinstance(group.generic, User)

def test_no_dereference_context_manager_dbref(self):
def test_no_dereference_context_manager_nested(self):
"""Ensure that DBRef items in ListFields aren't dereferenced."""
connect("mongoenginetest")

class User(Document):
name = StringField()

class Group(Document):
ref = ReferenceField(User, dbref=False)

User.drop_collection()
Group.drop_collection()

for i in range(1, 51):
User(name="user %s" % i).save()

user = User.objects.first()
Group(ref=user).save()

with no_dereference(Group):
group = Group.objects.first()
assert isinstance(group.ref, DBRef)

with no_dereference(Group):
group = Group.objects.first()
assert isinstance(group.ref, DBRef)

# make sure its still off here
group = Group.objects.first()
assert isinstance(group.ref, DBRef)

group = Group.objects.first()
assert isinstance(group.ref, User)

def test_no_dereference_context_manager_dbref(self):
"""Ensure that DBRef items in ListFields aren't dereferenced"""

class User(Document):
name = StringField()
Expand All @@ -173,16 +199,19 @@ class Group(Document):
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()

with no_dereference(Group) as NoDeRefGroup:
assert Group._fields["members"]._auto_dereference
assert not NoDeRefGroup._fields["members"]._auto_dereference
with no_dereference(Group):
assert not Group._fields["members"]._auto_dereference

with no_dereference(Group) as Group:
group = Group.objects.first()
with no_dereference(Group):
qs = Group.objects
assert qs._auto_dereference is False
group = qs.first()
assert not group._fields["members"]._auto_dereference
assert all(not isinstance(m, User) for m in group.members)
assert not isinstance(group.ref, User)
assert not isinstance(group.generic, User)

group = Group.objects.first()
assert all(isinstance(m, User) for m in group.members)
assert isinstance(group.ref, User)
assert isinstance(group.generic, User)
Expand Down Expand Up @@ -265,7 +294,6 @@ def test_query_counter_does_not_swallow_exception(self):
raise TypeError()

def test_query_counter_temporarily_modifies_profiling_level(self):
connect("mongoenginetest")
db = get_db()

def _current_profiling_level():
Expand All @@ -290,7 +318,6 @@ def _set_profiling_level(lvl):
raise

def test_query_counter(self):
connect("mongoenginetest")
db = get_db()

collection = db.query_counter
Expand Down Expand Up @@ -380,7 +407,6 @@ class B(Document):
assert q == 3

def test_query_counter_counts_getmore_queries(self):
connect("mongoenginetest")
db = get_db()

collection = db.query_counter
Expand All @@ -397,7 +423,6 @@ def test_query_counter_counts_getmore_queries(self):
assert q == 2 # 1st select + 1 getmore

def test_query_counter_ignores_particular_queries(self):
connect("mongoenginetest")
db = get_db()

collection = db.query_counter
Expand Down

0 comments on commit 9b92cc7

Please sign in to comment.