diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 0584d88168..f9d54bbd73 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -228,10 +228,12 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain cpu_request=self.resources.requests.cpu, gpu_request=self.resources.requests.gpu, memory_request=self.resources.requests.mem, + oom_reserved_memory_request = self.resources.requests.oom_reserved_mem, ephemeral_storage_limit=self.resources.limits.ephemeral_storage, cpu_limit=self.resources.limits.cpu, gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, + oom_reserved_memory_limit=self.resources.limits.oom_reserved_mem, ) def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index c911bdb161..5fd51f6cc2 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -40,6 +40,7 @@ class Resources(DataClassJSONMixin): mem: Optional[Union[str, int]] = None gpu: Optional[Union[str, int]] = None ephemeral_storage: Optional[Union[str, int]] = None + oom_reserved_mem: Optional[Union[str, int]] = None def __post_init__(self): def _check_cpu(value): @@ -58,6 +59,7 @@ def _check_others(value): _check_others(self.mem) _check_others(self.gpu) _check_others(self.ephemeral_storage) + _check_others(self.oom_reserved_mem) @dataclass @@ -85,6 +87,8 @@ def _convert_resources_to_resource_entries(resources: Resources) -> List[_Resour value=str(resources.ephemeral_storage), ) ) + if resources.oom_reserved_mem is not None: + resource_entries.append(_ResourceEntry(name=_ResourceName.OOM_RESERVED_MEMORY, value=str(resources.oom_reserved_mem))) return resource_entries @@ -154,6 +158,7 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour "mem": "memory", "gpu": k8s_gpu_resource_key, "ephemeral_storage": "ephemeral-storage", + "oom_reserved_mem": "oom-reserved-memory", } k8s_pod_resources = {} diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 9f1967d2f9..0dd2587c52 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -69,10 +69,12 @@ def _get_container_definition( cpu_request: Optional[Union[str, int, float]] = None, gpu_request: Optional[Union[str, int]] = None, memory_request: Optional[Union[str, int]] = None, + oom_reserved_memory_request: Optional[Union[str, int]] = None, ephemeral_storage_limit: Optional[Union[str, int]] = None, cpu_limit: Optional[Union[str, int, float]] = None, gpu_limit: Optional[Union[str, int]] = None, memory_limit: Optional[Union[str, int]] = None, + oom_reserved_memory_limit: Optional[Union[str, int]] = None, environment: Optional[Dict[str, str]] = None, ) -> "task_models.Container": ephemeral_storage_limit = ephemeral_storage_limit @@ -83,6 +85,8 @@ def _get_container_definition( gpu_request = gpu_request memory_limit = memory_limit memory_request = memory_request + oom_reserved_memory_limit = oom_reserved_memory_limit + oom_reserved_memory_request = oom_reserved_memory_request from flytekit.models import task as task_models @@ -101,6 +105,13 @@ def _get_container_definition( requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_request)) + if oom_reserved_memory_request: + requests.append( + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.OOM_RESERVED_MEMORY, + oom_reserved_memory_request, + ) + ) limits = [] if ephemeral_storage_limit: @@ -116,6 +127,13 @@ def _get_container_definition( limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_limit)) + if oom_reserved_memory_limit: + limits.append( + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.OOM_RESERVED_MEMORY, + oom_reserved_memory_limit, + ) + ) if environment is None: environment = {} diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 88430aa28a..5dc176296e 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -28,6 +28,7 @@ class ResourceName(object): GPU = _core_task.Resources.GPU MEMORY = _core_task.Resources.MEMORY EPHEMERAL_STORAGE = _core_task.Resources.EPHEMERAL_STORAGE + OOM_RESERVED_MEMORY = _core_task.Resources.OOM_RESERVED_MEMORY class ResourceEntry(_common.FlyteIdlEntity): def __init__(self, name, value):