src.core.train_runner module
- class src.core.train_runner.PSManager(address=None, authkey=None, serializer='pickle', ctx=None, *, shutdown_timeout=1.0)[source]
Bases:
BaseManager
- ParameterServer(*args, **kwds)
- get_hist(*args, **kwds)
- get_staleness_stats(*args, **kwds)
- get_time_push(*args, **kwds)
- src.core.train_runner.run_training(dataset_builder: ~typing.Callable[[int, int, int], ~typing.Tuple[~torch.utils.data.dataloader.DataLoader, int]], model: ~typing.Callable[[int], ~torch.nn.modules.module.Module], param: ~src.config.ConfigParameters = ConfigParameters(num_workers=5, staleness=50, lr=0.01, local_steps=1, batch_size=10, device='cpu', log_level=20, tol=1e-08, Amplitude=1), parameter_server: ~typing.Callable = <class 'src.core.parameter_server.ParameterServer'>, asgd_worker: ~typing.Callable = <function worker>) list[Tensor] [source]
Helper function to run the Stale Synchronous Parallel training with the provided dataset builder, model and configuration parameters.
- Parameters:
dataset_builder – Function used to build the dataset.
model – Model class to be trained.
param (ConfigParameters) – SSP Configuration parameters.
- Returns:
The final model parameters after training.
- Return type:
list[torch.Tensor]