Source code for braket.jobs.image_uris

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import os
from enum import Enum
from functools import cache


[docs] class Framework(str, Enum): """Supported Frameworks for pre-built containers""" BASE = "BASE" PL_TENSORFLOW = "PL_TENSORFLOW" PL_PYTORCH = "PL_PYTORCH"
[docs] def built_in_images(region: str) -> set[str]: """Checks a region for built in Braket images. Args: region (str): The AWS region to check for images Returns: set[str]: returns a set of built images """ return {retrieve_image(framework, region) for framework in Framework}
[docs] @cache def retrieve_image(framework: Framework, region: str) -> str: """Retrieves the ECR URI for the Docker image matching the specified arguments. Args: framework (Framework): The name of the framework. region (str): The AWS region for the Docker image. Returns: str: The ECR URI for the corresponding Amazon Braket Docker image. Raises: ValueError: If any of the supplied values are invalid or the combination of inputs specified is not supported. """ # Validate framework framework = Framework(framework) config = _config_for_framework(framework) registry = _registry_for_region(config, region) tag = f"{config['repository']}:latest" return f"{registry}.dkr.ecr.{region}.amazonaws.com/{tag}"
def _config_for_framework(framework: Framework) -> dict[str, str]: """Loads the JSON config for the given framework. Args: framework (Framework): The framework whose config needs to be loaded. Returns: dict[str, str]: Dict that contains the configuration for the specified framework. """ fname = os.path.join(os.path.dirname(__file__), "image_uri_config", f"{framework.lower()}.json") with open(fname) as f: return json.load(f) def _registry_for_region(config: dict[str, str], region: str) -> str: """Retrieves the registry for the specified region from the configuration. Args: config (dict[str, str]): Dict containing the framework configuration. region (str): str that specifies the region for which the registry is retrieved. Returns: str: str that specifies the registry for the supplied region. Raises: ValueError: If the supplied region is invalid or not supported. """ if region not in (supported_regions := config["supported_regions"]): raise ValueError( f"Unsupported region: {region}. You may need to upgrade your SDK version for newer " f"regions. Supported region(s): {supported_regions}" ) return config["registry"]