diff --git a/lib/zookeeper.py b/lib/zookeeper.py index e886042389..95b3afd5e3 100644 --- a/lib/zookeeper.py +++ b/lib/zookeeper.py @@ -36,7 +36,7 @@ # SCION from lib.errors import SCIONBaseError from lib.thread import kill_self, thread_safety_net -from lib.util import timed +from lib.util import SCIONTime, timed class ZkBaseError(SCIONBaseError): @@ -277,9 +277,37 @@ def is_connected(self): def wait_connected(self, timeout=None): """ - Wait until there is a connection to Zookeeper. + Wait until there is a connection to Zookeeper. Log every 10s until a + connection is available. + + :param float timeout: + Number of seconds to wait for a ZK connection. If ``None``, wait + forever. + :returns: ``True`` if connected, otherwise ``False`` + :rtype: :class:`bool` """ - return self._connected.wait(timeout=timeout) + if self.is_connected(): + return True + logging.debug("Waiting for ZK connection") + start = SCIONTime.get_time() + total_time = 0.0 + if timeout is None: + next_timeout = 10.0 + while True: + if timeout is not None: + next_timeout = min(timeout - total_time, 10.0) + ret = self._connected.wait(timeout=next_timeout) + total_time = SCIONTime.get_time() - start + if ret: + logging.debug("ZK connection available after %.2fs", total_time) + return True + elif timeout is not None and total_time >= timeout: + logging.debug("ZK connection still unavailable after %.2fs", + total_time) + return False + else: + logging.debug("Still waiting for ZK connection (%.2fs so far)", + total_time) def ensure_path(self, path, abs=False): """ diff --git a/test/lib_zookeeper_test.py b/test/lib_zookeeper_test.py index 2e3970490a..7f5f15e22d 100644 --- a/test/lib_zookeeper_test.py +++ b/test/lib_zookeeper_test.py @@ -372,18 +372,58 @@ class TestZookeeperWaitConnected(BaseZookeeper): Unit tests for lib.zookeeper.Zookeeper.wait_connected """ @patch("lib.zookeeper.Zookeeper.__init__", autospec=True, return_value=None) - def _check(self, timeout, init): + def test_connected(self, init): inst = self._init_basic_setup() + inst.is_connected = create_mock() + # Call + ntools.ok_(inst.wait_connected()) + # Tests + inst.is_connected.assert_called_once_with() + + @patch("lib.zookeeper.SCIONTime.get_time", new_callable=create_mock) + @patch("lib.zookeeper.Zookeeper.__init__", autospec=True, return_value=None) + def test_no_timeout(self, init, get_time): + inst = self._init_basic_setup() + inst.is_connected = create_mock() + inst.is_connected.return_value = False + get_time.side_effect = [0, 10, 20] inst._connected = create_mock(["wait"]) - inst._connected.wait.return_value = 33 + inst._connected.wait.side_effect = [False, True] # Call - ntools.eq_(inst.wait_connected(timeout=timeout), 33) + ntools.ok_(inst.wait_connected(timeout=None)) # Tests - inst._connected.wait.assert_called_once_with(timeout=timeout) + inst._connected.wait.assert_has_calls([call(timeout=10.0)] * 2) + ntools.eq_(inst._connected.wait.call_count, 2) - def test(self): - for timeout in None, 1, 15: - yield self._check, timeout + @patch("lib.zookeeper.SCIONTime.get_time", new_callable=create_mock) + @patch("lib.zookeeper.Zookeeper.__init__", autospec=True, return_value=None) + def test_timeout_success(self, init, get_time): + inst = self._init_basic_setup() + inst.is_connected = create_mock() + inst.is_connected.return_value = False + get_time.side_effect = [0, 10, 20] + inst._connected = create_mock(["wait"]) + inst._connected.wait.side_effect = [False, True] + # Call + ntools.ok_(inst.wait_connected(timeout=15)) + # Tests + inst._connected.wait.assert_has_calls([call(timeout=10.0), + call(timeout=5.0)]) + ntools.eq_(inst._connected.wait.call_count, 2) + + @patch("lib.zookeeper.SCIONTime.get_time", new_callable=create_mock) + @patch("lib.zookeeper.Zookeeper.__init__", autospec=True, return_value=None) + def test_timeout_fail(self, init, get_time): + inst = self._init_basic_setup() + inst.is_connected = create_mock() + inst.is_connected.return_value = False + get_time.side_effect = [0, 10, 20] + inst._connected = create_mock(["wait"]) + inst._connected.wait.side_effect = [False, False] + # Call + ntools.assert_false(inst.wait_connected(timeout=15)) + # Tests + ntools.eq_(inst._connected.wait.call_count, 2) class TestZookeeperEnsurePath(BaseZookeeper):