Source code for src.data.linear

"""Synthetic over-parameterised linear regression dataset builder."""

from __future__ import annotations
import numpy as np
from .base import AbstractDataBuilder

[docs] class LinearRegressionBuilder(AbstractDataBuilder): def __init__( self, num_workers: int, n_samples: int = 100, n_features: int = 110, noise: float = 0.0, shard: bool = False, seed: int | None = None, ): super().__init__(num_workers, seed) self.n_samples = n_samples self.n_features = n_features self.noise = noise self.shard = shard self._X, self._y = self._build_dataset() def _create_linear_dataset( n_samples: int, n_features: int, noise: float, rng: np.random.RandomState, ): X = rng.uniform(-5, 5, size=(n_samples, n_features)).astype(np.float32) w = rng.randn(n_features) y = (X @ w + noise * rng.randn(n_samples)).astype(np.float32) return X, y def _build_dataset(self): rng = np.random.RandomState(self.seed) return self._create_linear_dataset( self.n_samples, self.n_features, self.noise, rng ) def _slice_for_worker(self, X, y, worker_id): if not self.shard: return X, y return X[worker_id :: self.num_workers], y[worker_id :: self.num_workers]