From cf9774622ab28c657ceca41456e6b240911fc20b Mon Sep 17 00:00:00 2001 From: zjgemi Date: Fri, 17 Nov 2023 14:35:24 +0800 Subject: [PATCH] fix: support retry multiple steps in a workflow Signed-off-by: zjgemi --- src/dflow/argo_objects.py | 32 ++------------------------------ src/dflow/main.py | 3 +-- src/dflow/workflow.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/dflow/argo_objects.py b/src/dflow/argo_objects.py index 2dfec1c5..e5e6c6a0 100644 --- a/src/dflow/argo_objects.py +++ b/src/dflow/argo_objects.py @@ -207,37 +207,9 @@ def modify_output_artifact( self.outputs.artifacts[name].modified = {"old_key": old_key} def retry(self): - from .workflow import Workflow, get_argo_api_client + from .workflow import Workflow wf = Workflow(id=self.workflow) - assert wf.query_status() == "Running" - logger.info("Suspend workflow %s..." % self.workflow) - wf.suspend() - time.sleep(5) - - logger.info("Query workflow %s..." % self.workflow) - wf_info = wf.query().recover() - nodes = wf_info["status"]["nodes"] - patch = {"status": {"nodes": {}}} - patch["status"]["nodes"][self.id] = {"phase": "Pending"} - for node in nodes.values(): - if node["name"] != self.name and self.name.startswith( - node["name"]) and node["phase"] == "Failed": - patch["status"]["nodes"][node["id"]] = {"phase": "Running"} - - logger.info("Delete pod of step %s..." % self.id) - self.delete_pod() - with get_argo_api_client() as api_client: - logger.info("Update workflow %s..." % self.workflow) - api_client.call_api( - '/api/v1/workflows/%s/%s' % ( - config["namespace"], self.workflow), - 'PUT', response_type='object', - header_params=config["http_headers"], - body={"patch": json.dumps(patch)}, - _return_http_data_only=True) - - logger.info("Resume workflow %s..." % self.workflow) - wf.resume() + wf.retry_steps([self.id]) def get_pod(self): assert self.type == "Pod" diff --git a/src/dflow/main.py b/src/dflow/main.py index c3911f1b..31b12469 100644 --- a/src/dflow/main.py +++ b/src/dflow/main.py @@ -446,8 +446,7 @@ def main(): wf_id = args.ID wf = Workflow(id=wf_id) if args.step is not None: - step = wf.query_step(id=args.step)[0] - step.retry() + wf.retry_steps(args.step.split(",")) else: wf.retry() elif args.command == "stop": diff --git a/src/dflow/workflow.py b/src/dflow/workflow.py index 63226f7f..0e82c3a7 100644 --- a/src/dflow/workflow.py +++ b/src/dflow/workflow.py @@ -1198,6 +1198,40 @@ def suspend(self) -> None: '/api/v1/workflows/%s/%s/suspend' % (self.namespace, self.id), 'PUT', header_params=config["http_headers"]) + def retry_steps(self, step_ids): + assert self.query_status() == "Running" + logger.info("Suspend workflow %s..." % self.id) + self.suspend() + time.sleep(5) + + logger.info("Query workflow %s..." % self.id) + wf_info = self.query().recover() + nodes = wf_info["status"]["nodes"] + patch = {"status": {"nodes": {}}} + for step_id in step_ids: + step = ArgoStep(nodes[step_id], self.id) + patch["status"]["nodes"][step_id] = {"phase": "Pending"} + for node in nodes.values(): + if node["name"] != step.name and step.name.startswith( + node["name"]) and node["phase"] == "Failed": + patch["status"]["nodes"][node["id"]] = {"phase": "Running"} + + logger.info("Delete pod of step %s..." % step_id) + step.delete_pod() + + with get_argo_api_client() as api_client: + logger.info("Update workflow %s..." % self.id) + api_client.call_api( + '/api/v1/workflows/%s/%s' % ( + config["namespace"], self.id), + 'PUT', response_type='object', + header_params=config["http_headers"], + body={"patch": json.dumps(patch)}, + _return_http_data_only=True) + + logger.info("Resume workflow %s..." % self.id) + self.resume() + def get_argo_api_client(host=None, token=None): if host is None: