# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
import functools
import importlib.util
import inspect
import re
import shutil
import sys
import tempfile
import warnings
from collections.abc import Callable, Iterable
from logging import Logger, getLogger
from pathlib import Path
from types import ModuleType
from typing import Any
import cloudpickle
from braket.aws.aws_session import AwsSession
from braket.jobs._entry_point_template import run_entry_point, symlink_input_data
from braket.jobs.config import (
CheckpointConfig,
InstanceConfig,
OutputDataConfig,
S3DataSourceConfig,
StoppingCondition,
)
from braket.jobs.image_uris import Framework, built_in_images, retrieve_image
from braket.jobs.quantum_job import QuantumJob
from braket.jobs.quantum_job_creation import _generate_default_job_name
[docs]
def hybrid_job(
*,
device: str | None,
include_modules: str | ModuleType | Iterable[str | ModuleType] | None = None,
dependencies: str | Path | list[str] | None = None,
local: bool = False,
job_name: str | None = None,
image_uri: str | None = None,
input_data: str | dict | S3DataSourceConfig | None = None,
wait_until_complete: bool = False,
instance_config: InstanceConfig | None = None,
distribution: str | None = None,
copy_checkpoints_from_job: str | None = None,
checkpoint_config: CheckpointConfig | None = None,
role_arn: str | None = None,
stopping_condition: StoppingCondition | None = None,
output_data_config: OutputDataConfig | None = None,
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
quiet: bool | None = None,
reservation_arn: str | None = None,
) -> Callable:
"""Defines a hybrid job by decorating the entry point function. The job will be created
when the decorated function is called.
The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an
`AwsQuantumJob`. The following parameters will be ignored when running a job with
`local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`,
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`.
Args:
device (str | None): Device ARN of the QPU device that receives priority quantum
task queueing once the hybrid job begins running. Each QPU has a separate hybrid jobs
queue so that only one hybrid job is running at a time. The device string is accessible
in the hybrid job instance as the environment variable "AMZN_BRAKET_DEVICE_ARN".
When using embedded simulators, you may provide the device argument as string of the
form: "local:<provider>/<simulator_name>" or `None`.
include_modules (str | ModuleType | Iterable[str | ModuleType] | None): Either a
single module or module name or a list of module or module names referring to local
modules to be included. Any references to members of these modules in the hybrid job
algorithm code will be serialized as part of the algorithm code. Default: `[]`
dependencies (str | Path | list[str] | None): Path (absolute or relative) to a
requirements.txt file, or alternatively a list of strings, with each string being a
`requirement specifier <https://pip.pypa.io/en/stable/reference/requirement-specifiers/
#requirement-specifiers>`_, to be used for the hybrid job.
local (bool): Whether to use local mode for the hybrid job. Default: `False`
job_name (str | None): A string that specifies the name with which the job is created.
Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to
f'{decorated-function-name}-{timestamp}'.
image_uri (str | None): A str that specifies the ECR image to use for executing the job.
`retrieve_image()` function may be used for retrieving the ECR image URIs
for the containers supported by Braket. Default: `<Braket base image_uri>`.
input_data (str | dict | S3DataSourceConfig | None): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}'. If a local
path, S3 URI, or S3DataSourceConfig is provided, it will be given a default
channel name "input".
Default: {}.
wait_until_complete (bool): `True` if we should wait until the job completes.
This would tail the job logs as it waits. Otherwise `False`. Ignored if using
local mode. Default: `False`.
instance_config (InstanceConfig | None): Configuration of the instance(s) for running the
classical code for the hybrid job. Default:
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.
distribution (str | None): A str that specifies how the job should be distributed.
If set to "data_parallel", the hyperparameters for the job will be set to use data
parallelism features for PyTorch or TensorFlow. Default: `None`.
copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose
checkpoint you want to use in the current job. Specifying this value will copy
over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config
s3Uri to the current job's checkpoint_config s3Uri, making it available at
checkpoint_config.localPath during the job execution. Default: `None`
checkpoint_config (CheckpointConfig | None): Configuration that specifies the
location where checkpoint data is stored.
Default: `CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints')`.
role_arn (str | None): A str providing the IAM role ARN used to execute the
script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`.
stopping_condition (StoppingCondition | None): The maximum length of time, in seconds,
and the maximum number of tasks that a job can run before being forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).
output_data_config (OutputDataConfig | None): Specifies the location for the output of
the job.
Default: `OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
kmsKeyId=None)`.
aws_session (AwsSession | None): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job.
Default: {}.
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default: `getLogger(__name__)`
quiet (bool | None): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
reservation_arn (str | None): the reservation window arn provided by Braket
Direct to reserve exclusive usage for the device to run the hybrid job on.
Default: None.
Returns:
Callable: the callable for creating a Hybrid Job.
"""
_validate_python_version(image_uri, aws_session)
def _hybrid_job(entry_point: Callable) -> Callable:
@functools.wraps(entry_point)
def job_wrapper(*args: Any, **kwargs: Any) -> Callable:
"""The job wrapper.
Args:
*args (Any): Arbitrary arguments.
**kwargs (Any): Arbitrary keyword arguments.
Returns:
Callable: the callable for creating a Hybrid Job.
"""
with _IncludeModules(include_modules), tempfile.TemporaryDirectory(
dir="", prefix="decorator_job_"
) as temp_dir:
temp_dir_path = Path(temp_dir)
entry_point_file_path = Path("entry_point.py")
with open(temp_dir_path / entry_point_file_path, "w") as entry_point_file:
template = "\n".join(
[
_process_input_data(input_data),
_serialize_entry_point(entry_point, args, kwargs),
]
)
entry_point_file.write(template)
if dependencies:
_process_dependencies(dependencies, temp_dir_path)
job_args = {
"device": device or "local:none/none",
"source_module": temp_dir,
"entry_point": (
f"{temp_dir}.{entry_point_file_path.stem}:{entry_point.__name__}"
),
"wait_until_complete": wait_until_complete,
"job_name": job_name or _generate_default_job_name(func=entry_point),
"hyperparameters": _log_hyperparameters(entry_point, args, kwargs),
"logger": logger,
}
optional_args = {
"image_uri": image_uri,
"input_data": input_data,
"instance_config": instance_config,
"distribution": distribution,
"checkpoint_config": checkpoint_config,
"copy_checkpoints_from_job": copy_checkpoints_from_job,
"role_arn": role_arn,
"stopping_condition": stopping_condition,
"output_data_config": output_data_config,
"aws_session": aws_session,
"tags": tags,
"quiet": quiet,
"reservation_arn": reservation_arn,
}
for key, value in optional_args.items():
if value is not None:
job_args[key] = value
job = _create_job(job_args, local)
return job
return job_wrapper
return _hybrid_job
def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None) -> None:
"""Validate python version at job definition time"""
aws_session = aws_session or AwsSession()
# user provides a custom image_uri
if image_uri and image_uri not in built_in_images(aws_session.region):
print(
"Skipping python version validation, make sure versions match "
"between local environment and container."
)
else:
# set default image_uri to base
image_uri = image_uri or retrieve_image(Framework.BASE, aws_session.region)
tag = aws_session.get_full_image_tag(image_uri)
major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups()
if (sys.version_info.major, sys.version_info.minor) != (
int(major_version),
int(minor_version),
):
raise RuntimeError(
"Python version must match between local environment and container. "
f"Client is running Python {sys.version_info.major}.{sys.version_info.minor} "
f"locally, but container uses Python {major_version}.{minor_version}."
)
def _process_dependencies(dependencies: str | Path | list[str], temp_dir: Path) -> None:
if isinstance(dependencies, (str, Path)):
# requirements file
shutil.copy(Path(dependencies).resolve(), temp_dir / "requirements.txt")
else:
# list of packages
with open(temp_dir / "requirements.txt", "w") as f:
f.write("\n".join(dependencies))
class _IncludeModules:
def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] = None):
modules = modules or []
if isinstance(modules, (str, ModuleType)):
modules = [modules]
self._modules = [
(importlib.import_module(module) if isinstance(module, str) else module)
for module in modules
]
def __enter__(self):
"""Register included modules with cloudpickle to be pickled by value"""
for module in self._modules:
cloudpickle.register_pickle_by_value(module)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Unregister included modules with cloudpickle to be pickled by value"""
for module in self._modules:
cloudpickle.unregister_pickle_by_value(module)
def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str:
"""Create an entry point from a function"""
wrapped_entry_point = functools.partial(entry_point, *args, **kwargs)
try:
serialized = cloudpickle.dumps(wrapped_entry_point)
except Exception as e:
raise RuntimeError(
"Serialization failed for decorator hybrid job. If you are referencing "
"an object from outside the function scope, either directly or through "
"function parameters, try instantiating the object inside the decorated "
"function instead."
) from e
return run_entry_point.format(
serialized=serialized,
function_name=entry_point.__name__,
)
def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> dict:
"""Capture function arguments as hyperparameters"""
signature = inspect.signature(entry_point)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
hyperparameters = {}
for param, value in bound_args.arguments.items():
param_kind = signature.parameters[param].kind
if param_kind in [
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
]:
hyperparameters[param] = value
elif param_kind == inspect.Parameter.VAR_KEYWORD:
hyperparameters.update(**value)
else:
warnings.warn(
"Positional only arguments will not be logged to the hyperparameters file.",
stacklevel=1,
)
return {name: _sanitize(value) for name, value in hyperparameters.items()}
def _sanitize(hyperparameter: Any) -> str:
"""Sanitize forbidden characters from hp strings"""
string_hp = str(hyperparameter)
sanitized = (
string_hp
# replace forbidden characters with close matches
.replace("\n", " ")
.replace("$", "?")
.replace("(", "{")
.replace("&", "+")
.replace("`", "'")
# not technically forbidden, but to avoid mismatched parens
.replace(")", "}")
)
# max allowed length for a hyperparameter is 2500
if len(sanitized) > 2500:
# show as much as possible, including the final 20 characters
return f"{sanitized[:2500 - 23]}...{sanitized[-20:]}"
return sanitized
def _process_input_data(input_data: dict) -> list[str]:
"""Create symlinks to data.
Logic chart for how the service moves files into the data directory on the instance:
input data matches exactly one file: cwd/filename -> channel/filename
input data matches exactly one directory: cwd/dirname/* -> channel/*
else (multiple matches, possibly including exact):
cwd/prefix_match -> channel/prefix_match, for each match
"""
input_data = input_data or {}
if not isinstance(input_data, dict):
input_data = {"input": input_data}
def matches(prefix: str) -> list[str]:
return [str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(prefix)]
def is_prefix(path: str) -> bool:
return len(matches(path)) > 1 or not Path(path).exists()
prefix_channels = set()
directory_channels = set()
file_channels = set()
for channel, data in input_data.items():
if AwsSession.is_s3_uri(str(data)) or isinstance(data, S3DataSourceConfig):
channel_arg = f'channel="{channel}"' if channel != "input" else ""
print(
"Input data channels mapped to an S3 source will not be available in "
f"the working directory. Use `get_input_data_dir({channel_arg})` to read "
f"input data from S3 source inside the job container."
)
elif is_prefix(str(data)):
prefix_channels.add(channel)
elif Path(data).is_dir():
directory_channels.add(channel)
else:
file_channels.add(channel)
return symlink_input_data.format(
prefix_matches={channel: matches(input_data[channel]) for channel in prefix_channels},
input_data_items=[
(channel, data)
for channel, data in input_data.items()
if channel in prefix_channels | directory_channels | file_channels
],
prefix_channels=prefix_channels,
directory_channels=directory_channels,
)
def _create_job(job_args: dict[str, Any], local: bool = False) -> QuantumJob:
"""Create an AWS or Local hybrid job"""
if local:
from braket.jobs.local import LocalQuantumJob
for aws_only_arg in [
"wait_until_complete",
"copy_checkpoints_from_job",
"instance_config",
"distribution",
"stopping_condition",
"tags",
"logger",
]:
if aws_only_arg in job_args:
del job_args[aws_only_arg]
return LocalQuantumJob.create(**job_args)
else:
from braket.aws import AwsQuantumJob
return AwsQuantumJob.create(**job_args)