# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
[docs]
@dataclass
class CheckpointConfig:
"""Configuration that specifies the location where checkpoint data is stored."""
localPath: str = "/opt/jobs/checkpoints"
s3Uri: str | None = None
[docs]
@dataclass
class InstanceConfig:
"""Configuration of the instance(s) used to run the hybrid job."""
instanceType: str = "ml.m5.large"
volumeSizeInGb: int = 30
instanceCount: int = 1
[docs]
@dataclass
class OutputDataConfig:
"""Configuration that specifies the location for the output of the hybrid job."""
s3Path: str | None = None
kmsKeyId: str | None = None
[docs]
@dataclass
class StoppingCondition:
"""Conditions that specify when the hybrid job should be forcefully stopped."""
maxRuntimeInSeconds: int = 5 * 24 * 60 * 60
[docs]
@dataclass
class DeviceConfig:
device: str
[docs]
class S3DataSourceConfig:
"""Data source for data that lives on S3.
Attributes:
config (dict[str, dict]): config passed to the Braket API
"""
def __init__(
self,
s3_data: str,
content_type: str | None = None,
):
"""Create a definition for input data used by a Braket Hybrid job.
Args:
s3_data (str): Defines the location of s3 data to train on.
content_type (str | None): MIME type of the input data (default: None).
"""
self.config = {
"dataSource": {
"s3DataSource": {
"s3Uri": s3_data,
}
}
}
if content_type is not None:
self.config["contentType"] = content_type