Skip to content

Commit

Permalink
Update SM Python SDK for PT 2.3.0 SM DLC
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithn1 committed Jun 3, 2024
1 parent b68a810 commit 393a426
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
"2.1.0",
"2.1.2",
"2.2.0",
"2.3.0",
],
}

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ def test_validate_smdataparallel_args_not_raises():
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled),
("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled),
("ml.p3.16xlarge", "pytorch", "2.2.0", "py310", smdataparallel_enabled),
("ml.p3.16xlarge", "pytorch", "2.3.0", "py311", smdataparallel_enabled),
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
Expand Down Expand Up @@ -991,7 +992,7 @@ def test_validate_torch_distributed_not_raises():

# Case 3: Distribution is torch_distributed enabled, supported framework and instances
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
torch_distributed_gpu_supported_fw_versions = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.2.0"]
torch_distributed_gpu_supported_fw_versions = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.2.0", "2.3.0"]
for framework_version in torch_distributed_gpu_supported_fw_versions:
fw_utils.validate_torch_distributed_distribution(
instance_type="ml.p3.8xlarge",
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
can_model_package_source_uri_autopopulate,
_resolve_routing_config,
)
import sagemaker.utils
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString, ParameterInteger

Expand Down Expand Up @@ -392,6 +393,7 @@ def test_set_nested_value():


def test_get_short_version():
assert sagemaker.utils.get_short_version("2.3.0") == "2.3"
assert sagemaker.utils.get_short_version("2.2.0") == "2.2"
assert sagemaker.utils.get_short_version("2.2") == "2.2"
assert sagemaker.utils.get_short_version("2.1.0") == "2.1"
Expand Down

0 comments on commit 393a426

Please sign in to comment.