Source code for braket.jobs.metrics_data.cwl_insights_metrics_fetcher

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

import time
from logging import Logger, getLogger
from typing import Any, Optional, Union

from braket.aws.aws_session import AwsSession
from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType
from braket.jobs.metrics_data.exceptions import MetricsRetrievalError
from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser


[docs] class CwlInsightsMetricsFetcher: LOG_GROUP_NAME = "/aws/braket/jobs" QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 def __init__( self, aws_session: AwsSession, poll_timeout_seconds: float = 10, poll_interval_seconds: float = 1, logger: Logger = getLogger(__name__), ): """Initializes a `CwlInsightsMetricsFetcher`. Args: aws_session (AwsSession): AwsSession to connect to AWS with. poll_timeout_seconds (float): The polling timeout for retrieving the metrics, in seconds. Default: 10 seconds. poll_interval_seconds (float): The interval of time, in seconds, between polling for results. Default: 1 second. logger (Logger): Logger object with which to write logs, such as quantum task statuses while waiting for a quantum task to be in a terminal state. Default is `getLogger(__name__)` """ self._poll_timeout_seconds = poll_timeout_seconds self._poll_interval_seconds = poll_interval_seconds self._logger = logger self._logs_client = aws_session.logs_client @staticmethod def _get_element_from_log_line( element_name: str, log_line: list[dict[str, Any]] ) -> Optional[str]: """Finds and returns an element of a log line from CloudWatch Insights results. Args: element_name (str): The element to find. log_line (list[dict[str, Any]]): An iterator for RegEx matches on a log line. Returns: Optional[str]: The value of the element with the element name, or None if no such element is found. """ return next( (element["value"] for element in log_line if element["field"] == element_name), None ) def _get_metrics_results_sync(self, query_id: str) -> list[Any]: """Waits for the CloudWatch Insights query to complete and then returns all the results. Args: query_id (str): CloudWatch Insights query ID. Raises: MetricsRetrievalError: Raised if the query is Failed or Cancelled. Returns: list[Any]: The results from CloudWatch insights 'GetQueryResults' operation. """ timeout_time = time.time() + self._poll_timeout_seconds while time.time() < timeout_time: response = self._logs_client.get_query_results(queryId=query_id) query_status = response["status"] if query_status in ["Failed", "Cancelled"]: raise MetricsRetrievalError(f"Query {query_id} failed with status {query_status}.") elif query_status == "Complete": return response["results"] else: time.sleep(self._poll_interval_seconds) self._logger.warning(f"Timed out waiting for query {query_id}.") return [] def _parse_log_line(self, result_entry: list[dict[str, Any]], parser: LogMetricsParser) -> None: """Parses the single entry from CloudWatch Insights results and adds any metrics it finds to 'all_metrics' along with the timestamp for the entry. Args: result_entry (list[dict[str, Any]]): A structured result from calling CloudWatch Insights to get logs that contain metrics. A single entry contains the message (the actual line logged to output), the timestamp (generated by CloudWatch Logs), and other metadata that we (currently) do not use. parser (LogMetricsParser) : The CWL metrics parser. """ if message := self._get_element_from_log_line("@message", result_entry): timestamp = self._get_element_from_log_line("@timestamp", result_entry) parser.parse_log_message(timestamp, message) def _parse_log_query_results( self, results: list[Any], metric_type: MetricType, statistic: MetricStatistic ) -> dict[str, list[Union[str, float, int]]]: """Parses CloudWatch Insights results and returns all found metrics. Args: results (list[Any]): A structured result from calling CloudWatch Insights to get logs that contain metrics. metric_type (MetricType): The type of metrics to get. statistic (MetricStatistic): The statistic to determine which metric value to use when there is a conflict. Returns: dict[str, list[Union[str, float, int]]]: The metrics data. """ parser = LogMetricsParser() for result in results: self._parse_log_line(result, parser) return parser.get_parsed_metrics(metric_type, statistic)
[docs] def get_metrics_for_job( self, job_name: str, metric_type: MetricType = MetricType.TIMESTAMP, statistic: MetricStatistic = MetricStatistic.MAX, job_start_time: int | None = None, job_end_time: int | None = None, stream_prefix: str | None = None, ) -> dict[str, list[Union[str, float, int]]]: """Synchronously retrieves all the algorithm metrics logged by a given Hybrid Job. Args: job_name (str): The name of the Hybrid Job. The name must be exact to ensure only the relevant metrics are retrieved. metric_type (MetricType): The type of metrics to get. Default is MetricType.TIMESTAMP. statistic (MetricStatistic): The statistic to determine which metric value to use when there is a conflict. Default is MetricStatistic.MAX. job_start_time (int | None): The time when the hybrid job started. Default: 3 hours before job_end_time. job_end_time (int | None): If the hybrid job is complete, this should be the time at which the hybrid job finished. Default: current time. stream_prefix (str | None): If a logs prefix is provided, it will be used instead of the job name. Returns: dict[str, list[Union[str, float, int]]]: The metrics data, where the keys are the column names and the values are a list containing the values in each row. Example: timestamp energy 0 0.1 1 0.2 would be represented as: { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } The values may be integers, floats, strings or None. """ query_end_time = job_end_time or int(time.time()) query_start_time = job_start_time or query_end_time - self.QUERY_DEFAULT_JOB_DURATION stream_prefix = (stream_prefix or job_name).replace("/", "\\/") query = ( f"fields @timestamp, @message " f"| filter @logStream like /^{stream_prefix}\\// " f"| filter @message like /Metrics - /" ) response = self._logs_client.start_query( logGroupName=self.LOG_GROUP_NAME, startTime=query_start_time, endTime=query_end_time, queryString=query, limit=10000, ) query_id = response["queryId"] results = self._get_metrics_results_sync(query_id) return self._parse_log_query_results(results, metric_type, statistic)