# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
import math
import tarfile
import tempfile
import time
from enum import Enum
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, ClassVar
import boto3
from botocore.exceptions import ClientError
from braket.aws import AwsDevice
from braket.aws.aws_session import AwsSession
from braket.aws.queue_information import HybridJobQueueInfo
from braket.jobs import logs
from braket.jobs.config import (
CheckpointConfig,
InstanceConfig,
OutputDataConfig,
S3DataSourceConfig,
StoppingCondition,
)
from braket.jobs.data_persistence import load_job_result
from braket.jobs.metrics_data.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher
# TODO: Have added metric file in metrics folder, but have to decide on the name for keep
# for the files, since all those metrics are retrieved from the CW.
from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType
from braket.jobs.quantum_job import QuantumJob
from braket.jobs.quantum_job_creation import prepare_quantum_job
[docs]
class AwsQuantumJob(QuantumJob):
"""Amazon Braket implementation of a quantum job."""
TERMINAL_STATES: ClassVar[set[str]] = {"CANCELLED", "COMPLETED", "FAILED"}
RESULTS_FILENAME = "results.json"
RESULTS_TAR_FILENAME = "model.tar.gz"
LOG_GROUP = "/aws/braket/jobs"
[docs]
class LogState(Enum):
"""Log state enum."""
TAILING = "tailing"
JOB_COMPLETE = "job_complete"
COMPLETE = "complete"
[docs]
@classmethod
def create(
cls,
device: str,
source_module: str,
entry_point: str | None = None,
image_uri: str | None = None,
job_name: str | None = None,
code_location: str | None = None,
role_arn: str | None = None,
wait_until_complete: bool = False,
hyperparameters: dict[str, Any] | None = None,
input_data: str | dict | S3DataSourceConfig | None = None,
instance_config: InstanceConfig | None = None,
distribution: str | None = None,
stopping_condition: StoppingCondition | None = None,
output_data_config: OutputDataConfig | None = None,
copy_checkpoints_from_job: str | None = None,
checkpoint_config: CheckpointConfig | None = None,
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
quiet: bool = False,
reservation_arn: str | None = None,
) -> AwsQuantumJob:
"""Creates a hybrid job by invoking the Braket CreateJob API.
Args:
device (str): Device ARN of the QPU device that receives priority quantum
task queueing once the hybrid job begins running. Each QPU has a separate hybrid
jobs queue so that only one hybrid job is running at a time. The device string is
accessible in the hybrid job instance as the environment variable
"AMZN_BRAKET_DEVICE_ARN". When using embedded simulators, you may provide the device
argument as a string of the form: "local:<provider>/<simulator_name>".
source_module (str): Path (absolute, relative or an S3 URI) to a python module to be
tarred and uploaded. If `source_module` is an S3 URI, it must point to a
tar.gz file. Otherwise, source_module may be a file or directory.
entry_point (str | None): A str that specifies the entry point of the hybrid job,
relative to the source module. The entry point must be in the format
`importable.module` or `importable.module:callable`. For example,
`source_module.submodule:start_here` indicates the `start_here` function
contained in `source_module.submodule`. If source_module is an S3 URI,
entry point must be given. Default: source_module's name
image_uri (str | None): A str that specifies the ECR image to use for executing the
hybrid job. `image_uris.retrieve_image()` function may be used for retrieving the
ECR image URIs for the containers supported by Braket.
Default = `<Braket base image_uri>`.
job_name (str | None): A str that specifies the name with which the hybrid job is
created. Allowed pattern for hybrid job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`
Default: f'{image_uri_type}-{timestamp}'.
code_location (str | None): The S3 prefix URI where custom code will be uploaded.
Default: f's3://{default_bucket_name}/jobs/{job_name}/script'.
role_arn (str | None): A str providing the IAM role ARN used to execute the
script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`.
wait_until_complete (bool): `True` if we should wait until the hybrid job completes.
This would tail the hybrid job logs as it waits. Otherwise `False`.
Default: `False`.
hyperparameters (dict[str, Any] | None): Hyperparameters accessible to the hybrid job.
The hyperparameters are made accessible as a dict[str, str] to the hybrid job.
For convenience, this accepts other types for keys and values, but `str()`
is called to convert them before being passed on. Default: None.
input_data (str | dict | S3DataSourceConfig | None): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
path, S3 URI, or S3DataSourceConfig is provided, it will be given a default
channel name "input".
Default: {}.
instance_config (InstanceConfig | None): Configuration of the instance(s) for running
the classical code for the hybrid job. Default:
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.
distribution (str | None): A str that specifies how the hybrid job should be
distributed. If set to "data_parallel", the hyperparameters for the hybrid job will
be set to use data parallelism features for PyTorch or TensorFlow. Default: None.
stopping_condition (StoppingCondition | None): The maximum length of time, in seconds,
and the maximum number of quantum tasks that a hybrid job can run before being
forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).
output_data_config (OutputDataConfig | None): Specifies the location for the output of
the hybrid job.
Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
kmsKeyId=None).
copy_checkpoints_from_job (str | None): A str that specifies the hybrid job ARN whose
checkpoint you want to use in the current hybrid job. Specifying this value will
copy over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config
s3Uri to the current hybrid job's checkpoint_config s3Uri, making it available at
checkpoint_config.localPath during the hybrid job execution. Default: None
checkpoint_config (CheckpointConfig | None): Configuration that specifies the location
where checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').
aws_session (AwsSession | None): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this
hybrid job.
Default: {}.
logger (Logger): Logger object with which to write logs, such as quantum task statuses
while waiting for quantum task to be in a terminal state. Default is
`getLogger(__name__)`
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
reservation_arn (str | None): the reservation window arn provided by Braket
Direct to reserve exclusive usage for the device to run the hybrid job on.
Default: None.
Returns:
AwsQuantumJob: Hybrid job tracking the execution on Amazon Braket.
Raises:
ValueError: Raises ValueError if the parameters are not valid.
"""
aws_session = AwsQuantumJob._initialize_session(aws_session, device, logger)
create_job_kwargs = prepare_quantum_job(
device=device,
source_module=source_module,
entry_point=entry_point,
image_uri=image_uri,
job_name=job_name,
code_location=code_location,
role_arn=role_arn,
hyperparameters=hyperparameters,
input_data=input_data,
instance_config=instance_config,
distribution=distribution,
stopping_condition=stopping_condition,
output_data_config=output_data_config,
copy_checkpoints_from_job=copy_checkpoints_from_job,
checkpoint_config=checkpoint_config,
aws_session=aws_session,
tags=tags,
reservation_arn=reservation_arn,
)
job_arn = aws_session.create_job(**create_job_kwargs)
job = AwsQuantumJob(job_arn, aws_session, quiet)
if wait_until_complete:
print(f"Initializing Braket Job: {job_arn}")
job.logs(wait=True)
return job
def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool = False):
"""Initializes an `AwsQuantumJob`.
Args:
arn (str): The ARN of the hybrid job.
aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the hybrid job.
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
Raises:
ValueError: Supplied region and session region do not match.
"""
self._arn: str = arn
self._quiet = quiet
if aws_session:
if not self._is_valid_aws_session_region_for_job_arn(aws_session, arn):
raise ValueError(
"The aws session region does not match the region for the supplied arn."
)
self._aws_session = aws_session
else:
self._aws_session = AwsQuantumJob._default_session_for_job_arn(arn)
self._metadata = {}
@staticmethod
def _is_valid_aws_session_region_for_job_arn(aws_session: AwsSession, job_arn: str) -> bool:
"""Checks whether the job region and session region match.
Returns:
bool: `True` when the aws_session region matches the job_arn region; otherwise
`False`.
"""
job_region = job_arn.split(":")[3]
return job_region == aws_session.region
@staticmethod
def _default_session_for_job_arn(job_arn: str) -> AwsSession:
"""Get an AwsSession for the Hybrid Job ARN. The AWS session should be in the region of the
hybrid job.
Args:
job_arn (str): The ARN for the quantum hybrid job.
Returns:
AwsSession: `AwsSession` object with default `boto_session` in hybrid job's region.
"""
job_region = job_arn.split(":")[3]
boto_session = boto3.Session(region_name=job_region)
return AwsSession(boto_session=boto_session)
@property
def arn(self) -> str:
"""str: The ARN (Amazon Resource Name) of the quantum hybrid job."""
return self._arn
@property
def name(self) -> str:
"""str: The name of the quantum job."""
return self.metadata(use_cached_value=True).get("jobName")
@property
def _logs_prefix(self) -> str:
"""str: the prefix for the job logs."""
# jobs ARNs used to contain the job name and use a log prefix of `job-name`
# now job ARNs use a UUID and a log prefix of `job-name/UUID`
return (
f"{self.name}"
if self.arn.endswith(self.name)
else f"{self.name}/{self.arn.split('/')[-1]}"
)
[docs]
def state(self, use_cached_value: bool = False) -> str:
"""The state of the quantum hybrid job.
Args:
use_cached_value (bool): If `True`, uses the value most recently retrieved
value from the Amazon Braket `GetJob` operation. If `False`, calls the
`GetJob` operation to retrieve metadata, which also updates the cached
value. Default = `False`.
Returns:
str: The value of `status` in `metadata()`. This is the value of the `status` key
in the Amazon Braket `GetJob` operation.
See Also:
`metadata()`
"""
return self.metadata(use_cached_value).get("status")
[docs]
def queue_position(self) -> HybridJobQueueInfo:
"""The queue position details for the hybrid job.
Returns:
HybridJobQueueInfo: Instance of HybridJobQueueInfo class representing
the queue position information for the hybrid job. The queue_position is
only returned when the hybrid job is not in RUNNING/CANCELLING/TERMINAL states,
else queue_position is returned as None. If the queue position of the hybrid
job is greater than 15, we return '>15' as the queue_position return value.
Examples:
job status = QUEUED and position is 2 in the queue.
>>> job.queue_position()
HybridJobQueueInfo(queue_position='2', message=None)
job status = QUEUED and position is 18 in the queue.
>>> job.queue_position()
HybridJobQueueInfo(queue_position='>15', message=None)
job status = COMPLETED
>>> job.queue_position()
HybridJobQueueInfo(queue_position=None,
message='Job is in COMPLETED status. AmazonBraket does
not show queue position for this status.')
"""
response = self.metadata()["queueInfo"]
queue_position = None if response.get("position") == "None" else response.get("position")
message = response.get("message")
return HybridJobQueueInfo(queue_position=queue_position, message=message)
[docs]
def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
"""Display logs for a given hybrid job, optionally tailing them until hybrid job is
complete.
If the output is a tty or a Jupyter cell, it will be color-coded
based on which instance the log entry is from.
Args:
wait (bool): `True` to keep looking for new log entries until the hybrid job completes;
otherwise `False`. Default: `False`.
poll_interval_seconds (int): The interval of time, in seconds, between polling for
new log entries and hybrid job completion (default: 5).
Raises:
exceptions.UnexpectedStatusException: If waiting and the training hybrid job fails.
"""
# The loop below implements a state machine that alternates between checking the hybrid job
# status and reading whatever is available in the logs at this point. Note, that if we were
# called with wait == False, we never check the hybrid job status.
#
# If wait == TRUE and hybrid job is not completed, the initial state is TAILING
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the hybrid job really
# is complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
# Cloudwatch after the job was marked complete.
job_already_completed = self.state() in AwsQuantumJob.TERMINAL_STATES
log_state = (
AwsQuantumJob.LogState.TAILING
if wait and not job_already_completed
else AwsQuantumJob.LogState.COMPLETE
)
log_group = AwsQuantumJob.LOG_GROUP
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"]
has_streams = False
color_wrap = logs.ColorWrap()
previous_state = self.state()
while True:
time.sleep(poll_interval_seconds)
current_state = self.state()
has_streams = logs.flush_log_streams(
self._aws_session,
log_group,
self._logs_prefix,
stream_names,
positions,
instance_count,
has_streams,
color_wrap,
[previous_state, current_state],
None if self._quiet else self.queue_position().queue_position,
)
previous_state = current_state
if log_state == AwsQuantumJob.LogState.COMPLETE:
break
if log_state == AwsQuantumJob.LogState.JOB_COMPLETE:
log_state = AwsQuantumJob.LogState.COMPLETE
elif current_state in AwsQuantumJob.TERMINAL_STATES:
log_state = AwsQuantumJob.LogState.JOB_COMPLETE
[docs]
def metrics(
self,
metric_type: MetricType = MetricType.TIMESTAMP,
statistic: MetricStatistic = MetricStatistic.MAX,
) -> dict[str, list[Any]]:
"""Gets all the metrics data, where the keys are the column names, and the values are a list
containing the values in each row. For example, the table:
timestamp energy
0 0.1
1 0.2
would be represented as:
{ "timestamp" : [0, 1], "energy" : [0.1, 0.2] }
values may be integers, floats, strings or None.
Args:
metric_type (MetricType): The type of metrics to get. Default: MetricType.TIMESTAMP.
statistic (MetricStatistic): The statistic to determine which metric value to use
when there is a conflict. Default: MetricStatistic.MAX.
Returns:
dict[str, list[Any]]: The metrics data.
"""
fetcher = CwlInsightsMetricsFetcher(self._aws_session)
metadata = self.metadata(True)
job_start = None
job_end = None
if "startedAt" in metadata:
job_start = int(metadata["startedAt"].timestamp())
if self.state() in AwsQuantumJob.TERMINAL_STATES and "endedAt" in metadata:
job_end = int(math.ceil(metadata["endedAt"].timestamp()))
return fetcher.get_metrics_for_job(
self.name, metric_type, statistic, job_start, job_end, self._logs_prefix
)
[docs]
def cancel(self) -> str:
"""Cancels the job.
Returns:
str: Indicates the status of the job.
Raises:
ClientError: If there are errors invoking the CancelJob API.
"""
cancellation_response = self._aws_session.cancel_job(self._arn)
return cancellation_response["cancellationStatus"]
[docs]
def result(
self,
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
) -> dict[str, Any]:
"""Retrieves the hybrid job result persisted using the `save_job_result` function.
Args:
poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`.
Default: 10 days.
poll_interval_seconds (float): The polling interval, in seconds, for `result()`.
Default: 5 seconds.
Returns:
dict[str, Any]: Dict specifying the job results.
Raises:
RuntimeError: if hybrid job is in a FAILED or CANCELLED state.
TimeoutError: if hybrid job execution exceeds the polling timeout period.
"""
with tempfile.TemporaryDirectory() as temp_dir:
job_name = self.metadata(True)["jobName"]
try:
self.download_result(temp_dir, poll_timeout_seconds, poll_interval_seconds)
except ClientError as e:
if e.response["Error"]["Code"] == "404":
return {}
else:
raise e
return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name)
@staticmethod
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> dict[str, Any]:
return load_job_result(Path(temp_dir, job_name, AwsQuantumJob.RESULTS_FILENAME))
[docs]
def download_result(
self,
extract_to: str | None = None,
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
) -> None:
"""Downloads the results from the hybrid job output S3 bucket and extracts the tar.gz
bundle to the location specified by `extract_to`. If no location is specified,
the results are extracted to the current directory.
Args:
extract_to (str | None): The directory to which the results are extracted. The results
are extracted to a folder titled with the hybrid job name within this directory.
Default= `Current working directory`.
poll_timeout_seconds (float): The polling timeout, in seconds, for `download_result()`.
Default: 10 days.
poll_interval_seconds (float): The polling interval, in seconds, for
`download_result()`.Default: 5 seconds.
Raises:
RuntimeError: if hybrid job is in a FAILED or CANCELLED state.
TimeoutError: if hybrid job execution exceeds the polling timeout period.
"""
extract_to = extract_to or Path.cwd()
timeout_time = time.time() + poll_timeout_seconds
job_response = self.metadata(True)
while time.time() < timeout_time:
job_response = self.metadata(True)
job_state = self.state()
if job_state in AwsQuantumJob.TERMINAL_STATES:
output_s3_path = job_response["outputDataConfig"]["s3Path"]
output_s3_uri = f"{output_s3_path}/output/model.tar.gz"
AwsQuantumJob._attempt_results_download(self, output_s3_uri, output_s3_path)
AwsQuantumJob._extract_tar_file(f"{extract_to}/{self.name}")
return
else:
time.sleep(poll_interval_seconds)
raise TimeoutError(
f"{job_response['jobName']}: Polling for job completion "
f"timed out after {poll_timeout_seconds} seconds."
)
def _attempt_results_download(self, output_bucket_uri: str, output_s3_path: str) -> None:
try:
self._aws_session.download_from_s3(
s3_uri=output_bucket_uri, filename=AwsQuantumJob.RESULTS_TAR_FILENAME
)
except ClientError as e:
if e.response["Error"]["Code"] != "404":
raise e
exception_response = {
"Error": {
"Code": "404",
"Message": f"Error retrieving results, "
f"could not find results at '{output_s3_path}'",
}
}
raise ClientError(exception_response, "HeadObject") from e
@staticmethod
def _extract_tar_file(extract_path: str) -> None:
with tarfile.open(AwsQuantumJob.RESULTS_TAR_FILENAME, "r:gz") as tar:
tar.extractall(extract_path)
def __repr__(self) -> str:
return f"AwsQuantumJob('arn':'{self.arn}')"
def __eq__(self, other: AwsQuantumJob) -> bool:
return self.arn == other.arn if isinstance(other, AwsQuantumJob) else False
def __hash__(self) -> int:
return hash(self.arn)
@staticmethod
def _initialize_session(session_value: AwsSession, device: str, logger: Logger) -> AwsSession:
aws_session = session_value or AwsSession()
if device.startswith("local:"):
return aws_session
device_region = AwsDevice.get_device_region(device)
return (
AwsQuantumJob._initialize_regional_device_session(aws_session, device, logger)
if device_region
else AwsQuantumJob._initialize_non_regional_device_session(aws_session, device, logger)
)
@staticmethod
def _initialize_regional_device_session(
aws_session: AwsSession, device: AwsDevice, logger: Logger
) -> AwsSession:
device_region = AwsDevice.get_device_region(device)
current_region = aws_session.region
if current_region != device_region:
aws_session = aws_session.copy_session(region=device_region)
logger.info(f"Changed session region from '{current_region}' to '{device_region}'")
try:
aws_session.get_device(device)
return aws_session
except ClientError as e:
raise (
ValueError(f"'{device}' not found.")
if e.response["Error"]["Code"] == "ResourceNotFoundException"
else e
) from e
@staticmethod
def _initialize_non_regional_device_session(
aws_session: AwsSession, device: AwsDevice, logger: Logger
) -> AwsSession:
original_region = aws_session.region
try:
aws_session.get_device(device)
return aws_session
except ClientError as e:
if e.response["Error"]["Code"] != "ResourceNotFoundException":
raise e
if "qpu" not in device:
raise ValueError(f"Simulator '{device}' not found in '{original_region}'") from e
for region in frozenset(AwsDevice.REGIONS) - {original_region}:
device_session = aws_session.copy_session(region=region)
try:
device_session.get_device(device)
logger.info(
f"Changed session region from '{original_region}' to '{device_session.region}'"
)
return device_session
except ClientError as e:
if e.response["Error"]["Code"] != "ResourceNotFoundException":
raise e
raise ValueError(f"QPU '{device}' not found.")