src.core.parameter_server module

class src.core.parameter_server.ParameterServer(model, param)[source]

Bases: object

Parameter Server for Stale Synchronous Parallel training. The server manages the global model parameters and coordinates the gradient updates from multiple workers. Each worker computes gradients locally and with a push operation sends the result to the server, which aggregates the gradients and updates the model parameters. Each worker can receive the latest model parameters with a pull operation.

Arguments: :param model: PyTorch model instance :type model: nn.Module :param param: Configuration parameters :type param: ConfigParameters

get_hist() list[int][source]

Return the raw counts of staleness occurrences for this run.

get_staleness_stats()[source]

Returns a dict: {“per_worker”: { wid: { “mean”:…, “median”:…, “std”:…, “pct_over_bound”:…}, …},”combined”: {“mean”:…,”median”:…,”std”:…,”pct_over_bound”:…}}

get_time_push()[source]

Return the time spent in push and pull operations.

get_version()[source]

Return the current version of the model parameters.

pull()[source]
push(wid, w_version: int, grads: list[Tensor]) ParameterServerStatus[source]