Source code for braket.aws.aws_session

# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

import itertools
import os
import os.path
import re
from functools import cache
from pathlib import Path
from typing import Any, NamedTuple, Optional

import backoff
import boto3
from botocore import awsrequest, client
from botocore.config import Config
from botocore.exceptions import ClientError

import braket._schemas as braket_schemas
import braket._sdk as braket_sdk
from braket.tracking.tracking_context import active_trackers, broadcast_event
from braket.tracking.tracking_events import _TaskCreationEvent, _TaskStatusEvent


[docs] class AwsSession: """Manage interactions with AWS services."""
[docs] class S3DestinationFolder(NamedTuple): """A `NamedTuple` for an S3 bucket and object key.""" bucket: str key: str
def __init__( self, boto_session: boto3.Session | None = None, braket_client: client | None = None, config: Config | None = None, default_bucket: str | None = None, ): """Initializes an `AwsSession`. Args: boto_session (boto3.Session | None): A boto3 session object. braket_client (client | None): A boto3 Braket client. config (Config | None): A botocore Config object. default_bucket (str | None): The name of the default bucket of the AWS Session. Raises: ValueError: invalid boto_session or braket_client. """ if ( boto_session and braket_client and boto_session.region_name != braket_client.meta.region_name ): raise ValueError( "Boto Session region and Braket Client region must match and currently " f"they do not: Boto Session region is '{boto_session.region_name}', but " f"Braket Client region is '{braket_client.meta.region_name}'." ) self._config = config if braket_client: self.boto_session = boto_session or boto3.Session( region_name=braket_client.meta.region_name ) self.braket_client = braket_client else: self.boto_session = boto_session or boto3.Session( region_name=os.environ.get("AWS_REGION") ) self.braket_client = self.boto_session.client( "braket", config=self._config, endpoint_url=os.environ.get("BRAKET_ENDPOINT") ) self._update_user_agent() self._custom_default_bucket = bool(default_bucket) self._default_bucket = default_bucket or os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") self.braket_client.meta.events.register( "before-sign.braket.CreateQuantumTask", self._add_cost_tracker_count_handler ) self.braket_client.meta.events.register( "before-sign.braket", self._add_braket_user_agents_handler ) self._iam = None self._s3 = None self._sts = None self._logs = None self._ecr = None self._account_id = None @property def region(self) -> str: return self.boto_session.region_name @property def account_id(self) -> str: """Gets the caller's account number. Returns: str: The account number of the caller. """ if not self._account_id: self._account_id = self.sts_client.get_caller_identity()["Account"] return self._account_id @property def iam_client(self) -> client: """Gets the IAM client. Returns: client: The IAM Client. """ if not self._iam: self._iam = self.boto_session.client("iam", region_name=self.region) return self._iam @property def s3_client(self) -> client: """Gets the S3 client. Returns: client: The S3 Client. """ if not self._s3: self._s3 = self.boto_session.client("s3", region_name=self.region) return self._s3 @property def sts_client(self) -> client: """Gets the STS client. Returns: client: The STS Client. """ if not self._sts: self._sts = self.boto_session.client("sts", region_name=self.region) return self._sts @property def logs_client(self) -> client: """Gets the CloudWatch logs client. Returns: client: The CloudWatch logs Client. """ if not self._logs: self._logs = self.boto_session.client("logs", region_name=self.region) return self._logs @property def ecr_client(self) -> client: """Gets the ECR client. Returns: client: The ECR Client. """ if not self._ecr: self._ecr = self.boto_session.client("ecr", region_name=self.region) return self._ecr def _update_user_agent(self) -> None: """Updates the `User-Agent` header forwarded by boto3 to include the braket-sdk, braket-schemas and the notebook instance version. The header is a string of space delimited values (For example: "Boto3/1.14.43 Python/3.7.9 Botocore/1.17.44"). """ def _notebook_instance_version() -> str: # TODO: Replace with lifecycle configuration version once we have a way to access those nbi_metadata_path = "/opt/ml/metadata/resource-metadata.json" return "0" if os.path.exists(nbi_metadata_path) else "None" self._braket_user_agents = ( f"BraketSdk/{braket_sdk.__version__} " f"BraketSchemas/{braket_schemas.__version__} " f"NotebookInstance/{_notebook_instance_version()}" )
[docs] def add_braket_user_agent(self, user_agent: str) -> None: """Appends the `user-agent` value to the User-Agent header, if it does not yet exist in the header. This method is typically only relevant for libraries integrating with the Amazon Braket SDK. Args: user_agent (str): The user_agent value to append to the header. """ if user_agent not in self._braket_user_agents: self._braket_user_agents = f"{self._braket_user_agents} {user_agent}"
def _add_braket_user_agents_handler(self, request: awsrequest.AWSRequest, **kwargs) -> None: try: initial_user_agent = request.headers["User-Agent"] request.headers.replace_header( "User-Agent", f"{initial_user_agent} {self._braket_user_agents}" ) except KeyError: request.headers.add_header("User-Agent", self._braket_user_agents) @staticmethod def _add_cost_tracker_count_handler(request: awsrequest.AWSRequest, **kwargs) -> None: request.headers.add_header("Braket-Trackers", str(len(active_trackers()))) # # Quantum Tasks #
[docs] def cancel_quantum_task(self, arn: str) -> None: """Cancel the quantum task. Args: arn (str): The ARN of the quantum task to cancel. """ response = self.braket_client.cancel_quantum_task(quantumTaskArn=arn) broadcast_event(_TaskStatusEvent(arn=arn, status=response["cancellationStatus"]))
[docs] def create_quantum_task(self, **boto3_kwargs) -> str: """Create a quantum task. Args: **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateQuantumTask` operation. Returns: str: The ARN of the quantum task. """ # Add job token to request, if available. job_token = os.getenv("AMZN_BRAKET_JOB_TOKEN") if job_token: boto3_kwargs["jobToken"] = job_token response = self.braket_client.create_quantum_task(**boto3_kwargs) broadcast_event( _TaskCreationEvent( arn=response["quantumTaskArn"], shots=boto3_kwargs["shots"], is_job_task=(job_token is not None), device=boto3_kwargs["deviceArn"], ) ) return response["quantumTaskArn"]
[docs] def create_job(self, **boto3_kwargs) -> str: """Create a quantum hybrid job. Args: **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateJob` operation. Returns: str: The ARN of the hybrid job. """ response = self.braket_client.create_job(**boto3_kwargs) return response["jobArn"]
@staticmethod def _should_giveup(err: Exception) -> bool: return not ( isinstance(err, ClientError) and err.response["Error"]["Code"] in [ "ResourceNotFoundException", "ThrottlingException", ] )
[docs] @backoff.on_exception( backoff.expo, ClientError, max_tries=3, jitter=backoff.full_jitter, giveup=_should_giveup.__func__, ) def get_quantum_task(self, arn: str) -> dict[str, Any]: """Gets the quantum task. Args: arn (str): The ARN of the quantum task to get. Returns: dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. """ response = self.braket_client.get_quantum_task( quantumTaskArn=arn, additionalAttributeNames=["QueueInfo"] ) broadcast_event(_TaskStatusEvent(arn=response["quantumTaskArn"], status=response["status"])) return response
[docs] def get_default_jobs_role(self) -> str: """This returns the role ARN for the default hybrid jobs role created in the Amazon Braket Console. It will pick the first role it finds with the `RoleName` prefix `AmazonBraketJobsExecutionRole` with a `PathPrefix` of `/service-role/`. Returns: str: The ARN for the default IAM role for jobs execution created in the Amazon Braket console. Raises: RuntimeError: If no roles can be found with the prefix `/service-role/AmazonBraketJobsExecutionRole`. """ roles_paginator = self.iam_client.get_paginator("list_roles") for page in roles_paginator.paginate(PathPrefix="/service-role/"): for role in page.get("Roles", []): if role["RoleName"].startswith("AmazonBraketJobsExecutionRole"): return role["Arn"] raise RuntimeError( "No default jobs roles found. Please create a role using the " "Amazon Braket console or supply a custom role." )
[docs] @backoff.on_exception( backoff.expo, ClientError, max_tries=3, jitter=backoff.full_jitter, giveup=_should_giveup.__func__, ) def get_job(self, arn: str) -> dict[str, Any]: """Gets the hybrid job. Args: arn (str): The ARN of the hybrid job to get. Returns: dict[str, Any]: The response from the Amazon Braket `GetQuantumJob` operation. """ return self.braket_client.get_job(jobArn=arn, additionalAttributeNames=["QueueInfo"])
[docs] def cancel_job(self, arn: str) -> dict[str, Any]: """Cancel the hybrid job. Args: arn (str): The ARN of the hybrid job to cancel. Returns: dict[str, Any]: The response from the Amazon Braket `CancelJob` operation. """ return self.braket_client.cancel_job(jobArn=arn)
[docs] def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str: """Retrieve the S3 object body. Args: s3_bucket (str): The S3 bucket name. s3_object_key (str): The S3 object key within the `s3_bucket`. Returns: str: The body of the S3 object. """ s3 = self.boto_session.resource("s3", config=self._config) obj = s3.Object(s3_bucket, s3_object_key) return obj.get()["Body"].read().decode("utf-8")
[docs] def upload_to_s3(self, filename: str, s3_uri: str) -> None: """Upload file to S3. Args: filename (str): local file to be uploaded. s3_uri (str): The S3 URI where the file will be uploaded. """ bucket, key = self.parse_s3_uri(s3_uri) self.s3_client.upload_file(filename, bucket, key)
[docs] def upload_local_data(self, local_prefix: str, s3_prefix: str) -> None: """Upload local data matching a prefix to a corresponding location in S3 Args: local_prefix (str): a prefix designating files to be uploaded to S3. All files beginning with local_prefix will be uploaded. s3_prefix (str): the corresponding S3 prefix that will replace the local prefix when the data is uploaded. This will be an S3 URI and should include the bucket (i.e. 's3://my-bucket/my/prefix-') Example: local_prefix = "input", s3_prefix = "s3://my-bucket/dir/input" will upload: - 'input.csv' to 's3://my-bucket/dir/input.csv' - 'input-2.csv' to 's3://my-bucket/dir/input-2.csv' - 'input/data.txt' to 's3://my-bucket/dir/input/data.txt' - 'input-dir/data.csv' to 's3://my-bucket/dir/input-dir/data.csv' but will not upload: - 'my-input.csv' - 'my-dir/input.csv' To match all files within the directory "input" and upload them into "s3://my-bucket/input", provide local_prefix = "input/" and s3_prefix = "s3://my-bucket/input/" """ # support absolute paths if Path(local_prefix).is_absolute(): base_dir = Path(Path(local_prefix).anchor) relative_prefix = str(Path(local_prefix).relative_to(base_dir)) else: base_dir = Path() relative_prefix = local_prefix for file in itertools.chain( # files that match the prefix base_dir.glob(f"{relative_prefix}*"), # files inside of directories that match the prefix base_dir.glob(f"{relative_prefix}*/**/*"), ): if file.is_file(): s3_uri = str(file.as_posix()).replace(str(Path(local_prefix).as_posix()), s3_prefix) self.upload_to_s3(str(file), s3_uri)
[docs] def download_from_s3(self, s3_uri: str, filename: str) -> None: """Download file from S3 Args: s3_uri (str): The S3 uri from where the file will be downloaded. filename (str): filename to save the file to. """ bucket, key = self.parse_s3_uri(s3_uri) self.s3_client.download_file(bucket, key, filename)
[docs] def copy_s3_object(self, source_s3_uri: str, destination_s3_uri: str) -> None: """Copy object from another location in s3. Does nothing if source and destination URIs are the same. Args: source_s3_uri (str): S3 URI pointing to the object to be copied. destination_s3_uri (str): S3 URI where the object will be copied to. """ if source_s3_uri == destination_s3_uri: return source_bucket, source_key = self.parse_s3_uri(source_s3_uri) destination_bucket, destination_key = self.parse_s3_uri(destination_s3_uri) self.s3_client.copy( { "Bucket": source_bucket, "Key": source_key, }, destination_bucket, destination_key, )
[docs] def copy_s3_directory(self, source_s3_path: str, destination_s3_path: str) -> None: """Copy all objects from a specified directory in S3. Does nothing if source and destination URIs are the same. Preserves nesting structure, will not overwrite other files in the destination location unless they share a name with a file being copied. Args: source_s3_path (str): S3 URI pointing to the directory to be copied. destination_s3_path (str): S3 URI where the contents of the source_s3_path directory will be copied to. """ if source_s3_path == destination_s3_path: return source_bucket, source_prefix = AwsSession.parse_s3_uri(source_s3_path) destination_bucket, destination_prefix = AwsSession.parse_s3_uri(destination_s3_path) source_keys = self.list_keys(source_bucket, source_prefix) for key in source_keys: self.s3_client.copy( { "Bucket": source_bucket, "Key": key, }, destination_bucket, key.replace(source_prefix, destination_prefix, 1), )
[docs] def list_keys(self, bucket: str, prefix: str) -> list[str]: """Lists keys matching prefix in bucket. Args: bucket (str): Bucket to be queried. prefix (str): The S3 path prefix to be matched Returns: list[str]: A list of all keys matching the prefix in the bucket. """ list_objects = self.s3_client.list_objects_v2( Bucket=bucket, Prefix=prefix, ) keys = [obj["Key"] for obj in list_objects["Contents"]] while list_objects["IsTruncated"]: list_objects = self.s3_client.list_objects_v2( Bucket=bucket, Prefix=prefix, ContinuationToken=list_objects["NextContinuationToken"], ) keys += [obj["Key"] for obj in list_objects["Contents"]] return keys
[docs] def default_bucket(self) -> str: """Returns the name of the default bucket of the AWS Session. In the following order of priority, it will return either the parameter `default_bucket` set during initialization of the AwsSession (if not None), the bucket being used by the currently running Braket Hybrid Job (if evoked inside of a Braket Hybrid Job), or a default value of "amazon-braket-<aws account id>-<aws session region>. Except in the case of a user- specified bucket name, this method will create the default bucket if it does not exist. Returns: str: Name of the default bucket. """ if self._default_bucket: return self._default_bucket default_bucket = f"amazon-braket-{self.region}-{self.account_id}" self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=self.region) self._default_bucket = default_bucket return self._default_bucket
def _create_s3_bucket_if_it_does_not_exist(self, bucket_name: str, region: str) -> None: """Creates an S3 Bucket if it does not exist. Also swallows a few common exceptions that indicate that the bucket already exists or that it is being created. Args: bucket_name (str): Name of the S3 bucket to be created. region (str): The region in which to create the bucket. Raises: botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket creation. If the exception is due to the bucket already existing or already being created, no exception is raised. """ try: if region == "us-east-1": # 'us-east-1' cannot be specified because it is the default region: # https://github.com/boto/boto3/issues/125 self.s3_client.create_bucket(Bucket=bucket_name) else: self.s3_client.create_bucket( Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} ) self.s3_client.put_public_access_block( Bucket=bucket_name, PublicAccessBlockConfiguration={ "BlockPublicAcls": True, "IgnorePublicAcls": True, "BlockPublicPolicy": True, "RestrictPublicBuckets": True, }, ) self.s3_client.put_bucket_policy( Bucket=bucket_name, Policy=f"""{{ "Version": "2012-10-17", "Statement": [ {{ "Effect": "Allow", "Principal": {{ "Service": [ "braket.amazonaws.com" ] }}, "Action": "s3:*", "Resource": [ "arn:aws:s3:::{bucket_name}", "arn:aws:s3:::{bucket_name}/*" ] }} ] }}""", ) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] if ( error_code == "BucketAlreadyOwnedByYou" or error_code != "BucketAlreadyExists" and error_code == "OperationAborted" and "conflicting conditional operation" in message ): pass elif error_code == "BucketAlreadyExists": raise ValueError( f"Provided default bucket '{bucket_name}' already exists " f"for another account. Please supply alternative " f"bucket name via AwsSession constructor `AwsSession()`." ) from None else: raise
[docs] def get_device(self, arn: str) -> dict[str, Any]: """Calls the Amazon Braket `get_device` API to retrieve device metadata. Args: arn (str): The ARN of the device. Returns: dict[str, Any]: The response from the Amazon Braket `GetDevice` operation. """ return self.braket_client.get_device(deviceArn=arn)
[docs] def search_devices( self, arns: Optional[list[str]] = None, names: Optional[list[str]] = None, types: Optional[list[str]] = None, statuses: Optional[list[str]] = None, provider_names: Optional[list[str]] = None, ) -> list[dict[str, Any]]: """Get devices based on filters. The result is the AND of all the filters `arns`, `names`, `types`, `statuses`, `provider_names`. Args: arns (Optional[list[str]]): device ARN filter, default is `None`. names (Optional[list[str]]): device name filter, default is `None`. types (Optional[list[str]]): device type filter, default is `None`. statuses (Optional[list[str]]): device status filter, default is `None`. When `None` is used, RETIRED devices will not be returned. To include RETIRED devices in the results, use a filter that includes "RETIRED" for this parameter. provider_names (Optional[list[str]]): provider name list, default is `None`. Returns: list[dict[str, Any]]: The response from the Amazon Braket `SearchDevices` operation. """ filters = [] if arns: filters.append({"name": "deviceArn", "values": arns}) paginator = self.braket_client.get_paginator("search_devices") page_iterator = paginator.paginate(filters=filters, PaginationConfig={"MaxItems": 100}) results = [] for page in page_iterator: for result in page["devices"]: if names and result["deviceName"] not in names: continue if types and result["deviceType"] not in types: continue if statuses and result["deviceStatus"] not in statuses: continue if statuses is None and result["deviceStatus"] == "RETIRED": continue if provider_names and result["providerName"] not in provider_names: continue results.append(result) return results
[docs] @staticmethod def is_s3_uri(string: str) -> bool: """Determines if a given string is an S3 URI. Args: string (str): the string to check. Returns: bool: Returns True if the given string is an S3 URI. """ try: AwsSession.parse_s3_uri(string) except ValueError: return False return True
[docs] @staticmethod def parse_s3_uri(s3_uri: str) -> tuple[str, str]: """Parse S3 URI to get bucket and key Args: s3_uri (str): S3 URI. Returns: tuple[str, str]: Bucket and Key tuple. Raises: ValueError: Raises a ValueError if the provided string is not a valid S3 URI. """ try: # Object URL e.g. https://my-bucket.s3.us-west-2.amazonaws.com/my/key # S3 URI e.g. s3://my-bucket/my/key s3_uri_match = re.match(r"^https://([^./]+)\.[sS]3\.[^/]+/(.+)$", s3_uri) or re.match( r"^[sS]3://([^./]+)/(.+)$", s3_uri ) if s3_uri_match is None: raise AssertionError bucket, key = s3_uri_match.groups() return bucket, key except (AssertionError, ValueError) as e: raise ValueError(f"Not a valid S3 uri: {s3_uri}") from e
[docs] @staticmethod def construct_s3_uri(bucket: str, *dirs: str) -> str: """Create an S3 URI given a bucket and path. Args: bucket (str): S3 URI. *dirs (str): directories to be appended in the resulting S3 URI Returns: str: S3 URI Raises: ValueError: Raises a ValueError if the provided arguments are not valid to generate an S3 URI """ if not dirs: raise ValueError(f"Not a valid S3 location: s3://{bucket}") return f"s3://{bucket}/{'/'.join(dirs)}"
[docs] def describe_log_streams( self, log_group: str, log_stream_prefix: str, limit: Optional[int] = None, next_token: Optional[str] = None, ) -> dict[str, Any]: """Describes CloudWatch log streams in a log group with a given prefix. Args: log_group (str): Name of the log group. log_stream_prefix (str): Prefix for log streams to include. limit (Optional[int]): Limit for number of log streams returned. default is 50. next_token (Optional[str]): The token for the next set of items to return. Would have been received in a previous call. Returns: dict[str, Any]: Dictionary containing logStreams and nextToken """ log_stream_args = { "logGroupName": log_group, "logStreamNamePrefix": log_stream_prefix, "orderBy": "LogStreamName", } if limit: log_stream_args["limit"] = limit if next_token: log_stream_args["nextToken"] = next_token return self.logs_client.describe_log_streams(**log_stream_args)
[docs] def get_log_events( self, log_group: str, log_stream: str, start_time: int, start_from_head: bool = True, next_token: Optional[str] = None, ) -> dict[str, Any]: """Gets CloudWatch log events from a given log stream. Args: log_group (str): Name of the log group. log_stream (str): Name of the log stream. start_time (int): Timestamp that indicates a start time to include log events. start_from_head (bool): Bool indicating to return oldest events first. default is True. next_token (Optional[str]): The token for the next set of items to return. Would have been received in a previous call. Returns: dict[str, Any]: Dictionary containing events, nextForwardToken, and nextBackwardToken """ log_events_args = { "logGroupName": log_group, "logStreamName": log_stream, "startTime": start_time, "startFromHead": start_from_head, } if next_token: log_events_args["nextToken"] = next_token return self.logs_client.get_log_events(**log_events_args)
[docs] def copy_session( self, region: Optional[str] = None, max_connections: Optional[int] = None, ) -> AwsSession: """Creates a new AwsSession based on the region. Args: region (Optional[str]): Name of the region. Default = `None`. max_connections (Optional[int]): The maximum number of connections in the Boto3 connection pool. Default = `None`. Returns: AwsSession: based on the region and boto config parameters. """ config = Config(max_pool_connections=max_connections) if max_connections else None session_region = self.boto_session.region_name new_region = region or session_region # note that this method does not copy a custom Braket endpoint URL, since those are # region-specific. If you have an endpoint that you wish to be used by copied AwsSessions # (i.e. for task batching), please use the `BRAKET_ENDPOINT` environment variable. creds = self.boto_session.get_credentials() default_bucket = self._default_bucket if self._custom_default_bucket else None profile_name = self.boto_session.profile_name profile_name = profile_name if profile_name != "default" else None if creds.method == "explicit": boto_session = boto3.Session( aws_access_key_id=creds.access_key, aws_secret_access_key=creds.secret_key, aws_session_token=creds.token, region_name=new_region, profile_name=profile_name, ) elif creds.method == "env": boto_session = boto3.Session(region_name=new_region) else: boto_session = boto3.Session( region_name=new_region, profile_name=profile_name, ) copied_session = AwsSession( boto_session=boto_session, config=config, default_bucket=default_bucket ) # Preserve user_agent information copied_session._braket_user_agents = self._braket_user_agents return copied_session
[docs] @cache def get_full_image_tag(self, image_uri: str) -> str: """Get verbose image tag from image uri. Args: image_uri (str): Image uri to get tag for. Returns: str: Verbose image tag for given image. """ registry = image_uri.split(".")[0] repository, tag = image_uri.split("/")[-1].split(":") # get image digest of latest image digest = self.ecr_client.batch_get_image( registryId=registry, repositoryName=repository, imageIds=[{"imageTag": tag}], )["images"][0]["imageId"]["imageDigest"] # get all images matching digest (same image, different tags) images = self.ecr_client.batch_get_image( registryId=registry, repositoryName=repository, imageIds=[{"imageDigest": digest}], )["images"] # find the tag with the python version info for image in images: if re.search(r"py\d\d+", tag := image["imageId"]["imageTag"]): return tag raise ValueError("Full image tag missing.")