Skip to content

Commit

Permalink
raise exception if model init fails and destroy model folder only on …
Browse files Browse the repository at this point in the history
…failure
  • Loading branch information
denniswittich committed Dec 12, 2024
1 parent dbda78d commit cff43b4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
5 changes: 3 additions & 2 deletions learning_loop_node/detector/detector_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def load_model(self):
logging.info('Loading model from %s', GLOBALS.data_folder)
model_info = ModelInformation.load_from_disk(f'{GLOBALS.data_folder}/model')
if model_info is None:
logging.warning('No model found')
logging.error('No model found')
self._model_info = None
return
raise Exception('No model found')
try:
self._model_info = model_info
self.init()
logging.info('Successfully loaded model %s', self._model_info)
except Exception:
self._model_info = None
logging.error('Could not init model %s', model_info)
raise

Expand Down
23 changes: 15 additions & 8 deletions learning_loop_node/detector/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def get_model_version_response(self) -> ModelVersionResponse:

async def set_model_version_mode(self, version_control_mode: str) -> None:

self.log.info('Setting model version mode to %s', version_control_mode)

if version_control_mode == 'follow_loop':
self.version_control = VersionMode.FollowLoop
elif version_control_mode == 'pause':
Expand Down Expand Up @@ -324,13 +326,13 @@ async def _check_for_update(self) -> None:
with step_into(GLOBALS.data_folder):
model_symlink = 'model'
target_model_folder = f'models/{self.target_model.version}'
shutil.rmtree(target_model_folder, ignore_errors=True)
os.makedirs(target_model_folder)

await self.data_exchanger.download_model(target_model_folder,
Context(organization=self.organization,
project=self.project),
self.target_model.id, self.detector_logic.model_format)
if not os.path.exists(target_model_folder):
os.makedirs(target_model_folder)
await self.data_exchanger.download_model(target_model_folder,
Context(organization=self.organization,
project=self.project),
self.target_model.id,
self.detector_logic.model_format)
try:
os.unlink(model_symlink)
os.remove(model_symlink)
Expand All @@ -339,7 +341,12 @@ async def _check_for_update(self) -> None:
os.symlink(target_model_folder, model_symlink)
self.log.info('Updated symlink for model to %s', os.readlink(model_symlink))

self.detector_logic.load_model()
try:
self.detector_logic.load_model()
except Exception:
self.log.exception('Could not load model, will retry download on next check')
shutil.rmtree(target_model_folder, ignore_errors=True)
return
try:
await self.sync_status_with_learning_loop()
except Exception:
Expand Down

0 comments on commit cff43b4

Please sign in to comment.