Skip to content

Commit

Permalink
Merge pull request #707 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: support retry multiple steps in a workflow
  • Loading branch information
zjgemi authored Nov 17, 2023
2 parents 6b20a2d + cf97746 commit f812ab8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 32 deletions.
32 changes: 2 additions & 30 deletions src/dflow/argo_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,37 +207,9 @@ def modify_output_artifact(
self.outputs.artifacts[name].modified = {"old_key": old_key}

def retry(self):
from .workflow import Workflow, get_argo_api_client
from .workflow import Workflow
wf = Workflow(id=self.workflow)
assert wf.query_status() == "Running"
logger.info("Suspend workflow %s..." % self.workflow)
wf.suspend()
time.sleep(5)

logger.info("Query workflow %s..." % self.workflow)
wf_info = wf.query().recover()
nodes = wf_info["status"]["nodes"]
patch = {"status": {"nodes": {}}}
patch["status"]["nodes"][self.id] = {"phase": "Pending"}
for node in nodes.values():
if node["name"] != self.name and self.name.startswith(
node["name"]) and node["phase"] == "Failed":
patch["status"]["nodes"][node["id"]] = {"phase": "Running"}

logger.info("Delete pod of step %s..." % self.id)
self.delete_pod()
with get_argo_api_client() as api_client:
logger.info("Update workflow %s..." % self.workflow)
api_client.call_api(
'/api/v1/workflows/%s/%s' % (
config["namespace"], self.workflow),
'PUT', response_type='object',
header_params=config["http_headers"],
body={"patch": json.dumps(patch)},
_return_http_data_only=True)

logger.info("Resume workflow %s..." % self.workflow)
wf.resume()
wf.retry_steps([self.id])

def get_pod(self):
assert self.type == "Pod"
Expand Down
3 changes: 1 addition & 2 deletions src/dflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,7 @@ def main():
wf_id = args.ID
wf = Workflow(id=wf_id)
if args.step is not None:
step = wf.query_step(id=args.step)[0]
step.retry()
wf.retry_steps(args.step.split(","))
else:
wf.retry()
elif args.command == "stop":
Expand Down
34 changes: 34 additions & 0 deletions src/dflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,40 @@ def suspend(self) -> None:
'/api/v1/workflows/%s/%s/suspend' % (self.namespace, self.id),
'PUT', header_params=config["http_headers"])

def retry_steps(self, step_ids):
assert self.query_status() == "Running"
logger.info("Suspend workflow %s..." % self.id)
self.suspend()
time.sleep(5)

logger.info("Query workflow %s..." % self.id)
wf_info = self.query().recover()
nodes = wf_info["status"]["nodes"]
patch = {"status": {"nodes": {}}}
for step_id in step_ids:
step = ArgoStep(nodes[step_id], self.id)
patch["status"]["nodes"][step_id] = {"phase": "Pending"}
for node in nodes.values():
if node["name"] != step.name and step.name.startswith(
node["name"]) and node["phase"] == "Failed":
patch["status"]["nodes"][node["id"]] = {"phase": "Running"}

logger.info("Delete pod of step %s..." % step_id)
step.delete_pod()

with get_argo_api_client() as api_client:
logger.info("Update workflow %s..." % self.id)
api_client.call_api(
'/api/v1/workflows/%s/%s' % (
config["namespace"], self.id),
'PUT', response_type='object',
header_params=config["http_headers"],
body={"patch": json.dumps(patch)},
_return_http_data_only=True)

logger.info("Resume workflow %s..." % self.id)
self.resume()


def get_argo_api_client(host=None, token=None):
if host is None:
Expand Down

0 comments on commit f812ab8

Please sign in to comment.