Skip to content

Commit 047079b

Browse files
committed
feature: add emr-serverless step for SageMaker Pipelines
1 parent 462bed0 commit 047079b

File tree

3 files changed

+322
-1
lines changed

3 files changed

+322
-1
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for EMR Serverless workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import Any, Dict, List, Union, Optional
17+
18+
from sagemaker.workflow.entities import (
19+
RequestType,
20+
)
21+
from sagemaker.workflow.properties import (
22+
Properties,
23+
)
24+
from sagemaker.workflow.retry import StepRetryPolicy
25+
from sagemaker.workflow.step_collections import StepCollection
26+
from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig
27+
28+
29+
class EMRServerlessJobConfig:
30+
"""Config for EMR Serverless job."""
31+
32+
def __init__(
33+
self,
34+
job_driver: Dict,
35+
execution_role_arn: str,
36+
configuration_overrides: Optional[Dict] = None,
37+
execution_timeout_minutes: Optional[int] = None,
38+
name: Optional[str] = None,
39+
tags: Optional[Dict[str, str]] = None,
40+
): # pylint: disable=too-many-positional-arguments
41+
"""Create a definition for EMR Serverless job configuration.
42+
43+
Args:
44+
job_driver (Dict): The job driver for the job run.
45+
execution_role_arn (str): The execution role ARN for the job run.
46+
configuration_overrides (Dict, optional): Configuration overrides for the job run.
47+
execution_timeout_minutes (int, optional): The maximum duration for the job run.
48+
name (str, optional): The optional job run name.
49+
tags (Dict[str, str], optional): The tags assigned to the job run.
50+
"""
51+
self.job_driver = job_driver
52+
self.execution_role_arn = execution_role_arn
53+
self.configuration_overrides = configuration_overrides
54+
self.execution_timeout_minutes = execution_timeout_minutes
55+
self.name = name
56+
self.tags = tags
57+
58+
def to_request(self, application_id: Optional[str] = None) -> RequestType:
59+
"""Convert EMRServerlessJobConfig object to request dict."""
60+
config = {"executionRoleArn": self.execution_role_arn, "jobDriver": self.job_driver}
61+
if application_id is not None:
62+
config["applicationId"] = application_id
63+
if self.configuration_overrides is not None:
64+
config["configurationOverrides"] = self.configuration_overrides
65+
if self.execution_timeout_minutes is not None:
66+
config["executionTimeoutMinutes"] = self.execution_timeout_minutes
67+
if self.name is not None:
68+
config["name"] = self.name
69+
if self.tags is not None:
70+
config["tags"] = self.tags
71+
return config
72+
73+
74+
ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG = (
75+
"EMRServerlessStep {step_name} cannot have both application_id and application_config. "
76+
"To use EMRServerlessStep with application_config, "
77+
"application_id must be explicitly set to None."
78+
)
79+
80+
ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG = (
81+
"EMRServerlessStep {step_name} must have either application_id or application_config"
82+
)
83+
84+
85+
class EMRServerlessStep(ConfigurableRetryStep):
86+
"""EMR Serverless step for workflow with configurable retry policies."""
87+
88+
def __init__(
89+
self,
90+
name: str,
91+
display_name: str,
92+
description: str,
93+
job_config: EMRServerlessJobConfig,
94+
application_id: Optional[str] = None,
95+
application_config: Optional[Dict[str, Any]] = None,
96+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
97+
cache_config: Optional[CacheConfig] = None,
98+
retry_policies: Optional[List[StepRetryPolicy]] = None,
99+
): # pylint: disable=too-many-positional-arguments
100+
"""Constructs an `EMRServerlessStep`.
101+
102+
Args:
103+
name (str): The name of the EMR Serverless step.
104+
display_name (str): The display name of the EMR Serverless step.
105+
description (str): The description of the EMR Serverless step.
106+
job_config (EMRServerlessJobConfig): Job configuration for the EMR Serverless job.
107+
application_id (str, optional): The ID of the existing EMR Serverless application.
108+
application_config (Dict[str, Any], optional): Configuration for creating a new
109+
EMR Serverless application.
110+
depends_on (List[Union[str, Step, StepCollection]], optional): A list of
111+
`Step`/`StepCollection` names or `Step` instances or `StepCollection` instances
112+
that this `EMRServerlessStep` depends on.
113+
cache_config (CacheConfig, optional): A `sagemaker.workflow.steps.CacheConfig` instance.
114+
retry_policies (List[StepRetryPolicy], optional): A list of retry policies.
115+
"""
116+
super().__init__(
117+
name=name,
118+
step_type=StepTypeEnum.EMR_SERVERLESS,
119+
display_name=display_name,
120+
description=description,
121+
depends_on=depends_on,
122+
retry_policies=retry_policies,
123+
)
124+
125+
if application_id is None and application_config is None:
126+
raise ValueError(ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG.format(step_name=name))
127+
128+
if application_id is not None and application_config is not None:
129+
raise ValueError(ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG.format(step_name=name))
130+
131+
emr_serverless_args = {
132+
"ExecutionRoleArn": job_config.execution_role_arn, # Top-level role (used by backend)
133+
"JobConfig": job_config.to_request(
134+
application_id
135+
), # Role also in JobConfig (structure requirement)
136+
}
137+
138+
if application_id is not None:
139+
emr_serverless_args["ApplicationId"] = application_id
140+
elif application_config is not None:
141+
emr_serverless_args["ApplicationConfig"] = application_config
142+
143+
self.args = emr_serverless_args
144+
self.cache_config = cache_config
145+
146+
root_property = Properties(
147+
step_name=name, step=self, shape_name="GetJobRunResponse", service_name="emr-serverless"
148+
)
149+
self._properties = root_property
150+
151+
@property
152+
def arguments(self) -> RequestType:
153+
"""The arguments dict that is used to call EMR Serverless APIs."""
154+
return self.args
155+
156+
@property
157+
def properties(self) -> RequestType:
158+
"""A Properties object representing the EMR Serverless GetJobRunResponse model."""
159+
return self._properties
160+
161+
def to_request(self) -> RequestType:
162+
"""Updates the dictionary with cache configuration and retry policies."""
163+
request_dict = super().to_request()
164+
if self.cache_config:
165+
request_dict.update(self.cache_config.config)
166+
return request_dict

sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class StepTypeEnum(Enum):
5757
QUALITY_CHECK = "QualityCheck"
5858
CLARIFY_CHECK = "ClarifyCheck"
5959
EMR = "EMR"
60+
EMR_SERVERLESS = "EMRServerless"
6061
FAIL = "Fail"
6162
AUTOML = "AutoML"
6263

@@ -785,4 +786,4 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
785786
self.properties.TrainingJobSummaries[top_k].TrainingJobName,
786787
"output/model.tar.gz",
787788
],
788-
)
789+
)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Unit tests for EMR Serverless step."""
2+
3+
from __future__ import absolute_import
4+
5+
import pytest
6+
from sagemaker.workflow.emr_serverless_step import EMRServerlessStep
7+
from sagemaker.workflow.emr_serverless_step import EMRServerlessJobConfig
8+
9+
10+
class TestEMRServerlessJobConfig:
11+
"""Test EMRServerlessJobConfig class."""
12+
13+
def test_job_config_structure(self):
14+
job_config = EMRServerlessJobConfig(
15+
job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
16+
execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole",
17+
configuration_overrides={
18+
"applicationConfiguration": [
19+
{
20+
"classification": "spark-defaults",
21+
"properties": {"spark.sql.adaptive.enabled": "true"},
22+
}
23+
]
24+
},
25+
)
26+
27+
expected = {
28+
"executionRoleArn": "arn:aws:iam::123456789012:role/EMRServerlessRole",
29+
"jobDriver": {"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
30+
"configurationOverrides": {
31+
"applicationConfiguration": [
32+
{
33+
"classification": "spark-defaults",
34+
"properties": {"spark.sql.adaptive.enabled": "true"},
35+
}
36+
]
37+
},
38+
}
39+
40+
assert job_config.to_request() == expected
41+
42+
43+
class TestEMRServerlessStep:
44+
"""Test EMRServerlessStep class."""
45+
46+
def test_existing_application_step(self):
47+
job_config = EMRServerlessJobConfig(
48+
job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
49+
execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole",
50+
)
51+
52+
step = EMRServerlessStep(
53+
name="test-step",
54+
display_name="Test Step",
55+
description="Test Description",
56+
job_config=job_config,
57+
application_id="app-123",
58+
)
59+
60+
expected_args = {
61+
"ExecutionRoleArn": "arn:aws:iam::123456789012:role/EMRServerlessRole",
62+
"ApplicationId": "app-123",
63+
"JobConfig": {
64+
"applicationId": "app-123",
65+
"executionRoleArn": "arn:aws:iam::123456789012:role/EMRServerlessRole",
66+
"jobDriver": {"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
67+
},
68+
}
69+
70+
assert step.arguments == expected_args
71+
72+
def test_new_application_step(self):
73+
job_config = EMRServerlessJobConfig(
74+
job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
75+
execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole",
76+
)
77+
78+
step = EMRServerlessStep(
79+
name="test-step",
80+
display_name="Test Step",
81+
description="Test Description",
82+
job_config=job_config,
83+
application_config={
84+
"name": "test-application",
85+
"releaseLabel": "emr-6.15.0",
86+
"type": "SPARK",
87+
},
88+
)
89+
90+
expected_args = {
91+
"ExecutionRoleArn": "arn:aws:iam::123456789012:role/EMRServerlessRole",
92+
"ApplicationConfig": {
93+
"name": "test-application",
94+
"releaseLabel": "emr-6.15.0",
95+
"type": "SPARK",
96+
},
97+
"JobConfig": {
98+
"executionRoleArn": "arn:aws:iam::123456789012:role/EMRServerlessRole",
99+
"jobDriver": {"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
100+
},
101+
}
102+
103+
assert step.arguments == expected_args
104+
105+
def test_validation_errors(self):
106+
job_config = EMRServerlessJobConfig(
107+
job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
108+
execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole",
109+
)
110+
111+
# Should raise error when neither provided
112+
with pytest.raises(
113+
ValueError, match="must have either application_id or application_config"
114+
):
115+
EMRServerlessStep(
116+
name="test-step",
117+
display_name="Test Step",
118+
description="Test Description",
119+
job_config=job_config,
120+
)
121+
122+
# Should raise error when both provided
123+
with pytest.raises(
124+
ValueError, match="cannot have both application_id and application_config"
125+
):
126+
EMRServerlessStep(
127+
name="test-step",
128+
display_name="Test Step",
129+
description="Test Description",
130+
job_config=job_config,
131+
application_id="app-123",
132+
application_config={"name": "test-app"},
133+
)
134+
135+
def test_to_request(self):
136+
job_config = EMRServerlessJobConfig(
137+
job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
138+
execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole",
139+
)
140+
141+
step = EMRServerlessStep(
142+
name="test-step",
143+
display_name="Test Step",
144+
description="Test Description",
145+
job_config=job_config,
146+
application_id="app-123",
147+
)
148+
149+
request = step.to_request()
150+
assert request["Name"] == "test-step"
151+
assert request["Type"] == "EMRServerless"
152+
assert "Arguments" in request
153+
assert request["DisplayName"] == "Test Step"
154+
assert request["Description"] == "Test Description"

0 commit comments

Comments
 (0)