diff --git a/flask_mongoengine/connection.py b/flask_mongoengine/connection.py index 325d4705..e94a44af 100644 --- a/flask_mongoengine/connection.py +++ b/flask_mongoengine/connection.py @@ -73,6 +73,11 @@ def _validate_settings(is_test, temp_db, preserved, conn_host): 'only when `TESTING` is set to true.' raise InvalidSettingsError(msg) +def __get_app_config(key): + return (_app_instance.get(key, False) + if isinstance(_app_instance, dict) + else _app_instance.config.get(key, False)) + def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): global _connections set_global_attributes() @@ -98,9 +103,9 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings.pop('password', None) conn_settings.pop('authentication_source', None) - is_test = _app_instance.config.get('TESTING', False) - temp_db = _app_instance.config.get('TEMP_DB', False) - preserved = _app_instance.config.get('PRESERVE_TEMP_DB', False) + is_test = __get_app_config('TESTING') + temp_db = __get_app_config('TEMP_DB') + preserved = __get_app_config('PRESERVE_TEMP_DB') # Validation _validate_settings(is_test, temp_db, preserved, conn_host) @@ -351,7 +356,7 @@ def create_connection(config, app): @param app: instance of flask.Flask """ global _connection_settings, _app_instance - _app_instance = app + _app_instance = app if app else config if config is None or not isinstance(config, dict): raise Exception("Invalid application configuration"); @@ -363,6 +368,7 @@ def create_connection(config, app): connections = {} for conn_setting in conn_settings: alias = conn_setting['alias'] + _connection_settings[alias] = conn_setting connections[alias] = get_connection(alias) return connections else: diff --git a/tests/test_connection.py b/tests/test_connection.py index f5ff207e..aa0d7959 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -39,6 +39,10 @@ def test_live_connection(self): 'PASSWORD': None, 'DB': 'test' } + + self._do_persist(db) + + def _do_persist(self, db): class Todo(db.Document): title = db.StringField(max_length=60) text = db.StringField() @@ -57,6 +61,44 @@ class Todo(db.Document): f_to = Todo.objects().first() self.assertEqual(s_todo.title, f_to.title) + def test_multiple_connections(self): + db = MongoEngine() + self.app.config['TESTING'] = True + self.app.config['MONGODB_SETTINGS'] = [ + { + "ALIAS": "default", + "DB": 'my_db1', + "HOST": 'localhost', + "PORT": 27017 + }, + { + "ALIAS": "my_db2", + "DB": 'my_db2', + "HOST": 'localhost', + "PORT": 27017 + }, + ] + class Todo(db.Document): + title = db.StringField(max_length=60) + text = db.StringField() + done = db.BooleanField(default=False) + meta = {"db_alias": "my_db2"} + + db.init_app(self.app) + Todo.drop_collection() + + # Switch DB + from mongoengine.context_managers import switch_db + with switch_db(Todo, 'default') as Todo: + todo = Todo() + todo.text = "Sample" + todo.title = "Testing" + todo.done = True + s_todo = todo.save() + + f_to = Todo.objects().first() + self.assertEqual(s_todo.title, f_to.title) + def test_mongodb_temp_instance(self): # String value used instead of boolean self.app.config['TESTING'] = True