Source code for braket.jobs.local.local_job_container_setup

# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import tempfile
from collections.abc import Iterable
from logging import Logger, getLogger
from pathlib import Path
from typing import Any

from braket.aws.aws_session import AwsSession
from braket.jobs.local.local_job_container import _LocalJobContainer


[docs] def setup_container( container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs: str ) -> dict[str, str]: """Sets up a container with prerequisites for running a Braket Hybrid Job. The prerequisites are based on the options the customer has chosen for the hybrid job. Similarly, any environment variables that are needed during runtime will be returned by this function. Args: container(_LocalJobContainer): The container that will run the braket hybrid job. aws_session (AwsSession): AwsSession for connecting to AWS Services. **creation_kwargs (str): Arbitrary keyword arguments. Returns: dict[str, str]: A dictionary of environment variables that reflect Braket Hybrid Jobs options requested by the customer. """ logger = getLogger(__name__) _create_expected_paths(container, **creation_kwargs) run_environment_variables = {} run_environment_variables |= _get_env_credentials(aws_session, logger) run_environment_variables.update( _get_env_script_mode_config(creation_kwargs["algorithmSpecification"]["scriptModeConfig"]) ) run_environment_variables.update(_get_env_default_vars(aws_session, **creation_kwargs)) if _copy_hyperparameters(container, **creation_kwargs): run_environment_variables.update(_get_env_hyperparameters()) if _copy_input_data_list(container, aws_session, **creation_kwargs): run_environment_variables.update(_get_env_input_data()) return run_environment_variables
def _create_expected_paths(container: _LocalJobContainer, **creation_kwargs: str) -> None: """Creates the basic paths required for Braket Hybrid Jobs to run. Args: container(_LocalJobContainer): The container that will run the braket hybrid job. **creation_kwargs (str): Arbitrary keyword arguments. """ container.makedir("/opt/ml/model") container.makedir(creation_kwargs["checkpointConfig"]["localPath"]) def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> dict[str, str]: """Gets the account credentials from boto so they can be added as environment variables to the running container. Args: aws_session (AwsSession): AwsSession for connecting to AWS Services. logger (Logger): Logger object with which to write logs. Default is `getLogger(__name__)` Returns: dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ credentials = aws_session.boto_session.get_credentials() if credentials.token is None: logger.info("Using the long-lived AWS credentials found in session") return { "AWS_ACCESS_KEY_ID": str(credentials.access_key), "AWS_SECRET_ACCESS_KEY": str(credentials.secret_key), } logger.warning( "Using the short-lived AWS credentials found in session. They might expire while running." ) return { "AWS_ACCESS_KEY_ID": str(credentials.access_key), "AWS_SECRET_ACCESS_KEY": str(credentials.secret_key), "AWS_SESSION_TOKEN": str(credentials.token), } def _get_env_script_mode_config(script_mode_config: dict[str, str]) -> dict[str, str]: """Gets the environment variables related to the customer script mode config. Args: script_mode_config (dict[str, str]): The values for scriptModeConfig in the boto3 input parameters for running a Braket Hybrid Job. Returns: dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ result = { "AMZN_BRAKET_SCRIPT_S3_URI": script_mode_config["s3Uri"], "AMZN_BRAKET_SCRIPT_ENTRY_POINT": script_mode_config["entryPoint"], } if "compressionType" in script_mode_config: result["AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE"] = script_mode_config["compressionType"] return result def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs: str) -> dict[str, str]: """This function gets the remaining 'simple' env variables, that don't require any additional logic to determine what they are or when they should be added as env variables. Args: aws_session (AwsSession): AwsSession for connecting to AWS Services. **creation_kwargs (str): Arbitrary keyword arguments. Returns: dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ job_name = creation_kwargs["jobName"] bucket, location = AwsSession.parse_s3_uri(creation_kwargs["outputDataConfig"]["s3Path"]) return { "AWS_DEFAULT_REGION": aws_session.region, "AMZN_BRAKET_JOB_NAME": job_name, "AMZN_BRAKET_DEVICE_ARN": creation_kwargs["deviceConfig"]["device"], "AMZN_BRAKET_JOB_RESULTS_DIR": "/opt/braket/model", "AMZN_BRAKET_CHECKPOINT_DIR": creation_kwargs["checkpointConfig"]["localPath"], "AMZN_BRAKET_OUT_S3_BUCKET": bucket, "AMZN_BRAKET_TASK_RESULTS_S3_URI": f"s3://{bucket}/jobs/{job_name}/tasks", "AMZN_BRAKET_JOB_RESULTS_S3_PATH": str(Path(location, job_name, "output").as_posix()), } def _get_env_hyperparameters() -> dict[str, str]: """Gets the env variable for hyperparameters. This should only be added if the customer has provided hyperpameters to the hybrid job. Returns: dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ return { "AMZN_BRAKET_HP_FILE": "/opt/braket/input/config/hyperparameters.json", } def _get_env_input_data() -> dict[str, str]: """Gets the env variable for input data. This should only be added if the customer has provided input data to the hybrid job. Returns: dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ return { "AMZN_BRAKET_INPUT_DIR": "/opt/braket/input/data", } def _copy_hyperparameters(container: _LocalJobContainer, **creation_kwargs: str) -> bool: """If hyperpameters are present, this function will store them as a JSON object in the container in the appropriate location on disk. Args: container(_LocalJobContainer): The container to save hyperparameters to. **creation_kwargs (str): Arbitrary keyword arguments. Returns: bool: True if any hyperparameters were copied to the container. """ if "hyperParameters" not in creation_kwargs: return False hyperparameters = creation_kwargs["hyperParameters"] with tempfile.TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "hyperparameters.json") with open(file_path, "w") as write_file: json.dump(hyperparameters, write_file) container.copy_to(str(file_path), "/opt/ml/input/config/hyperparameters.json") return True def _download_input_data( aws_session: AwsSession, download_dir: str, input_data: dict[str, Any], ) -> None: """Downloads input data for a hybrid job. Args: aws_session (AwsSession): AwsSession for connecting to AWS Services. download_dir (str): The directory path to download to. input_data (dict[str, Any]): One of the input data in the boto3 input parameters for running a Braket Hybrid Job. Raises: ValueError: File already exists. RuntimeError: The item is not found. """ # If s3 prefix is the full name of a directory and all keys are inside # that directory, the contents of said directory will be copied into a # directory with the same name as the channel. This behavior is the same # whether or not s3 prefix ends with a "/". Moreover, if s3 prefix ends # with a "/", this is certainly the behavior to expect, since it can only # match a directory. # If s3 prefix matches any files exactly, or matches as a prefix of any # files or directories, then all files and directories matching s3 prefix # will be copied into a directory with the same name as the channel. channel_name = input_data["channelName"] s3_uri_prefix = input_data["dataSource"]["s3DataSource"]["s3Uri"] bucket, prefix = AwsSession.parse_s3_uri(s3_uri_prefix) s3_keys = aws_session.list_keys(bucket, prefix) top_level = prefix if _is_dir(prefix, s3_keys) else str(Path(prefix).parent) found_item = False try: Path(download_dir, channel_name).mkdir() except FileExistsError as e: raise ValueError( f"Duplicate channel names not allowed for input data: {channel_name}" ) from e for s3_key in s3_keys: relative_key = Path(s3_key).relative_to(top_level) download_path = Path(download_dir, channel_name, relative_key) if not s3_key.endswith("/"): download_path.parent.mkdir(parents=True, exist_ok=True) aws_session.download_from_s3( AwsSession.construct_s3_uri(bucket, s3_key), str(download_path) ) found_item = True if not found_item: raise RuntimeError(f"No data found for channel '{channel_name}'") def _is_dir(prefix: str, keys: Iterable[str]) -> bool: """Determine whether the prefix refers to a directory. Args: prefix (str): The prefix to check. keys (Iterable[str]): The set of paths to check. Returns: bool: True if the prefix refers to a directory. """ if prefix.endswith("/"): return True return all(key.startswith(f"{prefix}/") for key in keys) def _copy_input_data_list( container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs: str ) -> bool: """If the input data list is not empty, this function will download the input files and store them in the container. Args: container (_LocalJobContainer): The container to save input data to. aws_session (AwsSession): AwsSession for connecting to AWS Services. **creation_kwargs (str): Arbitrary keyword arguments. Returns: bool: True if any input data was copied to the container. """ if "inputDataConfig" not in creation_kwargs: return False input_data_list = creation_kwargs["inputDataConfig"] with tempfile.TemporaryDirectory() as temp_dir: for input_data in input_data_list: _download_input_data(aws_session, temp_dir, input_data) container.copy_to(temp_dir, "/opt/ml/input/data/") return bool(input_data_list)