from dataclasses import dataclass
from pathlib import Path
from enum import Enum, auto
import torch
import logging
[docs]
@dataclass
class ConfigParameters:
"""
Configuration for Stale Synchronous Parallel training for Asynchronous SGD (SSP-ASGD).
:param num_workers: Number of worker processes.
:type num_workers: int
:param staleness: Staleness bound allowed for the workers during training. Represents the maximum number of versions a worker can be behind the latest version.
:type staleness: int
:param lr: Learning rate for the model. Represents the step size for updating the model parameters.
:type lr: float
:param local_steps: Number of steps/updates each worker locally computes before pushing gradients to the server.
:type local_steps: int
:param batch_size: Batch size for each training step and the data loader.
:type batch_size: int
:param device: Device to use for training (e.g., "cuda" or "cpu").
:type device: str
:param log_level: Logging verbosity level.
:type log_level: int
"""
num_workers: int = 5
staleness: int = 50
lr: float = 0.01
local_steps: int = 1
batch_size: int = 10
device: str = "cuda" if torch.cuda.is_available() else "cpu"
log_level: int = logging.INFO
tol: float = 1e-8
Amplitude: float = 1