MMM.approximate_fit#

MMM.approximate_fit(X, y=None, progressbar=None, random_seed=None, *, fit_kwargs=None, sample_kwargs=None)[source]#

Fit a model using Variational Inference and return a DataTree.

This performs variational inference via pymc.fit, then draws posterior samples from the fitted approximation via Approximation.sample, returning an xr.DataTree compatible with the rest of the API (same structure as .fit).

Parameters:
Xarray_like | array, shape (n_obs, n_features)

The training input samples. If scikit-learn is available, array-like, otherwise array.

yarray_like | array, shape (n_obs,)

The target values (real numbers). If scikit-learn is available, array-like, otherwise array.

progressbarbool, optional

Specifies whether the fitting/sample progress bar should be displayed. Defaults to True.

random_seedOptional[RandomState]

Provides stochastic procedures with initial random seed for reproducibility.

fit_kwargsdict, optional

Extra keyword arguments forwarded to pymc.fit (e.g., {“n”: 10_000, “method”: “advi”}).

sample_kwargsdict, optional

Extra keyword arguments forwarded to Approximation.sample (e.g., {“draws”: 1_000}).

Returns:
xr.DataTree

DataTree of the variationally fitted model.