Skip to content

Commit

Permalink
Use trio and async for API calls, make tab updates concurrently (#3
Browse files Browse the repository at this point in the history
…) (#12)

* Add first draft of async conversion

* Add connection limit using httpx; fix asyncio/trio issues

* Remove log messages

* Create course log message wrap function

* Add missing trio requirement

* Add tqdm progress bar for course nav updating
  • Loading branch information
ssciolla authored Jul 14, 2023
1 parent ec73308 commit 3d0f76f
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 150 deletions.
30 changes: 20 additions & 10 deletions migration/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
retry_if_exception_type,
stop_after_attempt
)
import trio


logger = logging.getLogger(__name__)

MAX_ATTEMPT_NUM = 4
MAX_ASYNC_CONNS = 20


class EndpointType(Enum):
Expand All @@ -30,7 +32,7 @@ class GetResponse:


class API:
client: httpx.Client
client: httpx.AsyncClient

def __init__(
self,
Expand All @@ -40,7 +42,13 @@ def __init__(
timeout: int = 10
):
headers = {'Authorization': f'Bearer {key}'}
self.client = httpx.Client(base_url=url + endpoint_type.value, headers=headers, timeout=timeout)
limits = httpx.Limits(max_connections=MAX_ASYNC_CONNS)
self.client = httpx.AsyncClient(
base_url=url + endpoint_type.value,
headers=headers,
timeout=timeout,
limits=limits
)

@staticmethod
def get_next_page_params(resp: httpx.Response) -> dict[str, Any] | None:
Expand All @@ -54,10 +62,11 @@ def get_next_page_params(resp: httpx.Response) -> dict[str, Any] | None:
stop=stop_after_attempt(MAX_ATTEMPT_NUM),
retry=retry_if_exception_type((httpx.HTTPError, JSONDecodeError)),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARN)
before_sleep=before_sleep_log(logger, logging.WARN),
sleep=trio.sleep
)
def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse:
resp = self.client.get(url=url, params=params)
async def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse:
resp = await self.client.get(url=url, params=params)
resp.raise_for_status()
data = resp.json()
next_page_params = self.get_next_page_params(resp)
Expand All @@ -67,14 +76,15 @@ def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse:
stop=stop_after_attempt(MAX_ATTEMPT_NUM),
retry=retry_if_exception_type((httpx.HTTPError, JSONDecodeError)),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARN)
before_sleep=before_sleep_log(logger, logging.WARN),
sleep=trio.sleep
)
def put(self, url: str, params: dict[str, Any] | None = None) -> Any:
resp = self.client.put(url=url, params=params)
async def put(self, url: str, params: dict[str, Any] | None = None) -> Any:
resp = await self.client.put(url=url, params=params)
resp.raise_for_status()
return resp.json()

def get_results_from_pages(
async def get_results_from_pages(
self, endpoint: str, params: dict[str, Any] | None = None, page_size: int = 50, limit: int | None = None
) -> list[dict[str, Any]]:
extra_params: dict[str, Any]
Expand All @@ -90,7 +100,7 @@ def get_results_from_pages(

while more_pages:
logger.debug(f'Params: {extra_params}')
get_resp = self.get(url=endpoint, params=extra_params)
get_resp = await self.get(url=endpoint, params=extra_params)
results += get_resp.data
if get_resp.next_page_params is None:
more_pages = False
Expand Down
86 changes: 55 additions & 31 deletions migration/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os
from contextlib import nullcontext

import trio
from dotenv import load_dotenv
from tqdm import tqdm

from api import API
from data import ExternalTool, ToolMigration
from data import Course, ExternalTool, ToolMigration
from db import DB, Dialect
from exceptions import InvalidToolIdsException
from manager import AccountManagerFactory, CourseManager
Expand All @@ -15,6 +17,15 @@
logger = logging.getLogger(__name__)


class TrioProgress(trio.abc.Instrument):

def __init__(self, total, **kwargs):
self.tqdm = tqdm(total=total, **kwargs)

def task_exited(self, task):
self.tqdm.update(1)


def find_tools_for_migrations(
tools: list[ExternalTool], migrations: list[ToolMigration]
) -> list[tuple[ExternalTool, ExternalTool]]:
Expand All @@ -36,39 +47,51 @@ def find_tools_for_migrations(
return tool_pairs


async def migrate_tool_for_course(api: API, course: Course, source_tool: ExternalTool, target_tool: ExternalTool):
course_manager = CourseManager(course, api)
tabs = await course_manager.get_tool_tabs()
source_tool_tab = CourseManager.find_tab_by_tool_id(source_tool.id, tabs)
target_tool_tab = CourseManager.find_tab_by_tool_id(target_tool.id, tabs)
if source_tool_tab is None or target_tool_tab is None:
raise InvalidToolIdsException(
'One or both of the following tool IDs are not available in this course: ' +
str([source_tool.id, target_tool.id])
)
await course_manager.replace_tool_tab(source_tool_tab, target_tool_tab)


@time_execution
def main(api: API, account_id: int, term_ids: list[int], migrations: list[ToolMigration], db: DB | None = None):

async def main(api: API, account_id: int, term_ids: list[int], migrations: list[ToolMigration], db: DB | None = None):
factory = AccountManagerFactory()
account_manager = factory.get_manager(account_id, api, db)

with api.client, db if db is not None else nullcontext(): # type: ignore
tools = account_manager.get_tools_installed_in_account()
logger.info(f'Number of tools found in account {account_id}: {len(tools)}')

tool_pairs = find_tools_for_migrations(tools, migrations)

# get list of tools available in account
courses = account_manager.get_courses_in_terms(term_ids)
logger.info(f'Number of courses found in account {account_id} for terms {term_ids}: {len(courses)}')

for source_tool, target_tool in tool_pairs:
logger.info(f'Source tool: {source_tool}')
logger.info(f'Target tool: {target_tool}')

for course in courses:
# Replace target tool with source tool in course navigation
course_manager = CourseManager(course, api)
tabs = course_manager.get_tool_tabs()
source_tool_tab = CourseManager.find_tab_by_tool_id(source_tool.id, tabs)
target_tool_tab = CourseManager.find_tab_by_tool_id(target_tool.id, tabs)
if source_tool_tab is None or target_tool_tab is None:
raise InvalidToolIdsException(
'One or both of the following tool IDs are not available in this course: ' +
str([source_tool.id, target_tool.id])
)
course_manager.replace_tool_tab(source_tool_tab, target_tool_tab)

async with api.client:
with db if db is not None else nullcontext(): # type: ignore
tools = await account_manager.get_tools_installed_in_account()
logger.info(f'Number of tools found in account {account_id}: {len(tools)}')

tool_pairs = find_tools_for_migrations(tools, migrations)

# get list of tools available in account
courses = await account_manager.get_courses_in_terms(term_ids)
logger.info(f'Number of courses found in account {account_id} for terms {term_ids}: {len(courses)}')

for source_tool, target_tool in tool_pairs:
logger.info(f'Source tool: {source_tool}')
logger.info(f'Target tool: {target_tool}')

progress = TrioProgress(total=len(courses))
trio.lowlevel.add_instrument(progress)
async with trio.open_nursery() as nursery:
for course in courses:
nursery.start_soon(
migrate_tool_for_course,
api,
course,
source_tool,
target_tool
)
trio.lowlevel.remove_instrument(progress)

if __name__ == '__main__':
# get configuration (either env. variables, cli flags, or direct input)
Expand Down Expand Up @@ -122,7 +145,8 @@ def main(api: API, account_id: int, term_ids: list[int], migrations: list[ToolMi
source_tool_id: int = int(os.getenv('SOURCE_TOOL_ID', '0'))
target_tool_id: int = int(os.getenv('TARGET_TOOL_ID', '0'))

main(
trio.run(
main,
API(api_url, api_key),
account_id,
enrollment_term_ids,
Expand Down
62 changes: 37 additions & 25 deletions migration/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,34 @@ class AccountManagerBase(ABC):
account_id: int

@abstractmethod
def get_tools_installed_in_account(self) -> list[ExternalTool]:
async def get_tools_installed_in_account(self) -> list[ExternalTool]:
pass

@abstractmethod
def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
pass


@dataclass
class AccountManager(AccountManagerBase):
api: API

def get_tools_installed_in_account(self) -> list[ExternalTool]:
async def get_tools_installed_in_account(self) -> list[ExternalTool]:
params = {"include_parents": True}
results = self.api.get_results_from_pages(f'/accounts/{self.account_id}/external_tools', params)
results = await self.api.get_results_from_pages(f'/accounts/{self.account_id}/external_tools', params)
tools = [ExternalTool(id=tool_dict['id'], name=tool_dict['name']) for tool_dict in results]
return tools

@time_execution
def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
limit_chunks = None
if limit is not None:
limit_chunks = chunk_integer(limit, len(term_ids))

results: list[dict[str, Any]] = []
for i, term_id in enumerate(term_ids):
limit_for_term = limit_chunks[i] if limit_chunks is not None else None
term_results = self.api.get_results_from_pages(
term_results = await self.api.get_results_from_pages(
f'/accounts/{self.account_id}/courses',
params={ 'enrollment_term_id': term_id },
page_size=50,
Expand All @@ -73,20 +73,20 @@ class WarehouseAccountManager(AccountManagerBase):
def __post_init__(self):
self.account_manager = AccountManager(self.account_id, self.api)

def get_tools_installed_in_account(self) -> list[ExternalTool]:
return self.account_manager.get_tools_installed_in_account()
async def get_tools_installed_in_account(self) -> list[ExternalTool]:
return await self.account_manager.get_tools_installed_in_account()

def get_subaccount_ids(self) -> list[int]:
results = self.api.get_results_from_pages(
async def get_subaccount_ids(self) -> list[int]:
results = await self.api.get_results_from_pages(
f'/accounts/{self.account_id}/sub_accounts', { 'recursive': True }
)
sub_account_ids = [result['id'] for result in results]
logger.debug(sub_account_ids)
return sub_account_ids

@time_execution
def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
account_ids = [self.account_id] + self.get_subaccount_ids()
async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
account_ids = [self.account_id] + await self.get_subaccount_ids()

conn = self.db.get_connection()
statement = sqlalchemy.text(f'''
Expand Down Expand Up @@ -148,8 +148,11 @@ def convert_data_to_tool_tab(cls, data: dict[str, Any]) -> ExternalToolTab:
position=data['position']
)

def get_tool_tabs(self) -> list[ExternalToolTab]:
results = self.api.get_results_from_pages(f'/courses/{self.course.id}/tabs')
def create_course_log_message(self, message: str) -> str:
return f'{self.course}\n{message}\n- - -'

async def get_tool_tabs(self) -> list[ExternalToolTab]:
results = await self.api.get_results_from_pages(f'/courses/{self.course.id}/tabs')

tabs: list[ExternalToolTab] = []
for result in results:
Expand All @@ -158,43 +161,52 @@ def get_tool_tabs(self) -> list[ExternalToolTab]:
tabs.append(CourseManager.convert_data_to_tool_tab(result))
return tabs

def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool, position: int | None = None):
async def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool, position: int | None = None):
params: dict[str, Any] = { "hidden": is_hidden }
if position is not None:
params.update({ "position": position })

result = self.api.put(
result = await self.api.put(
f'/courses/{self.course.id}/tabs/{tab.id}',
params=params
)
logger.debug(result)
return CourseManager.convert_data_to_tool_tab(result)

def replace_tool_tab(
async def replace_tool_tab(
self, source_tab: ExternalToolTab, target_tab: ExternalToolTab
) -> tuple[ExternalToolTab, ExternalToolTab]:
logger.debug([source_tab, target_tab])

# Source tool is hidden in course, don't do anything
if source_tab.is_hidden:
logger.debug(f'Skipping replacement for {[source_tab, target_tab]}; source tool is hidden.')
logger.debug(self.create_course_log_message(
f'Skipping replacement for {[source_tab, target_tab]}; source tool is hidden.'
))
return (source_tab, target_tab)
else:
if not target_tab.is_hidden:
logger.warning(
logger.warning(self.create_course_log_message(
f'Both tools ({[source_tab, target_tab]}) are currently available. ' +
'Rolling back will hide the target tool!'
)
logger.debug((f'Skipping update for {target_tab}; tool is already available.'))
))
logger.debug(self.create_course_log_message(
f'Skipping update for {target_tab}; tool is already available.'
))
new_target_tab = target_tab
else:
target_position = source_tab.position
new_target_tab = self.update_tool_tab(tab=target_tab, is_hidden=False, position=target_position)
logger.info(f"Made available target tool in course's navigation: {new_target_tab}")
new_target_tab = await self.update_tool_tab(tab=target_tab, is_hidden=False, position=target_position)
logger.info(self.create_course_log_message(
f"Made available target tool in course's navigation: {new_target_tab}"
))

# Always hide the source tool if it's available
new_source_tab = self.update_tool_tab(tab=source_tab, is_hidden=True)
logger.info(f"Hid source tool in course's navigation: {new_source_tab}")
new_source_tab = await self.update_tool_tab(tab=source_tab, is_hidden=True)
logger.info(self.create_course_log_message(
f"Hid source tool in course's navigation: {new_source_tab}"
))

return (new_source_tab, new_target_tab)


Expand Down
Loading

0 comments on commit 3d0f76f

Please sign in to comment.