import threading
import multiprocessing as mp
from collections import defaultdict
import numpy as np
import torch
from enum import Enum, auto
from .data_types import ParameterServerStatus
import time
import collections
from typing import Any, Dict, List, Tuple
[docs]
class ParameterServer:
"""
Parameter Server for Stale Synchronous Parallel training.
The server manages the global model parameters and coordinates the gradient updates from multiple workers.
Each worker computes gradients locally and with a `push` operation sends the result to the server, which aggregates the gradients and updates the model parameters.
Each worker can receive the latest model parameters with a `pull` operation.
Arguments:
:param model: PyTorch model instance
:type model: nn.Module
:param param: Configuration parameters
:type param: ConfigParameters
"""
def __init__(self, model, param):
self.param = param
self.theta = [p.detach().share_memory_() for p in model.parameters()]
self._current_ver = mp.Value("i", 0)
self.prev_theta = [p.clone().detach() for p in self.theta]
self._lock = threading.Lock()
# one list of staleness values per worker for tracking staleness stats
self._staleness = defaultdict(list)
# One list of the global staleness count
self.hist = [0] * (param.staleness +1) # We assume max staleness is 50, so easier data structure for F computation possible
self.total = 0
self.count_time_push = 0
self.count_time_pull = 0
[docs]
def pull(self):
server_start_pull = time.perf_counter()
result = [p.clone() for p in self.theta], self._current_ver.value
server_end_pull = time.perf_counter()
self.count_time_pull += (server_end_pull-server_start_pull)
return result
# Method not implemented
[docs]
def push(self, wid, w_version: int, grads: list[torch.Tensor]) -> ParameterServerStatus:
return ParameterServerStatus.REJECTED
[docs]
def get_version(self):
"""Return the current version of the model parameters."""
with self._lock:
return self._current_ver.value
[docs]
def get_time_push(self):
"""Return the time spent in push and pull operations."""
return (self.count_time_push, self.count_time_pull)
[docs]
def get_hist(self) -> list[int]:
"""Return the raw counts of staleness occurrences for this run."""
# note: self.hist is of length staleness+1
return list(self.hist)
[docs]
def get_staleness_stats(self):
"""
Returns a dict:
{"per_worker": { wid: { "mean":…, "median":…, "std":…, "pct_over_bound":…}, …},"combined": {"mean":…,"median":…,"std":…,"pct_over_bound":…}}
"""
per_worker = {}
all_vals = []
bound = self.param.staleness
for wid, vals in self._staleness.items():
arr = np.array(vals, dtype=float)
if arr.size:
mean = float(arr.mean())
median = float(np.median(arr))
std = float(arr.std())
# compute fraction > bound
over = (arr > bound).sum()
pct = float(over) / arr.size * 100.0
per_worker[wid] = {"mean":mean, "median": median, "std":std, "pct_over_bound": pct}
all_vals.append(arr)
else:
per_worker[wid] = {"mean": None, "median": None, "std":None, "pct_over_bound": None}
if all_vals:
all_concat = np.concatenate(all_vals)
combined = {"mean":float(all_concat.mean()), "median":float(np.median(all_concat)), "std":float(all_concat.std()), "pct_over_bound": float((all_concat > bound).sum()) / all_concat.size * 100.0}
else:
combined = {"mean":None, "median":None, "std":None, "pct_over_bound": None}
return {"per_worker": per_worker, "combined": combined}