diff --git a/03_finetune_swift.ipynb b/03_finetune_swift.ipynb
index da0bbb0..d0bdafb 100644
--- a/03_finetune_swift.ipynb
+++ b/03_finetune_swift.ipynb
@@ -63,7 +63,7 @@
" get_aws_account_id_for_region,\n",
" is_docker_installed,\n",
" is_docker_compose_installed,\n",
- " check_and_enable_docker_access_sagemaker_studio\n",
+ " enable_docker_access_sagemaker_studio\n",
")"
]
},
@@ -340,7 +340,7 @@
"metadata": {},
"outputs": [],
"source": [
- "use_local_mode = True # Set to true to run on local instance\n",
+ "use_local_mode = False # Set to true to run on local instance\n",
"instance_type = \"local_gpu\" if use_local_mode else \"ml.g6.8xlarge\" # \"ml.g6.12xlarge\" "
]
},
@@ -354,7 +354,9 @@
" border-radius: 5px; \n",
" max-width: 100%;\n",
" background: #f0fbff;\">\n",
- " Note: If you run into out of memory errors during training use a larger instance type. For example if you are training a larger model or with more data. If you are using local mode then you will need a GPU on your local machine, for example running inside a SageMaker Studio JuypterLab on a ml.g6.8xlarge instance.\n",
+ " Note: If you run into out of memory errors during training use a larger instance type. For example if you are training a larger model or with more data. \n",
+ "
\n",
+ "If you are using local mode then you will need a GPU on your local machine, for example running inside a SageMaker Studio JuypterLab on a ml.g6.8xlarge instance.\n",
""
]
},
@@ -373,7 +375,7 @@
"metadata": {},
"outputs": [],
"source": [
- "use_spot = True"
+ "use_spot = True # Set to False to use on-demand instances instead of spot instances."
]
},
{
@@ -423,7 +425,7 @@
"metadata": {},
"outputs": [],
"source": [
- "checkpoint_loc = None"
+ "checkpoint_loc = None # Default: None - No checkpointing. Re-run cell above \"Note box\" to set checkpoint_loc "
]
},
{
@@ -724,24 +726,26 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2d40f9f3-ee0d-4ba0-8472-f6bfde86503c",
+ "id": "39563e42-1a02-4c14-af64-68101a055633",
"metadata": {},
"outputs": [],
"source": [
- "if use_local_mode and (not is_docker_installed() or not is_docker_compose_installed()): \n",
- " # we need docker and docker-compose for LocalMode execution\n",
- " !bash docker-artifacts/01_docker_install.sh"
+ "if use_local_mode:\n",
+ " # for local mode we need docker enabled\n",
+ " # if in SageMaker Studio the following function ensures that docker access is enabled\n",
+ " enable_docker_access_sagemaker_studio(session)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "39563e42-1a02-4c14-af64-68101a055633",
+ "id": "2d40f9f3-ee0d-4ba0-8472-f6bfde86503c",
"metadata": {},
"outputs": [],
"source": [
- "# if in SageMaker Studio check if docker access is enabled if it is not enable it\n",
- "check_and_enable_docker_access_sagemaker_studio(use_local_mode, session)"
+ "if use_local_mode and (not is_docker_installed() or not is_docker_compose_installed()): \n",
+ " # we need docker and docker-compose for LocalMode execution\n",
+ " !bash docker-artifacts/01_docker_install.sh"
]
},
{
diff --git a/04_run_batch_inference.ipynb b/04_run_batch_inference.ipynb
index 00e5ddf..c34def0 100644
--- a/04_run_batch_inference.ipynb
+++ b/04_run_batch_inference.ipynb
@@ -65,10 +65,7 @@
" get_sagemaker_distribution, \n",
" SageMakerDistribution, \n",
" get_python_version, \n",
- " get_aws_account_id_for_region,\n",
- " is_docker_installed,\n",
- " is_docker_compose_installed,\n",
- " check_and_enable_docker_access_sagemaker_studio\n",
+ " get_aws_account_id_for_region\n",
")"
]
},
@@ -389,6 +386,30 @@
"sagemaker_dist_uri = f\"{sagemaker_distr_account_id}.dkr.ecr.{region}.amazonaws.com/sagemaker-distribution-prod:{sm_distro_version.image_version}-gpu\""
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "04305c64-7ae5-4ffc-be2f-fe0e6e1a398b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "instance_type = \"ml.g6e.xlarge\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d0bb0ade-103c-4746-bd48-524a0fe54de5",
+ "metadata": {},
+ "source": [
+ "
\n",
+ " Note: If you run into out of memory errors during inference use a larger instance type. For example if you have trained a larger model. \n",
+ "
"
+ ]
+ },
{
"cell_type": "markdown",
"id": "202dfb1b-322c-470d-9191-29490a7ffee6",
@@ -521,7 +542,7 @@
" # RoleArn: \n",
" S3RootUri: {s3_root_uri}\n",
" ImageUri: {sagemaker_dist_uri} \n",
- " InstanceType: ml.g6e.4xlarge\n",
+ " InstanceType: ml.g6e.4xlarge # default instance type to use\n",
" Dependencies: requirements.txt\n",
" IncludeLocalWorkDir: true\n",
" PreExecutionCommands:\n",
@@ -610,6 +631,22 @@
"## Batch Inference Function"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e9050ae0-98eb-48c9-8d76-320a5b4ee86d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "inference_kwargs = {\n",
+ " \"model_id\":model_config.model_id,\n",
+ " \"model_type\":model_config.model_type,\n",
+ " \"dataset_s3\":dataset_s3_uri,\n",
+ " \"test_data_path\":\"conversations_test_swift_format.json\",\n",
+ " \"guided_decoding\":guided_decoding\n",
+ "}"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -618,11 +655,11 @@
"outputs": [],
"source": [
"@remote(\n",
- " instance_type=\"ml.g6e.xlarge\", # Powerful GPU for fast inference\n",
+ " instance_type=instance_type, # Powerful GPU for fast inference\n",
" instance_count=1, # Single instance for cost efficiency\n",
" volume_size=200, # Large volume for model and data storage\n",
" job_name_prefix=job_name_prefix,\n",
- " # use_spot_instances=True, # Cost efficient inference. Inference can be restarted if no spot capacity. \n",
+ " # use_spot_instances=True, # Enable cost efficient inference. Inference can be restarted if no spot capacity. \n",
" max_wait_time_in_seconds=172800, # 48 hours max wait\n",
" max_runtime_in_seconds=172800, # 48 hours max runtime\n",
")\n",
@@ -710,28 +747,12 @@
},
{
"cell_type": "markdown",
- "id": "6eb0b491-7f9f-4b83-98d1-daedabe3f311",
+ "id": "259ae020-c35d-4470-95ce-0043a1b61893",
"metadata": {},
"source": [
"## Run Batch Inference"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e9050ae0-98eb-48c9-8d76-320a5b4ee86d",
- "metadata": {},
- "outputs": [],
- "source": [
- "inference_kwargs = {\n",
- " \"model_id\":model_config.model_id,\n",
- " \"model_type\":model_config.model_type,\n",
- " \"dataset_s3\":dataset_s3_uri,\n",
- " \"test_data_path\":\"conversations_test_swift_format.json\",\n",
- " \"guided_decoding\":guided_decoding\n",
- "}"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
diff --git a/06_deploy_model_endpoint.ipynb b/06_deploy_model_endpoint.ipynb
index 55884b4..8dfe450 100644
--- a/06_deploy_model_endpoint.ipynb
+++ b/06_deploy_model_endpoint.ipynb
@@ -241,7 +241,7 @@
"metadata": {},
"outputs": [],
"source": [
- "REPO_NAME = \"swift-json-vlm-container-finetuned\"\n",
+ "REPO_NAME = \"sagemaker-json-vlm-finetuned\"\n",
"os.environ['REPO_NAME'] = REPO_NAME\n",
"os.environ[\"S3_MODEL_URI\"]=s3_model_uri"
]
diff --git a/utils/training_image.py b/utils/training_image.py
index 1370622..38c99ee 100644
--- a/utils/training_image.py
+++ b/utils/training_image.py
@@ -129,33 +129,32 @@ def load_json_file(file_path):
DOCKER_ENABLED = "ENABLED"
DOCKER_DISABLED = "DISABLED"
-def check_and_enable_docker_access_sagemaker_studio(use_local_mode, session):
- if use_local_mode:
- resource_metadata = load_json_file(SAGEMAKER_STUDIO_METADATA)
- if resource_metadata:
- docker_access_disabled = True
- docker_access = None
- try:
- domain_id = resource_metadata["DomainId"]
- sm_client = session.boto_session.client("sagemaker")
- domain = sm_client.describe_domain(DomainId=domain_id)
- if docker_access_disabled := ((docker_access := domain["DomainSettings"]["DockerSettings"]["EnableDockerAccess"]) == DOCKER_DISABLED):
- print("Docker disabled on SageMaker Studio domain. Trying to enable docker access...")
- sm_client.update_domain(
- DomainId=domain_id,
- DomainSettingsForUpdate={
- 'DockerSettings': {
- 'EnableDockerAccess': DOCKER_ENABLED
- }
+def enable_docker_access_sagemaker_studio(session):
+ resource_metadata = load_json_file(SAGEMAKER_STUDIO_METADATA)
+ if resource_metadata:
+ docker_access_disabled = True
+ docker_access = None
+ try:
+ domain_id = resource_metadata["DomainId"]
+ sm_client = session.boto_session.client("sagemaker")
+ domain = sm_client.describe_domain(DomainId=domain_id)
+ if docker_access_disabled := ((docker_access := domain["DomainSettings"]["DockerSettings"]["EnableDockerAccess"]) == DOCKER_DISABLED):
+ print("Docker disabled on SageMaker Studio domain. Trying to enable docker access...")
+ sm_client.update_domain(
+ DomainId=domain_id,
+ DomainSettingsForUpdate={
+ 'DockerSettings': {
+ 'EnableDockerAccess': DOCKER_ENABLED
}
- )
- time.sleep(4)
- domain = sm_client.describe_domain(DomainId=domain_id)
- docker_access = domain["DomainSettings"]["DockerSettings"]["EnableDockerAccess"]
- docker_access_disabled = (docker_access == DOCKER_DISABLED)
- except Exception as e:
- print(e)
-
- print(f"SageMaker Studio domain ({domain_id}) docker access: {docker_access}")
- if docker_access_disabled:
- print("Failed to enable Docker Access on SageMaker Studio domain. Please enable it manually or ask your administrator. Docker access is required to run in local mode. https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-local-get-started.html#studio-updated-local-enable")
\ No newline at end of file
+ }
+ )
+ time.sleep(4)
+ domain = sm_client.describe_domain(DomainId=domain_id)
+ docker_access = domain["DomainSettings"]["DockerSettings"]["EnableDockerAccess"]
+ docker_access_disabled = (docker_access == DOCKER_DISABLED)
+ except Exception as e:
+ print(e)
+
+ print(f"SageMaker Studio domain ({domain_id}) docker access: {docker_access}")
+ if docker_access_disabled:
+ print("Failed to enable Docker Access on SageMaker Studio domain. Please enable it manually or ask your administrator. Docker access is required to run in local mode. https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-local-get-started.html#studio-updated-local-enable")
\ No newline at end of file