Source code for braket.jobs.logs

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import collections
import os
import sys
from collections.abc import Generator

##############################################################################
#
# Support for reading logs
#
##############################################################################
from typing import ClassVar, Optional

from botocore.exceptions import ClientError

from braket.aws.aws_session import AwsSession


[docs] class ColorWrap: """A callable that prints text in a different color depending on the instance. Up to 5 if the standard output is a terminal or a Jupyter notebook cell. """ # For what color each number represents, see # https://misc.flogisoft.com/bash/tip_colors_and_formatting#colors _stream_colors: ClassVar = [34, 35, 32, 36, 33] def __init__(self, force: bool = False): """Initialize a `ColorWrap`. Args: force (bool): If True, the render output is colorized wherever the output is. Default: False. """ self.colorize = force or sys.stdout.isatty() or os.environ.get("JPY_PARENT_PID", None) def __call__(self, index: int, s: str): """Prints the string, colorized or not, depending on the environment. Args: index (int): The instance number. s (str): The string to print. """ if self.colorize: self._color_wrap(index, s) else: print(s) def _color_wrap(self, index: int, s: str) -> None: """Prints the string in a color determined by the index. Args: index (int): The instance number. s (str): The string to print (color-wrapped). """ print(f"\x1b[{self._stream_colors[index % len(self._stream_colors)]}m{s}\x1b[0m")
# Position is a tuple that includes the last read timestamp and the number of items that were read # at that time. This is used to figure out which event to start with on the next read. Position = collections.namedtuple("Position", ["timestamp", "skip"])
[docs] def multi_stream_iter( aws_session: AwsSession, log_group: str, streams: list[str], positions: dict[str, Position] ) -> Generator[tuple[int, dict]]: """Iterates over the available events coming from a set of log streams. Log streams are in a single log group interleaving the events from each stream, so they yield in timestamp order. Args: aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. log_group (str): The name of the log group. streams (list[str]): A list of the log stream names. The the stream number is the position of the stream in this list. positions (dict[str, Position]): A list of (timestamp, skip) pairs which represent the last record read from each stream. Yields: Generator[tuple[int, dict]]: A tuple of (stream number, cloudwatch log event). """ event_iters = [ log_stream(aws_session, log_group, s, positions[s].timestamp, positions[s].skip) for s in streams ] events = [] for s in event_iters: try: events.append(next(s)) except StopIteration: events.append(None) while any(events): i = events.index(min(events, key=lambda x: x["timestamp"] if x else float("inf"))) yield i, events[i] try: events[i] = next(event_iters[i]) except StopIteration: events[i] = None
[docs] def log_stream( aws_session: AwsSession, log_group: str, stream_name: str, start_time: int = 0, skip: int = 0 ) -> Generator[dict]: """A generator for log items in a single stream. This yields all the items that are available at the current moment. Args: aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. log_group (str): The name of the log group. stream_name (str): The name of the specific stream. start_time (int): The time stamp value to start reading the logs from. Default: 0. skip (int): The number of log entries to skip at the start. Default: 0 (This is for when there are multiple entries at the same timestamp.) Yields: Generator[dict]: A CloudWatch log event with the following key-value pairs: 'timestamp' (int): The time of the event. 'message' (str): The log event data. 'ingestionTime' (int): The time the event was ingested. """ next_token = None event_count = 1 while event_count > 0: response = aws_session.get_log_events( log_group, stream_name, start_time, start_from_head=True, next_token=next_token, ) next_token = response["nextForwardToken"] events = response["events"] event_count = len(events) if event_count > skip: events = events[skip:] skip = 0 else: skip = skip - event_count events = [] yield from events
[docs] def flush_log_streams( # noqa C901 aws_session: AwsSession, log_group: str, stream_prefix: str, stream_names: list[str], positions: dict[str, Position], stream_count: int, has_streams: bool, color_wrap: ColorWrap, state: list[str], queue_position: Optional[str] = None, ) -> bool: """Flushes log streams to stdout. Args: aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. log_group (str): The name of the log group. stream_prefix (str): The prefix for log streams to flush. stream_names (list[str]): A list of the log stream names. The position of the stream in this list is the stream number. If incomplete, the function will check for remaining streams and mutate this list to add stream names when available, up to the `stream_count` limit. positions (dict[str, Position]): A dict mapping stream numbers to (timestamp, skip) pairs which represent the last record read from each stream. The function will update this list after being called to represent the new last record read from each stream. stream_count (int): The number of streams expected. has_streams (bool): Whether the function has already been called once all streams have been found. This value is possibly updated and returned at the end of execution. color_wrap (ColorWrap): An instance of ColorWrap to potentially color-wrap print statements from different streams. state (list[str]): The previous and current state of the job. queue_position (Optional[str]): The current queue position. This is not passed in if the job is ran with `quiet=True` Raises: Exception: Any exception found besides a ResourceNotFoundException. Returns: bool: Returns 'True' if any streams have been flushed. """ if len(stream_names) < stream_count: # Log streams are created whenever a container starts writing to stdout/err, # so this list may be dynamic until we have a stream for every instance. try: streams = aws_session.describe_log_streams( log_group, stream_prefix, limit=stream_count, ) # stream_names = [...] wouldn't modify the list by reference. new_streams = [ s["logStreamName"] for s in streams["logStreams"] if s["logStreamName"] not in stream_names ] stream_names.extend(new_streams) positions |= [ (s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions ] except ClientError as e: # On the very first training job run on an account, there's no # log group until the container starts logging, so ignore any # errors thrown about that until logging begins. err = e.response.get("Error", {}) if err.get("Code") != "ResourceNotFoundException": raise if stream_names: if not has_streams: print() has_streams = True for idx, event in multi_stream_iter(aws_session, log_group, stream_names, positions): color_wrap(idx, event["message"]) ts, count = positions[stream_names[idx]] if event["timestamp"] == ts: positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) else: positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1) elif queue_position is not None and state[1] == "QUEUED": print(f"Job queue position: {queue_position}", end="\n", flush=True) elif state[0] != state[1] and state[1] == "RUNNING" and queue_position is not None: print("Running:", end="\n", flush=True) else: print(".", end="", flush=True) return has_streams