Source code for src.experiments.saasgd

"""
Test SA-ASGD against standard SGD on linear regression with synthetic overparameterized data.

From the base repository directory:  
`python -m src.experiments.saasgd`

The implemented SA-ASGD algorithm is taken from: "Staleness-aware Asynchronous SGD for Distributed Deep Learning" (https://arxiv.org/pdf/1511.05950).
"""
from __future__ import annotations
import time, pathlib, pickle, random, sys
import numpy as np
import torch
import torch.nn as nn
import logging
import torch.nn.functional as F
import matplotlib.pyplot as plt
from numpy.linalg import svd
from scipy.stats import ttest_rel
import os
import scipy.stats as stats_mod

from .. import *

[docs] def main(): """ Main function for running the experiments comparing ASAP-SGD algorithm and standard SGD. In the script a comparative experiment between ASAP-SGD and standard SGD is performed by: 1. Generate a synthetic linear regression datasets (with specified overparameterization). 2. Trains models using both SGD and ASAP-SGD across multiple random seeds (200 by default). 3. Evaluates trained models across multiple metrics: test loss, weight properties (L2 norm, sparsity, kurtosis) and convergence statistics. 4. Compares results via statistical tests (paired t-tests) 5. Visualize results: loss distributions, staleness patterns, weight characteristics. Strong reproducibility is ensured by using a fixed master seed. Checkpoints are created to save losses, weight properties, and staleness distributions for both ASAP-SGD and SGD training. """ # AMOUNT OF SEEDS YOU WANT TO COMPUTE NOW RUNS_REGULAR_SGD = 200 # Set always min to 1 for both methods (if want to retrieve/use the stored values) RUNS_ASGD = 1 # USER WILL HAVE TO CHOOSE THE AMOUNT OF OVERPARAMETRIZATION args = parse_args() # every run uses n_samples=100 n_samples = 100 # compute features = level% of samples n_features = int(n_samples * args.overparam / 100) # base checkpoint tree BASE_CKPT = pathlib.Path(__file__).parent / "ckpt" # e.g. ckpt/overparam_150/SGD and ckpt/overparam_150/ASAP_SGD cfg_dir = BASE_CKPT / f"overparam_{args.overparam}" SGD_DIR = cfg_dir / "SGD" SAASGD_DIR = cfg_dir / "SAASGD" for d in (SGD_DIR, SAASGD_DIR): d.mkdir(parents=True, exist_ok=True) # Set up logging logging.basicConfig(level=logging.INFO) # Fix the master seed so you always get the same “sub‑seeds” random.seed(1234) # Draw 100 integers in [0, 2^8) seeds = [random.randrange(2**8) for _ in range(200)] # If you change the amount of seeds, the first n will still always be the same ! # FILES FOR CHECKPOINTING sgd_losses_f = 'sgd_losses.pkl' asgd_losses_f = 'ASGD_losses.pkl' asgd_stats_f = 'ASGD_stats.pkl' staleness_distr_f = 'ASGD_staleness_distr.pkl' SGD_weight_properties_f = 'sgd_weight_properties.pkl' ASGD_weight_properties_f = 'ASGD_weight_properties.pkl' true_weight_properties_f = 'true_weight_properties.pkl' # For each checkpoint file sgd_losses_file = os.path.join(SGD_DIR, sgd_losses_f) asgd_losses_file = os.path.join(SAASGD_DIR, asgd_losses_f) asgd_stats_file = os.path.join(SAASGD_DIR, asgd_stats_f) staleness_distr_file = os.path.join(SAASGD_DIR, staleness_distr_f) SGD_weight_properties_file = os.path.join(SGD_DIR, SGD_weight_properties_f) ASGD_weight_properties_file = os.path.join(SAASGD_DIR, ASGD_weight_properties_f) true_weight_properties_file = os.path.join(SGD_DIR, true_weight_properties_f) if RUNS_REGULAR_SGD > 0: #RETRIEVE LOSSES losses_file = sgd_losses_file if os.path.exists(losses_file): with open(losses_file, 'rb') as f: SGD_losses = pickle.load(f) logging.info(f"Resuming: {len(SGD_losses)}/{len(seeds)} seeds done") else: SGD_losses = [] logging.info("Starting fresh, no existing losses file found") # RETRIEVE/INIT WEIGHT PROPERTIES if os.path.exists(SGD_weight_properties_file): with open(SGD_weight_properties_file, 'rb') as f: SGD_weight_properties = pickle.load(f) else: if len(SGD_losses) == 0: SGD_weight_properties = [] else: # In the case that you start tracking after some runs already have been computed SGD_weight_properties = [None] * len(SGD_losses) logging.info("Starting fresh on weigth metrics") # RETRIEVE/INIT TRUE WEIGHT PROPERTIES if os.path.exists(true_weight_properties_file): with open(true_weight_properties_file, 'rb') as f: true_weights = pickle.load(f) else: if len(SGD_losses) == 0: true_weights = [] else: # In the case that you start tracking after some runs already have been computed true_weights = [None] * len(SGD_losses) logging.info("Starting fresh on weigth metrics") # Pick up where you left off start_idx = len(SGD_losses) for idx in range(start_idx, len(seeds)): seed = seeds[idx] if RUNS_REGULAR_SGD == 0: print("Performed the specified amount of runs for regular SGD") break RUNS_REGULAR_SGD = RUNS_REGULAR_SGD - 1 # full splits => Always the same when using the same seed X_tr_lin, y_tr_lin, X_val_lin, y_val_lin, X_te_lin, y_te_lin, true_w = load_linear_data(n_samples=n_samples, n_features=n_features, noise=0.0,val_size=0.01,test_size=0.2, random_state= seed) X_comb = np.vstack([X_tr_lin, X_val_lin]) y_comb = np.concatenate([y_tr_lin, y_val_lin]) n_trainval = X_comb.shape[0] batch_size = max(1, int(0.1 * n_trainval)) # 3) Compute 95% of max stable step size η₉₅ _, S_comb, _ = svd(X_comb, full_matrices=False) eta_max = 2.0 / (S_comb[0]**2) eta_95 = 0.95 * eta_max start = time.perf_counter() sgd_model = sgd_training(X_comb, y_comb, num_epochs = 10000, criterion = nn.MSELoss(), batch_size = batch_size, lr = eta_95, tol=1e-8) end = time.perf_counter() sgd_time = end-start # Compute weight metrics on true weight vector true_m_gd = {'l2':l2_norm(true_w),'sparsity':sparsity_ratio(true_w),'kurtosis':weight_kurtosis(true_w)} true_weights.append(true_m_gd) # collect each parameter, detach from graph, move to CPU numpy, flatten weight_vectors = [] for param in sgd_model.parameters(): weight_vectors.append(param.detach().cpu().numpy().reshape(-1)) w = np.concatenate(weight_vectors) # Compute your three metrics m_gd = {'l2':l2_norm(w),'sparsity':sparsity_ratio(w),'kurtosis':weight_kurtosis(w)} SGD_weight_properties.append(m_gd) SGD_loss = evaluate_model("SGD", sgd_model, X_te_lin, y_te_lin) SGD_losses.append(SGD_loss) print("Time Comparison for run:" + str(idx) + f":SGD {sgd_time:2f} sec") # SAVE LOSSES with open(sgd_losses_file, 'wb') as f: pickle.dump(SGD_losses, f) with open(sgd_losses_file, 'rb') as f: SGD_losses = pickle.load(f) print("Retrieved regular SGD losses") avg_SGD_loss = sum(SGD_losses)/len(SGD_losses) print("Average SGD loss =" + str(avg_SGD_loss)) # SAVE WEIGHT METRICS/PROPERTIES with open(SGD_weight_properties_file, 'wb') as f: pickle.dump(SGD_weight_properties, f) # SAVE TRUE WEIGHT METRICS/PROPERTIES with open(true_weight_properties_file, 'wb') as f: pickle.dump(true_weights, f) if RUNS_ASGD > 0: # INIT/RETRIEVE LOSSES losses_file = asgd_losses_file if os.path.exists(losses_file): with open(losses_file, 'rb') as f: ASGD_losses = pickle.load(f) logging.info(f"Resuming: {len(ASGD_losses)}/{len(seeds)} seeds done") else: ASGD_losses = [] logging.info("Starting fresh, no existing losses file found") # INIT/RETRIEVE WORKER STATS if os.path.exists(asgd_stats_file): with open(asgd_stats_file, 'rb') as f: ASGD_stats = pickle.load(f) logging.info(f"Resuming stats: {len(ASGD_stats)}/{len(seeds)} done") else: ASGD_stats = [] logging.info("Starting fresh on stats") #INIT/RETRIEVE STALENESS DISTR if os.path.exists(staleness_distr_file): with open(staleness_distr_file, 'rb') as f: ASGD_staleness_distributions = pickle.load(f) logging.info(f"Resuming staleness distr: {len(ASGD_staleness_distributions)}/{len(seeds)} done") else: if len(ASGD_losses) == 0: ASGD_staleness_distributions = [] else: # In the case that you start tracking these distributions after some runs already have been computed ASGD_staleness_distributions = [None] * len(ASGD_losses) logging.info("Starting fresh on staleness distr") # INIT/RETRIEVE WEIGHT METRICS/PROPERTIES if os.path.exists(ASGD_weight_properties_file): with open(ASGD_weight_properties_file, 'rb') as f: ASGD_weight_properties = pickle.load(f) logging.info(f"Resuming weight properties: {len(ASGD_weight_properties)}/{len(seeds)} done") else: if len(ASGD_losses) == 0: ASGD_weight_properties = [] else: # In the case that you start tracking these distributions after some runs already have been computed ASGD_weight_properties = [None] * len(ASGD_losses) logging.info("Starting fresh on ASGD weight properties") # Pick up where you left off start_idx = len(ASGD_losses) for idx in range(start_idx, len(seeds)): seed = seeds[idx] if RUNS_ASGD == 0: print("Performed the specified amount of runs for ASGD") break RUNS_ASGD = RUNS_ASGD - 1 # full splits => Always the same when using the same seed X_tr_lin, y_tr_lin, X_val_lin, y_val_lin, X_te_lin, y_te_lin, true_weight = load_linear_data(n_samples= n_samples, n_features= n_features, noise=0.0, val_size=0.01,test_size=0.2, random_state=seed) X_comb = np.vstack([X_tr_lin, X_val_lin]) y_comb = np.concatenate([y_tr_lin, y_val_lin]) n_trainval = X_comb.shape[0] batch_size = max(1, int(0.1 * n_trainval)) # 3) Compute 95% of max stable step size η₉₅ _, S_comb, _ = svd(X_comb, full_matrices=False) eta_max = 2.0 / (S_comb[0]**2) eta_95 = 0.95 * eta_max # Dataset builder function dataset_builder = FullDataLoaderBuilder(X_comb, y_comb) # Model class model = LinearNetModel # Set up the configuration for the SSP training params_ssp = ConfigParameters( num_workers = 10, staleness = 50, lr = eta_95, # DEPENDING ON ALGO THIS HAS TO BE CHANGED ! local_steps = 10000, batch_size = batch_size, device = "cuda" if torch.cuda.is_available() else "cpu", log_level = logging.DEBUG, tol = 1e-8, # The tol for workers is currently set at tol = 1e-8 Amplitude = 1 # The max amplitude IN ASAP ) # Run the SSP training and measure the time taken start = time.perf_counter() asgd_params, dim, stats, staleness_distr = run_training(dataset_builder, model, params_ssp, parameter_server=ParameterServerSAASGD) end = time.perf_counter() asgd_time = end - start ASGD_stats.append(stats) # Compute staleness distribution freq = np.array(staleness_distr) / sum(staleness_distr) # normalize to probabilities ASGD_staleness_distributions.append(freq) # Evaluate the trained model on the test set asgd_model = build_model(asgd_params, model, dim) flat_parts = [] for param in asgd_model.parameters(): flat_parts.append(param.detach().cpu().numpy().reshape(-1)) w_asgd = np.concatenate(flat_parts) # Compute weight metrics/properties m_asgd = {'l2':l2_norm(w_asgd),'sparsity': sparsity_ratio(w_asgd),'kurtosis': weight_kurtosis(w_asgd)} ASGD_weight_properties.append(m_asgd) ASGD_loss = evaluate_model("ASGD", asgd_model, X_te_lin, y_te_lin) ASGD_losses.append(ASGD_loss) print("Time Comparison for run:" + str(idx) + f": ASGD {asgd_time:2f} sec") # SAVE THE LOSSES with open(asgd_losses_file, 'wb') as f: pickle.dump(ASGD_losses, f) with open(asgd_losses_file, 'rb') as f: ASGD_losses = pickle.load(f) print("Retrieved ASGD losses") avg_ASGD_loss = sum(ASGD_losses)/len(ASGD_losses) print("Average ASGD loss =" + str(avg_ASGD_loss)) #SAVE THE WORKER STATS with open(asgd_stats_file, 'wb') as f: pickle.dump(ASGD_stats, f) # SAVE THE STALENESS DISTRIBUTIONS with open(staleness_distr_file, 'wb') as f: pickle.dump(ASGD_staleness_distributions, f) # SAVE THE WEIGHT METRICS/PROPERTIES with open(ASGD_weight_properties_file, 'wb') as f: pickle.dump(ASGD_weight_properties, f) # COMPARE LOSSES FOR THE SEEDS THAT HAVE BEEN USED IN BOTH METHODS UNTIL NOW # Align lengths (in case one list is longer because of incomplete runs) n = min(len(SGD_losses), len(ASGD_losses)) sgd_losses = SGD_losses[:n] asgd_losses = ASGD_losses[:n] # Compute difference: SGD_loss - ASGD_loss diffs = np.array(sgd_losses) - np.array(asgd_losses) # COMPUTE PAIRED T-TEST if n > 1: t_stat, p_value = stats_mod.ttest_rel(sgd_losses, asgd_losses, nan_policy='omit') print(f"Paired t-test over {n} runs:") print(f" t-statistic = {t_stat:.4f}") print(f" p-value = {p_value:.4e}") # Summary statistics mean_diff = np.mean(diffs) median_diff = np.median(diffs) std_diff = np.std(diffs) print(f"Computed over {n} seeds:") print(f"Mean difference (SGD - ASGD): {mean_diff:.4e}") print(f"Median difference: {median_diff:.4e}") print(f"Std of difference: {std_diff:.4e}") # Plot histogram of differences plt.figure() plt.hist(diffs, bins=20, edgecolor='black') plt.axvline(mean_diff, color='red', linestyle='dashed', linewidth=1, label=f"Mean: {mean_diff:.2e}") plt.axvline(median_diff, color='blue', linestyle='dotted', linewidth=1, label=f"Median: {median_diff:.2e}") plt.xlabel("SGD_loss - ASGD_loss") plt.ylabel("Frequency") plt.title("Distribution of Loss Differences (SGD vs. ASGD)") plt.legend() plt.tight_layout() plt.show() # VISUALIZE THE STALENESS DISTRIBUTION OF THE LAST 3 RUNS #–– Extract the last three runs last3 = ASGD_staleness_distributions[-3:] # list of length 3, each shape (S+1,) taus = np.arange(last3[0].shape[0]) # 0 … max staleness fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for ax, freq, run_idx in zip( axes, last3, range(len(ASGD_staleness_distributions)-3, len(ASGD_staleness_distributions)) ): ax.bar(taus, freq, edgecolor='k', alpha=0.7) ax.set_title(f"Run {run_idx}") ax.set_xlabel("τ") axes[0].set_ylabel("P(τ)") fig.suptitle("Last 3 Runs: Staleness Distributions") plt.tight_layout() plt.show() # COMPARE THE WEIGHT METRICS/PROPERTIES # 1) Make a mask of valid runs M = min(len(SGD_weight_properties), len(ASGD_weight_properties), len(true_weights)) mask = np.array([ (SGD_weight_properties[i] is not None) and (ASGD_weight_properties[i] is not None) and (true_weights[i] is not None) for i in range(M) ]) keys = ('l2','sparsity','kurtosis') # build the arrays of shape (N,3) sgd_arr = np.vstack([ [SGD_weight_properties[i][k] for k in keys] for i in range(M) if mask[i] ]) asgd_arr = np.vstack([ [ASGD_weight_properties[i][k] for k in keys] for i in range(M) if mask[i] ]) true_arr = np.vstack([ [true_weights[i][k] for k in keys] for i in range(M) if mask[i] ]) N = sgd_arr.shape[0] # 3) Paired differences diffs = sgd_arr - asgd_arr # shape (N,3) # Descriptive summaries and confidence intervals for j,key in enumerate(keys): d = diffs[:,j] m, s = d.mean(), d.std(ddof=1) ci_low, ci_high = stats_mod.t.interval( 0.95, df=N-1, loc=m, scale=s/np.sqrt(N)) print(f"{key}: mean diff = {m:.4f}, 95% CI = [{ci_low:.4f}, {ci_high:.4f}]") # Paired hypothesis testing and Effect-size (Cohen’s d for paired data) for j,key in enumerate(keys): d = diffs[:,j] d_mean, d_std = d.mean(), d.std(ddof=1) cohens_d = d_mean / d_std t_stat, p_t = stats_mod.ttest_rel(sgd_arr[:,j], asgd_arr[:,j]) p_w = stats_mod.wilcoxon(d).pvalue print(f"{key}: Cohen’s d = {cohens_d:.3f}") print(f"{key}: paired t-test p = {p_t:.3e}, wilcoxon p = {p_w:.3e}") # Correlation with generalization gap sgd_sel = np.array(SGD_losses[:M])[mask] asgd_sel= np.array(ASGD_losses[:M])[mask] loss_diff = sgd_sel - asgd_sel for j,key in enumerate(keys): r, p = stats_mod.pearsonr(diffs[:,j], loss_diff) print(f"Corr(loss_diff, {key}_diff): r = {r:.3f}, p = {p:.3e}") # Boxplot fig, axes = plt.subplots(1,3,figsize=(12,4)) for j,key in enumerate(keys): axes[j].boxplot([sgd_arr[:,j], asgd_arr[:,j]], labels=['SGD','ASGD']) axes[j].set_title(key) plt.tight_layout(); plt.show() for j,key in enumerate(keys): plt.figure() plt.scatter(sgd_arr[:,j], asgd_arr[:,j], alpha=0.7) lim = max(sgd_arr[:,j].max(), asgd_arr[:,j].max()) plt.plot([0,lim],[0,lim], linestyle='--') plt.xlabel('SGD'); plt.ylabel('ASGD'); plt.title(key) plt.tight_layout(); plt.show() delta_sgd = np.abs(sgd_arr - true_arr) # how far each run’s SGD metrics sit from its ground truth delta_asgd = np.abs(asgd_arr - true_arr) # — now compute distance-to-teacher for each method — # average signed difference in *distance* to teacher: for j,key in enumerate(keys): # negative means ASGD is *closer* (on average) to the teacher than SGD mean_dist_diff = delta_sgd[:,j].mean() - delta_asgd[:,j].mean() print(f"{key}: mean(|SGD-teacher| - |ASGD-teacher|) = {mean_dist_diff:.4f}") # you can also do a paired test on these distances: for j,key in enumerate(keys): d = delta_sgd[:,j] - delta_asgd[:,j] t_stat, pval = stats_mod.ttest_rel(delta_sgd[:,j], delta_asgd[:,j]) print(f"{key}: paired t-test on dist-to-teacher p = {pval:.3e}") # — and finally, overlay the teacher’s *average* metric in your boxplots — teacher_means = true_arr.mean(axis=0) fig, axes = plt.subplots(1,3,figsize=(12,4)) for j,key in enumerate(keys): axes[j].boxplot([sgd_arr[:,j], asgd_arr[:,j]], labels=['SGD','ASGD']) # horizontal line at the *average* teacher metric axes[j].axhline(teacher_means[j], color='C2', linestyle='--', label='teacher') axes[j].set_title(key) axes[j].legend() plt.tight_layout() plt.show()
if __name__ == "__main__": try: main() except Exception as e: logging.error(f"An error occurred: {e}") sys.exit(1)