Source code for braket.jobs.metrics_data.cwl_metrics_fetcher

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import time
from logging import Logger, getLogger
from typing import Union

from braket.aws.aws_session import AwsSession
from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType
from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser


[docs] class CwlMetricsFetcher: LOG_GROUP_NAME = "/aws/braket/jobs" def __init__( self, aws_session: AwsSession, poll_timeout_seconds: float = 10, logger: Logger = getLogger(__name__), ): """Initializes a `CwlMetricsFetcher`. Args: aws_session (AwsSession): AwsSession to connect to AWS with. poll_timeout_seconds (float): The polling timeout for retrieving the metrics, in seconds. Default: 10 seconds. logger (Logger): Logger object with which to write logs, such as quantum task statuses while waiting for quantum task to be in a terminal state. Default is `getLogger(__name__)` """ self._poll_timeout_seconds = poll_timeout_seconds self._logger = logger self._logs_client = aws_session.logs_client @staticmethod def _is_metrics_message(message: str) -> bool: """Returns true if a given message is designated as containing Metrics. Args: message (str): The message to check. Returns: bool: True if the given message is designated as containing Metrics; False otherwise. """ return "Metrics -" in message if message else False def _parse_metrics_from_log_stream( self, stream_name: str, timeout_time: float, parser: LogMetricsParser, ) -> None: """Synchronously retrieves the algorithm metrics logged in a given hybrid job log stream. Args: stream_name (str): The name of the log stream. timeout_time (float) : We stop getting metrics if the current time is beyond the timeout time. parser (LogMetricsParser) : The CWL metrics parser. """ kwargs = { "logGroupName": self.LOG_GROUP_NAME, "logStreamName": stream_name, "startFromHead": True, "limit": 10000, } previous_token = None while time.time() < timeout_time: response = self._logs_client.get_log_events(**kwargs) for event in response.get("events"): message = event.get("message") if self._is_metrics_message(message): parser.parse_log_message(event.get("timestamp"), message) next_token = response.get("nextForwardToken") if not next_token or next_token == previous_token: return previous_token = next_token kwargs["nextToken"] = next_token self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") def _get_log_streams_for_job(self, job_name: str, timeout_time: float) -> list[str]: """Retrieves the list of log streams relevant to a hybrid job. Args: job_name (str): The name of the hybrid job. timeout_time (float) : Metrics cease getting streamed if the current time exceeds the timeout time. Returns: list[str]: A list of log stream names for the given hybrid job. """ kwargs = { "logGroupName": self.LOG_GROUP_NAME, "logStreamNamePrefix": f"{job_name}/algo-", } log_streams = [] while time.time() < timeout_time: response = self._logs_client.describe_log_streams(**kwargs) if streams := response.get("logStreams"): for stream in streams: if name := stream.get("logStreamName"): log_streams.append(name) if next_token := response.get("nextToken"): kwargs["nextToken"] = next_token else: return log_streams self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") return log_streams
[docs] def get_metrics_for_job( self, job_name: str, metric_type: MetricType = MetricType.TIMESTAMP, statistic: MetricStatistic = MetricStatistic.MAX, ) -> dict[str, list[Union[str, float, int]]]: """Synchronously retrieves all the algorithm metrics logged by a given Hybrid Job. Args: job_name (str): The name of the Hybrid Job. The name must be exact to ensure only the relevant metrics are retrieved. metric_type (MetricType): The type of metrics to get. Default is MetricType.TIMESTAMP. statistic (MetricStatistic): The statistic to determine which metric value to use when there is a conflict. Default is MetricStatistic.MAX. Returns: dict[str, list[Union[str, float, int]]]: The metrics data, where the keys are the column names and the values are a list containing the values in each row. Example: timestamp energy 0 0.1 1 0.2 would be represented as: { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } values may be integers, floats, strings or None. """ timeout_time = time.time() + self._poll_timeout_seconds parser = LogMetricsParser() log_streams = self._get_log_streams_for_job(job_name, timeout_time) for log_stream in log_streams: self._parse_metrics_from_log_stream(log_stream, timeout_time, parser) return parser.get_parsed_metrics(metric_type, statistic)