diff --git a/migration/api.py b/migration/api.py index 8fe131d..3f7e992 100644 --- a/migration/api.py +++ b/migration/api.py @@ -12,11 +12,13 @@ retry_if_exception_type, stop_after_attempt ) +import trio logger = logging.getLogger(__name__) MAX_ATTEMPT_NUM = 4 +MAX_ASYNC_CONNS = 20 class EndpointType(Enum): @@ -30,7 +32,7 @@ class GetResponse: class API: - client: httpx.Client + client: httpx.AsyncClient def __init__( self, @@ -40,7 +42,13 @@ def __init__( timeout: int = 10 ): headers = {'Authorization': f'Bearer {key}'} - self.client = httpx.Client(base_url=url + endpoint_type.value, headers=headers, timeout=timeout) + limits = httpx.Limits(max_connections=MAX_ASYNC_CONNS) + self.client = httpx.AsyncClient( + base_url=url + endpoint_type.value, + headers=headers, + timeout=timeout, + limits=limits + ) @staticmethod def get_next_page_params(resp: httpx.Response) -> dict[str, Any] | None: @@ -54,10 +62,11 @@ def get_next_page_params(resp: httpx.Response) -> dict[str, Any] | None: stop=stop_after_attempt(MAX_ATTEMPT_NUM), retry=retry_if_exception_type((httpx.HTTPError, JSONDecodeError)), reraise=True, - before_sleep=before_sleep_log(logger, logging.WARN) + before_sleep=before_sleep_log(logger, logging.WARN), + sleep=trio.sleep ) - def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse: - resp = self.client.get(url=url, params=params) + async def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse: + resp = await self.client.get(url=url, params=params) resp.raise_for_status() data = resp.json() next_page_params = self.get_next_page_params(resp) @@ -67,14 +76,15 @@ def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse: stop=stop_after_attempt(MAX_ATTEMPT_NUM), retry=retry_if_exception_type((httpx.HTTPError, JSONDecodeError)), reraise=True, - before_sleep=before_sleep_log(logger, logging.WARN) + before_sleep=before_sleep_log(logger, logging.WARN), + sleep=trio.sleep ) - def put(self, url: str, params: dict[str, Any] | None = None) -> Any: - resp = self.client.put(url=url, params=params) + async def put(self, url: str, params: dict[str, Any] | None = None) -> Any: + resp = await self.client.put(url=url, params=params) resp.raise_for_status() return resp.json() - def get_results_from_pages( + async def get_results_from_pages( self, endpoint: str, params: dict[str, Any] | None = None, page_size: int = 50, limit: int | None = None ) -> list[dict[str, Any]]: extra_params: dict[str, Any] @@ -90,7 +100,7 @@ def get_results_from_pages( while more_pages: logger.debug(f'Params: {extra_params}') - get_resp = self.get(url=endpoint, params=extra_params) + get_resp = await self.get(url=endpoint, params=extra_params) results += get_resp.data if get_resp.next_page_params is None: more_pages = False diff --git a/migration/main.py b/migration/main.py index 787fa2c..a09dfe0 100644 --- a/migration/main.py +++ b/migration/main.py @@ -2,10 +2,12 @@ import os from contextlib import nullcontext +import trio from dotenv import load_dotenv +from tqdm import tqdm from api import API -from data import ExternalTool, ToolMigration +from data import Course, ExternalTool, ToolMigration from db import DB, Dialect from exceptions import InvalidToolIdsException from manager import AccountManagerFactory, CourseManager @@ -15,6 +17,15 @@ logger = logging.getLogger(__name__) +class TrioProgress(trio.abc.Instrument): + + def __init__(self, total, **kwargs): + self.tqdm = tqdm(total=total, **kwargs) + + def task_exited(self, task): + self.tqdm.update(1) + + def find_tools_for_migrations( tools: list[ExternalTool], migrations: list[ToolMigration] ) -> list[tuple[ExternalTool, ExternalTool]]: @@ -36,39 +47,51 @@ def find_tools_for_migrations( return tool_pairs +async def migrate_tool_for_course(api: API, course: Course, source_tool: ExternalTool, target_tool: ExternalTool): + course_manager = CourseManager(course, api) + tabs = await course_manager.get_tool_tabs() + source_tool_tab = CourseManager.find_tab_by_tool_id(source_tool.id, tabs) + target_tool_tab = CourseManager.find_tab_by_tool_id(target_tool.id, tabs) + if source_tool_tab is None or target_tool_tab is None: + raise InvalidToolIdsException( + 'One or both of the following tool IDs are not available in this course: ' + + str([source_tool.id, target_tool.id]) + ) + await course_manager.replace_tool_tab(source_tool_tab, target_tool_tab) + + @time_execution -def main(api: API, account_id: int, term_ids: list[int], migrations: list[ToolMigration], db: DB | None = None): - +async def main(api: API, account_id: int, term_ids: list[int], migrations: list[ToolMigration], db: DB | None = None): factory = AccountManagerFactory() account_manager = factory.get_manager(account_id, api, db) - with api.client, db if db is not None else nullcontext(): # type: ignore - tools = account_manager.get_tools_installed_in_account() - logger.info(f'Number of tools found in account {account_id}: {len(tools)}') - - tool_pairs = find_tools_for_migrations(tools, migrations) - - # get list of tools available in account - courses = account_manager.get_courses_in_terms(term_ids) - logger.info(f'Number of courses found in account {account_id} for terms {term_ids}: {len(courses)}') - - for source_tool, target_tool in tool_pairs: - logger.info(f'Source tool: {source_tool}') - logger.info(f'Target tool: {target_tool}') - - for course in courses: - # Replace target tool with source tool in course navigation - course_manager = CourseManager(course, api) - tabs = course_manager.get_tool_tabs() - source_tool_tab = CourseManager.find_tab_by_tool_id(source_tool.id, tabs) - target_tool_tab = CourseManager.find_tab_by_tool_id(target_tool.id, tabs) - if source_tool_tab is None or target_tool_tab is None: - raise InvalidToolIdsException( - 'One or both of the following tool IDs are not available in this course: ' + - str([source_tool.id, target_tool.id]) - ) - course_manager.replace_tool_tab(source_tool_tab, target_tool_tab) - + async with api.client: + with db if db is not None else nullcontext(): # type: ignore + tools = await account_manager.get_tools_installed_in_account() + logger.info(f'Number of tools found in account {account_id}: {len(tools)}') + + tool_pairs = find_tools_for_migrations(tools, migrations) + + # get list of tools available in account + courses = await account_manager.get_courses_in_terms(term_ids) + logger.info(f'Number of courses found in account {account_id} for terms {term_ids}: {len(courses)}') + + for source_tool, target_tool in tool_pairs: + logger.info(f'Source tool: {source_tool}') + logger.info(f'Target tool: {target_tool}') + + progress = TrioProgress(total=len(courses)) + trio.lowlevel.add_instrument(progress) + async with trio.open_nursery() as nursery: + for course in courses: + nursery.start_soon( + migrate_tool_for_course, + api, + course, + source_tool, + target_tool + ) + trio.lowlevel.remove_instrument(progress) if __name__ == '__main__': # get configuration (either env. variables, cli flags, or direct input) @@ -122,7 +145,8 @@ def main(api: API, account_id: int, term_ids: list[int], migrations: list[ToolMi source_tool_id: int = int(os.getenv('SOURCE_TOOL_ID', '0')) target_tool_id: int = int(os.getenv('TARGET_TOOL_ID', '0')) - main( + trio.run( + main, API(api_url, api_key), account_id, enrollment_term_ids, diff --git a/migration/manager.py b/migration/manager.py index 8c82c5a..22b439d 100644 --- a/migration/manager.py +++ b/migration/manager.py @@ -19,11 +19,11 @@ class AccountManagerBase(ABC): account_id: int @abstractmethod - def get_tools_installed_in_account(self) -> list[ExternalTool]: + async def get_tools_installed_in_account(self) -> list[ExternalTool]: pass @abstractmethod - def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]: + async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]: pass @@ -31,14 +31,14 @@ def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> class AccountManager(AccountManagerBase): api: API - def get_tools_installed_in_account(self) -> list[ExternalTool]: + async def get_tools_installed_in_account(self) -> list[ExternalTool]: params = {"include_parents": True} - results = self.api.get_results_from_pages(f'/accounts/{self.account_id}/external_tools', params) + results = await self.api.get_results_from_pages(f'/accounts/{self.account_id}/external_tools', params) tools = [ExternalTool(id=tool_dict['id'], name=tool_dict['name']) for tool_dict in results] return tools @time_execution - def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]: + async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]: limit_chunks = None if limit is not None: limit_chunks = chunk_integer(limit, len(term_ids)) @@ -46,7 +46,7 @@ def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> results: list[dict[str, Any]] = [] for i, term_id in enumerate(term_ids): limit_for_term = limit_chunks[i] if limit_chunks is not None else None - term_results = self.api.get_results_from_pages( + term_results = await self.api.get_results_from_pages( f'/accounts/{self.account_id}/courses', params={ 'enrollment_term_id': term_id }, page_size=50, @@ -73,11 +73,11 @@ class WarehouseAccountManager(AccountManagerBase): def __post_init__(self): self.account_manager = AccountManager(self.account_id, self.api) - def get_tools_installed_in_account(self) -> list[ExternalTool]: - return self.account_manager.get_tools_installed_in_account() + async def get_tools_installed_in_account(self) -> list[ExternalTool]: + return await self.account_manager.get_tools_installed_in_account() - def get_subaccount_ids(self) -> list[int]: - results = self.api.get_results_from_pages( + async def get_subaccount_ids(self) -> list[int]: + results = await self.api.get_results_from_pages( f'/accounts/{self.account_id}/sub_accounts', { 'recursive': True } ) sub_account_ids = [result['id'] for result in results] @@ -85,8 +85,8 @@ def get_subaccount_ids(self) -> list[int]: return sub_account_ids @time_execution - def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]: - account_ids = [self.account_id] + self.get_subaccount_ids() + async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]: + account_ids = [self.account_id] + await self.get_subaccount_ids() conn = self.db.get_connection() statement = sqlalchemy.text(f''' @@ -148,8 +148,11 @@ def convert_data_to_tool_tab(cls, data: dict[str, Any]) -> ExternalToolTab: position=data['position'] ) - def get_tool_tabs(self) -> list[ExternalToolTab]: - results = self.api.get_results_from_pages(f'/courses/{self.course.id}/tabs') + def create_course_log_message(self, message: str) -> str: + return f'{self.course}\n{message}\n- - -' + + async def get_tool_tabs(self) -> list[ExternalToolTab]: + results = await self.api.get_results_from_pages(f'/courses/{self.course.id}/tabs') tabs: list[ExternalToolTab] = [] for result in results: @@ -158,43 +161,52 @@ def get_tool_tabs(self) -> list[ExternalToolTab]: tabs.append(CourseManager.convert_data_to_tool_tab(result)) return tabs - def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool, position: int | None = None): + async def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool, position: int | None = None): params: dict[str, Any] = { "hidden": is_hidden } if position is not None: params.update({ "position": position }) - result = self.api.put( + result = await self.api.put( f'/courses/{self.course.id}/tabs/{tab.id}', params=params ) logger.debug(result) return CourseManager.convert_data_to_tool_tab(result) - def replace_tool_tab( + async def replace_tool_tab( self, source_tab: ExternalToolTab, target_tab: ExternalToolTab ) -> tuple[ExternalToolTab, ExternalToolTab]: logger.debug([source_tab, target_tab]) # Source tool is hidden in course, don't do anything if source_tab.is_hidden: - logger.debug(f'Skipping replacement for {[source_tab, target_tab]}; source tool is hidden.') + logger.debug(self.create_course_log_message( + f'Skipping replacement for {[source_tab, target_tab]}; source tool is hidden.' + )) return (source_tab, target_tab) else: if not target_tab.is_hidden: - logger.warning( + logger.warning(self.create_course_log_message( f'Both tools ({[source_tab, target_tab]}) are currently available. ' + 'Rolling back will hide the target tool!' - ) - logger.debug((f'Skipping update for {target_tab}; tool is already available.')) + )) + logger.debug(self.create_course_log_message( + f'Skipping update for {target_tab}; tool is already available.' + )) new_target_tab = target_tab else: target_position = source_tab.position - new_target_tab = self.update_tool_tab(tab=target_tab, is_hidden=False, position=target_position) - logger.info(f"Made available target tool in course's navigation: {new_target_tab}") + new_target_tab = await self.update_tool_tab(tab=target_tab, is_hidden=False, position=target_position) + logger.info(self.create_course_log_message( + f"Made available target tool in course's navigation: {new_target_tab}" + )) # Always hide the source tool if it's available - new_source_tab = self.update_tool_tab(tab=source_tab, is_hidden=True) - logger.info(f"Hid source tool in course's navigation: {new_source_tab}") + new_source_tab = await self.update_tool_tab(tab=source_tab, is_hidden=True) + logger.info(self.create_course_log_message( + f"Hid source tool in course's navigation: {new_source_tab}" + )) + return (new_source_tab, new_target_tab) diff --git a/migration/tests.py b/migration/tests.py index 4dfa12b..5248851 100644 --- a/migration/tests.py +++ b/migration/tests.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import httpx +import trio from dotenv import load_dotenv from data import Course, ExternalTool, ExternalToolTab, ToolMigration @@ -20,12 +21,12 @@ logger = logging.getLogger(__name__) -class APITestCase(unittest.TestCase): +class APITestCase(unittest.IsolatedAsyncioTestCase): """ Integration/unit tests for API class """ - def setUp(self) -> None: + async def setUp(self) -> None: self.api_url: str = os.getenv('API_URL', '') api_key: str = os.getenv('API_KEY', '') self.api = API(self.api_url, api_key) @@ -45,17 +46,19 @@ def test_get_next_page_params_with_no_next_page(self): params = API.get_next_page_params(mock_response) self.assertEqual(params, {'page': ['2'], 'per_page': ['5']}) - def test_get_results_from_pages(self): - with self.api.client: - results = self.api.get_results_from_pages(f'/accounts/{self.account_id}/courses', page_size=5) + async def test_get_results_from_pages(self): + async with self.api.client: + results = await self.api.get_results_from_pages(f'/accounts/{self.account_id}/courses', page_size=5) self.assertTrue(len(results) > 1) - def test_get_results_from_pages_with_limit(self): - with self.api.client: - results = self.api.get_results_from_pages(f'/accounts/{self.account_id}/courses', page_size=5, limit=2) + async def test_get_results_from_pages_with_limit(self): + async with self.api.client: + results = await self.api.get_results_from_pages( + f'/accounts/{self.account_id}/courses', page_size=5, limit=2 + ) self.assertTrue(len(results) == 2) - def test_get_retries_on_http_error(self): + async def test_get_retries_on_http_error(self): request = MagicMock(httpx.Request, autospec=True, url=self.course_url) resp = httpx.Response( status_code=httpx.codes.BAD_GATEWAY, @@ -69,12 +72,12 @@ def test_get_retries_on_http_error(self): with patch.object(self.api.client, 'get', autospec=True) as mock_get_call: mock_get_call.side_effect = [resp, expected_resp] - with self.api.client: - result = self.api.get(self.course_url) + async with self.api.client: + result = await self.api.get(self.course_url) self.assertEqual(self.course_data, result.data) self.assertEqual(mock_get_call.call_count, 2) - def test_get_retries_on_decode_error(self): + async def test_get_retries_on_decode_error(self): request = MagicMock(httpx.Request, autospec=True, url=self.course_url) bad_json_resp = httpx.Response( status_code=httpx.codes.OK, @@ -88,12 +91,12 @@ def test_get_retries_on_decode_error(self): ) with patch.object(self.api.client, 'get', autospec=True) as mock_get_call: mock_get_call.side_effect = [bad_json_resp, expected_resp] - with self.api.client: - result = self.api.get(self.course_url) + async with self.api.client: + result = await self.api.get(self.course_url) self.assertEqual(self.course_data, result.data) self.assertEqual(mock_get_call.call_count, 2) - def test_put_retries_until_failure(self): + async def test_put_retries_until_failure(self): request = MagicMock(httpx.Request, autospec=True, url=self.course_url) bad_resp = httpx.Response( status_code=httpx.codes.BAD_GATEWAY, @@ -101,13 +104,13 @@ def test_put_retries_until_failure(self): ) with patch.object(self.api.client, 'put', autospec=True) as mock_put_call: mock_put_call.side_effect = [bad_resp, bad_resp, bad_resp, bad_resp] - with self.api.client: + async with self.api.client: with self.assertRaises(httpx.HTTPStatusError): - self.api.put(self.course_url, params={ "name": "Test Course!" }) + await self.api.put(self.course_url, params={ "name": "Test Course!" }) self.assertEqual(mock_put_call.call_count, 4) -class AccountManagerTestCase(unittest.TestCase): +class AccountManagerTestCase(unittest.IsolatedAsyncioTestCase): """ Integration tests for AccountManager class """ @@ -119,19 +122,19 @@ def setUp(self) -> None: self.enrollment_term_ids: list[int] = convert_csv_to_int_list(os.getenv('ENROLLMENT_TERM_IDS', '0')) self.api = API(api_url, api_key) - def test_manager_get_tools(self): - with self.api.client: + async def test_manager_get_tools(self): + async with self.api.client: manager = AccountManager(self.test_account_id, self.api) - tools = manager.get_tools_installed_in_account() + tools = await manager.get_tools_installed_in_account() self.assertTrue(len(tools) > 0) for tool in tools: logger.debug(tool) self.assertTrue(isinstance(tool, ExternalTool)) - def test_manager_get_courses_in_single_term(self): - with self.api.client: + async def test_manager_get_courses_in_single_term(self): + async with self.api.client: manager = AccountManager(self.test_account_id, self.api) - courses = manager.get_courses_in_terms([self.enrollment_term_ids[0]], 150) + courses = await manager.get_courses_in_terms([self.enrollment_term_ids[0]], 150) self.assertTrue(len(courses) > 0) term_ids: list[int] = [] for course in courses: @@ -140,10 +143,10 @@ def test_manager_get_courses_in_single_term(self): term_id_set = set(term_ids) self.assertTrue(len(term_id_set) == 1) - def test_manager_get_courses_in_multiple_terms(self): - with self.api.client: + async def test_manager_get_courses_in_multiple_terms(self): + async with self.api.client: manager = AccountManager(self.test_account_id, self.api) - courses = manager.get_courses_in_terms(self.enrollment_term_ids) + courses = await manager.get_courses_in_terms(self.enrollment_term_ids) self.assertTrue(len(courses) > 0) term_ids: list[int] = [] for course in courses: @@ -152,17 +155,17 @@ def test_manager_get_courses_in_multiple_terms(self): term_id_set = set(term_ids) self.assertTrue(len(term_id_set) > 1) - def test_manager_get_courses_with_limit(self): - with self.api.client: + async def test_manager_get_courses_with_limit(self): + async with self.api.client: manager = AccountManager(self.test_account_id, self.api) - courses = manager.get_courses_in_terms(self.enrollment_term_ids, 50) + courses = await manager.get_courses_in_terms(self.enrollment_term_ids, 50) self.assertTrue(len(courses) > 0) for course in courses: self.assertTrue(isinstance(course, Course)) self.assertTrue(len(courses) <= 50) -class WarehouseAccountManagerTestCase(unittest.TestCase): +class WarehouseAccountManagerTestCase(unittest.IsolatedAsyncioTestCase): """ Integration tests for WarehouseAccountManager class """ @@ -184,18 +187,19 @@ def setUp(self) -> None: self.db = DB(Dialect.POSTGRES, wh_db_params) self.test_account_id = int(os.getenv('TEST_ACCOUNT_ID', 0)) - def test_get_subaccount_ids(self): - with self.api.client: + async def test_get_subaccount_ids(self): + async with self.api.client: manager = WarehouseAccountManager(account_id=self.test_account_id, db=self.db, api=self.api) - subaccount_ids = manager.get_subaccount_ids() + subaccount_ids = await manager.get_subaccount_ids() self.assertTrue(len(subaccount_ids) > 0) for subaccount_id in subaccount_ids: self.assertIsInstance(subaccount_id, int) - def test_manager_get_courses_in_single_term(self): - with self.db, self.api.client: - manager = WarehouseAccountManager(account_id=self.test_account_id, db=self.db, api=self.api) - courses = manager.get_courses_in_terms([self.enrollment_term_ids[0]], 150) + async def test_manager_get_courses_in_single_term(self): + with self.db: + async with self.api.client: + manager = WarehouseAccountManager(account_id=self.test_account_id, db=self.db, api=self.api) + courses = await manager.get_courses_in_terms([self.enrollment_term_ids[0]], 150) self.assertTrue(len(courses) > 0) term_ids: list[int] = [] for course in courses: @@ -204,10 +208,11 @@ def test_manager_get_courses_in_single_term(self): term_id_set = set(term_ids) self.assertTrue(len(term_id_set) == 1) - def test_manager_get_courses_in_multiple_terms(self): - with self.db, self.api.client: - manager = WarehouseAccountManager(account_id=self.test_account_id, db=self.db, api=self.api) - courses = manager.get_courses_in_terms(self.enrollment_term_ids) + async def test_manager_get_courses_in_multiple_terms(self): + with self.db: + async with self.api.client: + manager = WarehouseAccountManager(account_id=self.test_account_id, db=self.db, api=self.api) + courses = await manager.get_courses_in_terms(self.enrollment_term_ids) self.assertTrue(len(courses) > 0) term_ids: list[int] = [] for course in courses: @@ -216,22 +221,23 @@ def test_manager_get_courses_in_multiple_terms(self): term_id_set = set(term_ids) self.assertTrue(len(term_id_set) > 1) - def test_manager_get_courses_with_limit(self): - with self.db, self.api.client: - manager = WarehouseAccountManager(self.test_account_id, self.db, api=self.api) - courses = manager.get_courses_in_terms(self.enrollment_term_ids, 50) + async def test_manager_get_courses_with_limit(self): + with self.db: + async with self.api.client: + manager = WarehouseAccountManager(self.test_account_id, self.db, api=self.api) + courses = await manager.get_courses_in_terms(self.enrollment_term_ids, 50) self.assertTrue(len(courses) > 0) for course in courses: self.assertTrue(isinstance(course, Course)) self.assertTrue(len(courses) <= 50) -class CourseManagerTestCase(unittest.TestCase): +class CourseManagerTestCase(unittest.IsolatedAsyncioTestCase): """ Integration/unit tests for CourseManager class """ - def setUp(self): + async def asyncSetUp(self): api_url: str = os.getenv('API_URL', '') api_key: str = os.getenv('API_KEY', '') self.api = API(api_url, api_key) @@ -255,8 +261,8 @@ def setUp(self): setup_api = API(api_url, api_key) setup_course_manager = CourseManager(course, setup_api) - with setup_api.client: - tabs_before = setup_course_manager.get_tool_tabs() + async with setup_api.client: + tabs_before = await setup_course_manager.get_tool_tabs() source_tab = CourseManager.find_tab_by_tool_id(self.source_tool_id, tabs_before) target_tab = CourseManager.find_tab_by_tool_id(self.target_tool_id, tabs_before) if source_tab is None or target_tab is None: @@ -275,56 +281,56 @@ def test_find_tab_by_tool_id_returns_none(self): tab = CourseManager.find_tab_by_tool_id(100000, [self.test_external_tool_tab]) self.assertTrue(tab is None) - def test_manager_gets_tool_tabs_in_course(self): - with self.api.client: - tabs = self.course_manager.get_tool_tabs() + async def test_manager_gets_tool_tabs_in_course(self): + async with self.api.client: + tabs = await self.course_manager.get_tool_tabs() self.assertTrue(len(tabs) > 0) for tab in tabs: self.assertTrue(isinstance(tab, ExternalToolTab)) - def test_update_tool_tab_with_position(self): - with self.api.client: - tabs = self.course_manager.get_tool_tabs() + async def test_update_tool_tab_with_position(self): + async with self.api.client: + tabs = await self.course_manager.get_tool_tabs() source_tab = CourseManager.find_tab_by_tool_id(self.source_tool_id, tabs) if source_tab is None: raise InvalidToolIdsException(f'Tool with ID {self.source_tool_id} is not available in this course') - new_tab = self.course_manager.update_tool_tab(source_tab, is_hidden=not source_tab.is_hidden, position=5) + new_tab = await self.course_manager.update_tool_tab(source_tab, is_hidden=not source_tab.is_hidden, position=5) self.assertNotEqual(new_tab.is_hidden, source_tab.is_hidden) self.assertEqual(new_tab.position, 5) - def test_manager_replace_tool_tab_skips_if_source_hidden_and_target_available(self): - with self.api.client: + async def test_manager_replace_tool_tab_skips_if_source_hidden_and_target_available(self): + async with self.api.client: # Set up - old_source_tab = self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=True) - old_target_tab = self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=False) + old_source_tab = await self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=True) + old_target_tab = await self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=False) - new_source_tab, new_target_tab = self.course_manager.replace_tool_tab( + new_source_tab, new_target_tab = await self.course_manager.replace_tool_tab( old_source_tab, old_target_tab ) self.assertEqual(old_source_tab, new_source_tab) self.assertEqual(old_target_tab, new_target_tab) - def test_manager_replace_tool_tab_skips_if_source_hidden_and_target_hidden(self): - with self.api.client: + async def test_manager_replace_tool_tab_skips_if_source_hidden_and_target_hidden(self): + async with self.api.client: # Set up - old_source_tab = self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=True) - old_target_tab = self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=True) + old_source_tab = await self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=True) + old_target_tab = await self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=True) - new_source_tab, new_target_tab = self.course_manager.replace_tool_tab( + new_source_tab, new_target_tab = await self.course_manager.replace_tool_tab( old_source_tab, old_target_tab ) self.assertEqual(old_source_tab, new_source_tab) self.assertEqual(old_target_tab, new_target_tab) - def test_manager_replace_tool_tab_fully_replaces_source_with_target(self): - with self.api.client: + async def test_manager_replace_tool_tab_fully_replaces_source_with_target(self): + async with self.api.client: # Set up - old_source_tab = self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=False, position=5) - old_target_tab = self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=True) + old_source_tab = await self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=False, position=5) + old_target_tab = await self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=True) - new_source_tab, new_target_tab = self.course_manager.replace_tool_tab( + new_source_tab, new_target_tab = await self.course_manager.replace_tool_tab( old_source_tab, old_target_tab ) @@ -332,13 +338,13 @@ def test_manager_replace_tool_tab_fully_replaces_source_with_target(self): self.assertFalse(new_target_tab.is_hidden) self.assertEqual(old_source_tab.position, new_target_tab.position) - def test_manager_replace_tool_tab_only_hides_source_if_target_available(self): - with self.api.client: + async def test_manager_replace_tool_tab_only_hides_source_if_target_available(self): + async with self.api.client: # Set up - old_source_tab = self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=False) - old_target_tab = self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=False) + old_source_tab = await self.course_manager.update_tool_tab(tab=self.source_tab, is_hidden=False) + old_target_tab = await self.course_manager.update_tool_tab(tab=self.target_tab, is_hidden=False) - new_source_tab, new_target_tab = self.course_manager.replace_tool_tab( + new_source_tab, new_target_tab = await self.course_manager.replace_tool_tab( old_source_tab, old_target_tab ) @@ -405,7 +411,7 @@ def sleep(duration: int): self.assertRegex(cm.output[0], re.compile(r'sleep took \d+\.\d+ seconds to complete\.')) -class MainTestCase(unittest.TestCase): +class MainTestCase(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: api_url: str = os.getenv('API_URL', '') @@ -417,15 +423,16 @@ def setUp(self) -> None: self.source_tool_id: int = int(os.getenv('SOURCE_TOOL_ID', '0')) self.target_tool_id: int = int(os.getenv('TARGET_TOOL_ID', '0')) - def test_find_tool_ids_for_migrations_raises_exception_when_tool_ids_are_invalid(self): - with self.api.client: + async def test_find_tool_ids_for_migrations_raises_exception_when_tool_ids_are_invalid(self): + async with self.api.client: account_manager = AccountManager(self.account_id, self.api) - tools = account_manager.get_tools_installed_in_account() + tools = await account_manager.get_tools_installed_in_account() with self.assertRaises(InvalidToolIdsException): find_tools_for_migrations(tools, [ToolMigration(100000000, 100000001)]) def test_main_migrates_tool_successfully(self): - main( + trio.run( + main, self.api, self.account_id, self.enrollment_term_ids, @@ -447,4 +454,7 @@ def test_main_migrates_tool_successfully(self): httpcore_level = logging.getLogger('httpcore') httpcore_level.setLevel(http_log_level) + asyncio_logger = logging.getLogger('asyncio') + asyncio_logger.setLevel(logging.ERROR) + unittest.main() diff --git a/migration/utils.py b/migration/utils.py index f8cab3e..6e5b31f 100644 --- a/migration/utils.py +++ b/migration/utils.py @@ -52,9 +52,9 @@ def chunk_integer(value: int, num_chunks: int) -> list[int]: def time_execution(callable: Callable) -> Callable: @functools.wraps(callable) - def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs): start = time.time() - result = callable(*args, **kwargs) + result = await callable(*args, **kwargs) end = time.time() delta = end - start logger.info(f'{callable.__qualname__} took {delta} seconds to complete.') diff --git a/requirements.txt b/requirements.txt index 3c792bb..1b628c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ psycopg2-binary==2.9.6 python-dotenv==1.0.0 SQLAlchemy==1.4.48 tenacity==8.2.2 +tqdm==4.65.0 +trio==0.22.1