Source code for src.core.parameter_server_saasgd

from .parameter_server import ParameterServer, ParameterServerStatus
from ..config import ConfigParameters
from torch import nn
import time
import torch

[docs] class ParameterServerSAASGD(ParameterServer): """ Staleness-aware Async-SGD Implementation from the paper: "Staleness-aware Asynchronous SGD for Distributed Deep Learning" (https://arxiv.org/pdf/1511.05950). """ def __init__(self, model: nn.Module, param: ConfigParameters): """ Initialize the Parameter Server for ASAP-SGD. :param model: PyTorch model instance. :type model: nn.Module :param param: Configuration parameters. :type param: ConfigParameters """ super().__init__(model, param) self.c = max(1, param.num_workers // param.staleness) self.accumulated_grads = [torch.zeros_like(p) for p in self.theta] self.accumulated_count = 0
[docs] def push(self, wid, w_version: int, grads: list[torch.Tensor]) -> ParameterServerStatus: server_start_push = time.perf_counter() with self._lock: current_ver = self._current_ver.value tau = current_ver - w_version if tau < 0: server_end_push = time.perf_counter() self.count_time_push += (server_end_push - server_start_push) return ParameterServerStatus.REJECTED # Record staleness self._staleness[wid].append(tau) if tau < len(self.hist): self.hist[tau] += 1 else: pass # Scale gradients by alpha0 / tau (if tau > 0) alpha0 = self.param.lr scale = alpha0 / tau if tau != 0 else alpha0 scaled_grads = [g * scale for g in grads] # Accumulate gradients for acc_g, s_g in zip(self.accumulated_grads, scaled_grads): acc_g.add_(s_g) self.accumulated_count += 1 # Apply update if enough gradients accumulated if self.accumulated_count >= self.c: # Compute average gradient avg_grads = [acc_g / self.accumulated_count for acc_g in self.accumulated_grads] # Update parameters for theta, avg_g in zip(self.theta, avg_grads): theta.sub_(avg_g) # Increment version self._current_ver.value += 1 # Reset accumulators self.accumulated_grads = [torch.zeros_like(p) for p in self.theta] self.accumulated_count = 0 server_end_push = time.perf_counter() self.count_time_push += (server_end_push - server_start_push) return ParameterServerStatus.ACCEPTED