Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam-D-Lewis committed May 9, 2023
1 parent 1adae15 commit 83c5bbf
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 13 deletions.
15 changes: 13 additions & 2 deletions nebari_workflow_controller/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_container_keep_portions,
get_keycloak_user,
get_spec_keep_portions,
get_spec_to_container_portions,
get_user_pod_spec,
mutate_template,
process_unhandled_exception,
Expand Down Expand Up @@ -119,7 +120,7 @@ def validate(request=Body(...)):

@app.post("/mutate")
def mutate(request=Body(...)):
logger.warn(f"Received request: \n\n{request}")
logger.debug(f"Received request: \n\n{request}")
return_response = partial(
base_return_response,
apiVersion=request["apiVersion"],
Expand All @@ -143,6 +144,9 @@ def mutate(request=Body(...)):

container_keep_portions = get_container_keep_portions(user_pod_spec, api)
spec_keep_portions = get_spec_keep_portions(user_pod_spec, api)
spec_to_container_portions = get_spec_to_container_portions(
user_pod_spec, api
)

if spec["kind"] == "Workflow":
templates = modified_spec["spec"]["templates"]
Expand All @@ -152,7 +156,12 @@ def mutate(request=Body(...)):
raise Exception("Only expecting Workflow or CronWorkflow")

for template in templates:
mutate_template(container_keep_portions, spec_keep_portions, template)
mutate_template(
container_keep_portions,
spec_keep_portions,
template,
spec_to_container_portions,
)

patch = jsonpatch.JsonPatch.from_diff(spec, modified_spec)
return return_response(
Expand All @@ -162,5 +171,7 @@ def mutate(request=Body(...)):
)
else:
return return_response(True)
except NWFCUnsupportedException as e:
return return_response(False, message=str(e))
except Exception as e:
return process_unhandled_exception(e, return_response, logger)
57 changes: 51 additions & 6 deletions nebari_workflow_controller/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ def get_user_pod_spec(keycloak_user):
return jupyter_pod_spec


def get_spec_to_container_portions(user_pod_spec, api):
return [
(
api.sanitize_for_serialization(user_pod_spec.spec.node_selector),
"nodeSelector",
)
]


def get_spec_keep_portions(user_pod_spec, api):
return [
(
Expand Down Expand Up @@ -211,6 +220,25 @@ def get_spec_keep_portions(user_pod_spec, api):
]


def recursive_dict_merge(greater_dict, lesser_dict, path=None):
"merges lesser_dict into greater_dict and assigns to greater_dict, greater_dict value takes precedence in case of conflict between greater and lesser dict"
if path is None:
path = []
for key in lesser_dict:
if key in greater_dict:
if isinstance(greater_dict[key], dict) and isinstance(
lesser_dict[key], dict
):
recursive_dict_merge(
greater_dict[key], lesser_dict[key], path + [str(key)]
)
else:
pass
else:
greater_dict[key] = lesser_dict[key]
return greater_dict


def get_container_keep_portions(user_pod_spec, api):
return [
(user_pod_spec.spec.containers[0].image, "image"),
Expand Down Expand Up @@ -246,17 +274,19 @@ def get_container_keep_portions(user_pod_spec, api):
]


def mutate_template(container_keep_portions, spec_keep_portions, template):
def mutate_template(
container_keep_portions,
spec_keep_portions,
template,
spec_to_container_portions=None,
):
for value, key in container_keep_portions:
if "container" not in template:
continue

if isinstance(value, dict):
if key in template["container"]:
template["container"][key] = {
**template["container"][key],
**value,
}
recursive_dict_merge(template["container"][key], value)
else:
template["container"][key] = value
elif isinstance(value, list):
Expand All @@ -270,7 +300,7 @@ def mutate_template(container_keep_portions, spec_keep_portions, template):
for value, key in spec_keep_portions:
if isinstance(value, dict):
if key in template:
template[key] = {**template[key], **value}
recursive_dict_merge(template[key], value)
else:
template[key] = value
elif isinstance(value, list):
Expand All @@ -280,3 +310,18 @@ def mutate_template(container_keep_portions, spec_keep_portions, template):
template[key] = value
else:
template[key] = value

if spec_to_container_portions:
for value, key in spec_to_container_portions:
if isinstance(value, dict):
if key in template["container"]:
recursive_dict_merge(template["container"][key], value)
else:
template["container"][key] = value
elif isinstance(value, list):
if key in template["container"]:
template["container"][key].append(value)
else:
template["container"][key] = value
else:
template["container"][key] = value
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from nebari_workflow_controller.models import KeycloakGroup, KeycloakUser

os.environ["NAMESPACE"] = "default"
os.environ["NAMESPACE"] = "dev"


def _valid_request_paths():
Expand Down
16 changes: 14 additions & 2 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_validate(request_file, allowed, mocked_get_keycloak_user):
with open(request_file) as f:
request = yaml.load(f, Loader=yaml.FullLoader)
response = validate(request)
print(response)

assert response["response"]["allowed"] == allowed
if not allowed:
assert response["response"]["status"]["message"]
Expand All @@ -45,7 +45,9 @@ def test_mutate_template_doesnt_error(request_templates, jupyterlab_pod_spec):
@pytest.mark.parametrize(
"request_file", ["tests/test_data/requests/valid/jupyterflow-override-example.yaml"]
)
def test_mutate2(request_file, mocked_get_keycloak_user_info, mocked_get_user_pod_spec):
def test_mutate_check_content(
request_file, mocked_get_keycloak_user, mocked_get_user_pod_spec
):
with open(request_file) as f:
request = yaml.load(f, Loader=yaml.FullLoader)
response = mutate(request)
Expand Down Expand Up @@ -77,3 +79,13 @@ def test_mutate2(request_file, mocked_get_keycloak_user_info, mocked_get_user_po
},
]:
assert volume in mutated_spec["spec"]["templates"][0]["volumes"]

assert mutated_spec["spec"]["templates"][0]["container"]["nodeSelector"] == {
"mylabel": "myValue",
"cloud.google.com/gke-nodepool": "user",
}

assert mutated_spec["spec"]["templates"][0]["container"]["resources"] == {
"requests": {"cpu": "3000m", "memory": "5368709120"},
"limits": {"cpu": "2", "memory": "8589934592"},
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ request:
- '-c'
args:
- conda run -n nebari-git-dask python script.py
resources: {}
nodeSpec:
resources:
requests:
cpu: '3000m'
nodeSelector:
mylabel: myValue
entrypoint: argosay
uid: c1bba5c6-2189-41ff-9487-be504c04487b
Expand Down

0 comments on commit 83c5bbf

Please sign in to comment.