diff --git a/orator/connections/connection.py b/orator/connections/connection.py index 469b5989..709aeef6 100644 --- a/orator/connections/connection.py +++ b/orator/connections/connection.py @@ -18,6 +18,21 @@ connection_logger = logging.getLogger("orator.connection") +def recoverable(wrapped): + @wraps(wrapped) + def _recoverable(self, *args, **kwargs): + self._reconnect_if_missing_connection() + try: + result = wrapped(self, *args, **kwargs) + except Exception as e: + result = self._recover_if_caused_by_lost_connection( + e, wrapped, *args, **kwargs + ) + return result + + return _recoverable + + def run(wrapped): """ Special decorator encapsulating query method. @@ -356,6 +371,14 @@ def _try_again_if_caused_by_lost_connection( raise QueryException(query, bindings, e) + def _recover_if_caused_by_lost_connection(self, e, callback, *args, **kwargs): + if self._caused_by_lost_connection(e): + self.reconnect() + + return callback(self, *args, **kwargs) + + raise e + def _caused_by_lost_connection(self, e): message = str(e).lower() diff --git a/orator/connections/mysql_connection.py b/orator/connections/mysql_connection.py index e37d4b06..ab3a2534 100644 --- a/orator/connections/mysql_connection.py +++ b/orator/connections/mysql_connection.py @@ -2,7 +2,7 @@ from ..utils import decode from ..utils import PY2 -from .connection import Connection +from .connection import Connection, recoverable from ..query.grammars.mysql_grammar import MySQLQueryGrammar from ..query.processors.mysql_processor import MySQLQueryProcessor from ..schema.grammars import MySQLSchemaGrammar @@ -37,6 +37,7 @@ def get_default_schema_grammar(self): def get_schema_manager(self): return MySQLSchemaManager(self) + @recoverable def begin_transaction(self): self._connection.autocommit(False) diff --git a/tests/connections/test_mysql_connection.py b/tests/connections/test_mysql_connection.py index 2c9a2f9f..f00c42ed 100644 --- a/tests/connections/test_mysql_connection.py +++ b/tests/connections/test_mysql_connection.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- +from flexmock import flexmock + from .. import OratorTestCase +from .. import mock from orator.connections.mysql_connection import MySQLConnection @@ -20,3 +23,13 @@ def test_marker_use_qmark_false(self): connection = MySQLConnection(None, "database", "", {"use_qmark": False}) self.assertIsNone(connection.get_marker()) + + def test_recover_if_caused_by_lost_connection_is_called(self): + connection = flexmock(MySQLConnection(None, "database")) + connection._connection = mock.Mock() + connection._connection.autocommit.side_effect = Exception("lost connection") + + connection.should_receive("_recover_if_caused_by_lost_connection").once() + connection.should_receive("reconnect") + + connection.begin_transaction()