Source code for src.data.poly

"""Non-linear “poly-varied” synthetic regression builder."""

from __future__ import annotations
import numpy as np
from .base import AbstractDataBuilder

[docs] class PolyVariedBuilder(AbstractDataBuilder): def __init__( self, num_workers: int, n_samples: int = 100, n_features: int = 110, max_degree: int = 4, noise: float = 0.0, shard: bool = False, seed: int | None = None, ): super().__init__(num_workers, seed) self.n_samples = n_samples self.n_features = n_features self.max_degree = max_degree self.noise = noise self.shard = shard self._X, self._y, self.degrees = self._build_dataset() def _create_poly_varied_dataset( n_samples: int, n_features: int, max_degree: int, noise: float, rng: np.random.RandomState, ): X = rng.uniform(-3, 3, size=(n_samples, n_features)) w = rng.randn(n_features) degrees = rng.randint(1, max_degree + 1, size=n_features) X_pow = np.stack([X[:, i] ** d for i, d in enumerate(degrees)], axis=1) y = X_pow @ w + noise * rng.randn(n_samples) return X.astype(np.float32), y.astype(np.float32), degrees def _build_dataset(self): rng = np.random.RandomState(self.seed) return self._create_poly_varied_dataset( self.n_samples, self.n_features, self.max_degree, self.noise, rng, ) def _slice_for_worker(self, X, y, worker_id): if not self.shard: return X, y return X[worker_id :: self.num_workers], y[worker_id :: self.num_workers]