src.core.parameter_server module
- class src.core.parameter_server.ParameterServer(model, param)[source]
Bases:
object
Parameter Server for Stale Synchronous Parallel training. The server manages the global model parameters and coordinates the gradient updates from multiple workers. Each worker computes gradients locally and with a push operation sends the result to the server, which aggregates the gradients and updates the model parameters. Each worker can receive the latest model parameters with a pull operation.
Arguments: :param model: PyTorch model instance :type model: nn.Module :param param: Configuration parameters :type param: ConfigParameters
- get_staleness_stats()[source]
Returns a dict: {“per_worker”: { wid: { “mean”:…, “median”:…, “std”:…, “pct_over_bound”:…}, …},”combined”: {“mean”:…,”median”:…,”std”:…,”pct_over_bound”:…}}
- push(wid, w_version: int, grads: list[Tensor]) ParameterServerStatus [source]