Skip to content

API Reference

GAM

Model specification and fit orchestration. fit() returns a GAMResults frozen dataclass containing all fitted state.

from jaxgam import GAM, GAMResults

model = GAM("y ~ s(x)", family="gaussian")
results = model.fit(data)
results.predict(newdata)
results.summary()

GAM(formula: str, family: str | ExponentialFamily = 'gaussian', method: str = 'REML', sp: np.ndarray | list | None = None, **kwargs)

Generalized Additive Model specification.

This class holds model specification parameters and orchestrates the fit pipeline. Calling fit() returns a GAMResults frozen dataclass containing all fitted state.

Parameters:

Name Type Description Default
formula str

Model formula in R-style Wilkinson notation, e.g. "y ~ s(x)".

required
family str or ExponentialFamily

Distribution family. One of 'gaussian', 'binomial', 'poisson', 'gamma', or an ExponentialFamily instance.

'gaussian'
method str

Smoothing parameter estimation method: 'REML' or 'ML'.

'REML'
sp ndarray or list

Fixed smoothing parameters. If provided, skips Newton optimization.

None
device str

Target device: 'cpu', 'gpu', or None (auto-detect). GPU requires jax[cuda12] (NVIDIA) or jax-metal (Apple).

required
**kwargs

Additional arguments. Supported scope guards: backend, optimizer, select, gamma, knots.

{}

Examples:

>>> model = GAM("y ~ s(x)", family="gaussian")
>>> results = model.fit(data)
>>> results.predict(newdata)
array([...])

Design doc reference: docs/refactor_gam_api/design.md §3.3

fit(data: pd.DataFrame | dict, weights: np.ndarray | None = None, offset: np.ndarray | None = None) -> GAMResults

Fit the GAM to data.

Parameters:

Name Type Description Default
data DataFrame or dict

Data frame containing the variables in the formula.

required
weights ndarray

Prior weights, shape (n,).

None
offset ndarray

Offset vector, shape (n,).

None

Returns:

Type Description
GAMResults

Frozen dataclass containing all fitted state.

Design doc reference: docs/refactor_gam_api/design.md §3.3

GAMResults

Immutable results object returned by GAM.fit(). All post-estimation methods (prediction, summary, plotting) live here.

from jaxgam import GAMResults

GAMResults(coefficients: np.ndarray, fitted_values: np.ndarray, linear_predictor: np.ndarray, Vp: np.ndarray, scale: float, edf: np.ndarray, edf1: np.ndarray, edf_total: float, deviance: float, null_deviance: float, smoothing_params: np.ndarray, converged: bool, n_iter: int, score: float, family: ExponentialFamily, setup: ModelSetup, coef_map: CoefficientMap, smooth_info: tuple[SmoothInfo, ...], term_names: tuple[str, ...], X: np.ndarray, y: np.ndarray, weights: np.ndarray, offset: np.ndarray | None, theta: float | None, n: int, execution_path: str, lambda_strategy: str, formula: str, method: str, training_data: dict[str, np.ndarray]) dataclass

Results from a fitted GAM.

All attributes are read-only (frozen dataclass). This object is the primary interface for post-estimation: prediction, inference, and visualization.

Design doc reference: docs/refactor_gam_api/design.md §3.4

predict(newdata: pd.DataFrame | dict | None = None, pred_type: str = 'response', se_fit: bool = False, offset: np.ndarray | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]

Predict from a fitted GAM.

Parameters:

Name Type Description Default
newdata DataFrame or dict

New data for prediction. If None, uses the training data.

None
pred_type str

Type of prediction: 'response' or 'link'.

'response'
se_fit bool

Whether to return standard errors.

False
offset array - like

Offset for new data predictions.

None

Returns:

Type Description
ndarray or tuple[ndarray, ndarray]

Predictions, or (predictions, standard_errors) if se_fit=True.

predict_matrix(newdata: pd.DataFrame | dict) -> np.ndarray

Build constrained prediction matrix for new data.

Equivalent to R's predict.gam(type="lpmatrix").

Parameters:

Name Type Description Default
newdata DataFrame or dict

New data for prediction.

required

Returns:

Type Description
np.ndarray, shape ``(n_new, total_coefs)``

Constrained prediction matrix.

summary() -> GAMSummary

Print and return summary of a fitted GAM.

Computes parametric coefficient significance (z/t tests), smooth term significance (Wood 2013 testStat), and model-level statistics (R-squared, deviance explained, scale estimate).

Returns:

Type Description
GAMSummary

Summary object with parametric and smooth term tables. The summary is also printed to stdout.

plot(select: int | list | None = None, pages: int = 0, rug: bool = True, se: bool = True, shade: bool = True, **kwargs) -> tuple[matplotlib.figure.Figure, np.ndarray]

Plot smooth components of a fitted GAM.

Equivalent to R's plot.gam().

Parameters:

Name Type Description Default
select int, list, or None

Select specific smooth term(s) to plot (0-indexed).

None
pages int

Number of pages. 0 means automatic layout.

0
rug bool

Show rug marks at data covariate values.

True
se bool

Show standard error bands.

True
shade bool

If True, use shaded SE bands; if False, use dashed lines.

True
**kwargs

Additional arguments passed to plot_gam().

{}

Returns:

Name Type Description
fig Figure

The figure.

axes ndarray

Array of Axes objects.


Families

Distribution families for the response variable. Normally specified as a string (family="gaussian") when constructing a GAM, but the classes can be used directly for custom link functions.

from jaxgam.families import Gaussian, Binomial, Poisson, Gamma, NegativeBinomial

Gaussian

Gaussian(link: str | Link | None = None)

Bases: ExponentialFamily

Gaussian (normal) family with V(mu) = 1.

Parameters:

Name Type Description Default
link str or Link or None

Link function. Default is identity.

None

variance(mu: np.ndarray) -> np.ndarray

V(mu) = 1 for all mu.

deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray

Deviance residuals: sign(y - mu) * sqrt(wt * (y - mu)^2).

The unit deviance for Gaussian is (y - mu)^2.

Binomial

Binomial(link: str | Link | None = None)

Bases: ExponentialFamily

Binomial family with V(mu) = mu * (1 - mu).

Parameters:

Name Type Description Default
link str or Link or None

Link function. Default is logit.

None

variance(mu: np.ndarray) -> np.ndarray

V(mu) = mu * (1 - mu).

deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray

Deviance residuals for Binomial.

Unit deviance: 2 * [y * log(y/mu) + (1-y) * log((1-y)/(1-mu))] with edge-case handling for y=0 and y=1.

Matches R's binomial()$dev.resids.

Poisson

Poisson(link: str | Link | None = None)

Bases: ExponentialFamily

Poisson family with V(mu) = mu.

Parameters:

Name Type Description Default
link str or Link or None

Link function. Default is log.

None

variance(mu: np.ndarray) -> np.ndarray

V(mu) = mu.

deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray

Deviance residuals for Poisson.

Unit deviance: 2 * [y * log(y/mu) - (y - mu)] with y=0 handled as a special case (term = 0).

Matches R's poisson()$dev.resids.

Gamma

Gamma(link: str | Link | None = None)

Bases: ExponentialFamily

Gamma family with V(mu) = mu^2.

Parameters:

Name Type Description Default
link str or Link or None

Link function. Default is inverse (1/mu).

None

variance(mu: np.ndarray) -> np.ndarray

V(mu) = mu^2.

deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray

Deviance residuals for Gamma.

Unit deviance: 2 * [-log(y/mu) + (y - mu)/mu] = -2 * [log(y/mu) - (y - mu)/mu]

Matches R's Gamma()$dev.resids.

Negative Binomial (Extended Family)

The Negative Binomial family models overdispersed count data. It is an extended family with an extra dispersion parameter theta that can be estimated alongside smoothing parameters, or fixed.

  • Variance: mu + mu^2 / theta
  • As theta -> infinity: NB approaches Poisson
  • Theta parameterization: theta > 0 (R's "size" parameter); alpha = 1/theta
from jaxgam.families import NegativeBinomial

# Estimate theta (default, starting from 1)
GAM("y ~ s(x)", family="nb").fit(data)
GAM("y ~ s(x)", family=NegativeBinomial()).fit(data)

# Estimate theta with a different starting value
GAM("y ~ s(x)", family=NegativeBinomial(theta=3)).fit(data)

# Fix theta at a known value
GAM("y ~ s(x)", family=NegativeBinomial(theta=2, fixed=True)).fit(data)

Constructor parameters: - theta (float, default 1.0): dispersion parameter (must be positive) - fixed (bool, default False): if True, theta is held constant during fitting

NegativeBinomial(theta: float = 1.0, *, fixed: bool = False, link: str | Link | None = None)

Bases: ExtendedFamily

Negative Binomial family with V(mu) = mu + mu^2/theta.

Parameters:

Name Type Description Default
theta float

Dispersion parameter (must be positive). When fixed=False (default), this is the starting value for estimation. When fixed=True, theta is held constant during fitting. Default is 1.0.

1.0
fixed bool

If False (default), theta is estimated during fitting (n_theta = 1). If True, theta is held constant (n_theta = 0).

False
link str or Link or None

Link function. Default is log. Supported: "log", "identity", "sqrt".

None

Examples:

>>> fam = NegativeBinomial()                    # estimate theta, start at 1
>>> fam = NegativeBinomial(theta=3)             # estimate theta, start at 3
>>> fam = NegativeBinomial(theta=2, fixed=True) # fix theta = 2

variance(mu: np.ndarray) -> np.ndarray

V(mu) = mu + mu^2/theta.

deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray

Per-observation deviance residuals.

R: efam.r lines 199-205.

Unit deviance: 2 * wt * [y * log(max(1,y)/mu) - (y+theta) * log1p((y-mu)/(mu+theta))]

The log1p rewrite avoids catastrophic cancellation when theta is large and (y+theta)/(mu+theta) ≈ 1.

get_theta(transformed: bool = False) -> np.ndarray

Extra parameter vector, shape (n_theta,) = (1,) for NB.

Parameters:

Name Type Description Default
transformed bool

If True, return exp(log_theta) (natural scale).

False

put_theta(log_theta: np.ndarray) -> None

Set log(theta) vector. Called by Newton after each accepted step.

ExtendedFamily Base Class

Base class for families with extra distributional parameters estimated via Newton optimization. NegativeBinomial inherits from this. Future extended families (Tweedie, Beta, etc.) will also subclass it.

ExtendedFamily(link: str | Link | None = None)

Bases: ExponentialFamily

Base class for families with extra parameters estimated via Newton.

Subclasses must implement:

  • get_theta / put_theta: mutable theta state for PIRLS runtime
  • deviance_fn: pure-function factory D(eta, log_theta_vec) -> scalar for the custom_jvp on PIRLS
  • working_weights_fn: pure-function factory W(eta, log_theta_vec) -> (n,) for the custom_jvp on PIRLS
  • saturated_loglik_theta: explicit-theta saturated log-likelihood for the REML criterion AD trace

The fitting code branches on family.n_theta > 0 (compile-time check via static family arg) to select the extended custom_jvp path.

R source reference: efam.r (extended family objects)

get_theta(transformed: bool = False) -> np.ndarray abstractmethod

Extra parameter vector, shape (n_theta,).

Returns log-scale by default. When transformed=True, returns natural scale (e.g. exp(log_theta) for positive parameters).

Parameters:

Name Type Description Default
transformed bool

If True, return parameters on the natural (not log) scale.

False

Returns:

Type Description
(ndarray, shape(n_theta))

Parameter vector.

put_theta(log_theta: np.ndarray) -> None abstractmethod

Set extra parameter vector (log-scale), shape (n_theta,).

Called by the Newton optimizer after each accepted step.

Parameters:

Name Type Description Default
log_theta (ndarray, shape(n_theta))

New parameter values on log scale.

required

deviance_fn(y: np.ndarray, wt: np.ndarray) abstractmethod

Return pure JAX function D(eta, log_theta_vec) -> scalar.

log_theta_vec has shape (n_theta,).

Used by the custom_jvp for IFT theta terms and joint JVPs. Must capture (y, wt, link) in closure; theta is an explicit argument for AD tracing.

Parameters:

Name Type Description Default
y (ndarray, shape(n))

Response values.

required
wt (ndarray, shape(n))

Prior weights.

required

Returns:

Type Description
Callable[[Array, Array], Array]

Pure function (eta, log_theta_vec) -> scalar_deviance.

working_weights_fn(wt: np.ndarray) abstractmethod

Return pure JAX function W(eta, log_theta_vec) -> (n,) array.

log_theta_vec has shape (n_theta,).

Used by the custom_jvp for joint dW JVPs. Must capture (wt, link) in closure; theta is an explicit argument.

Parameters:

Name Type Description Default
wt (ndarray, shape(n))

Prior weights.

required

Returns:

Type Description
Callable[[Array, Array], Array]

Pure function (eta, log_theta_vec) -> working_weights.

saturated_loglik_theta(y: np.ndarray, wt: np.ndarray, scale: float, log_theta: np.ndarray, *, max_y: int = 0) abstractmethod

Saturated log-likelihood with explicit theta for AD trace.

log_theta has shape (n_theta,).

Called inside _diff_score where log_theta is a traced JAX array. jax.grad differentiates through this w.r.t. log_theta.

Parameters:

Name Type Description Default
y (ndarray, shape(n))

Response values.

required
wt (ndarray, shape(n))

Prior weights.

required
scale float

Dispersion parameter.

required
log_theta (ndarray, shape(n_theta))

Extra parameters on log scale (traced JAX value).

required
max_y int

Maximum count in y. Controls the lax.scan loop bound. Must be a compile-time constant.

0

Returns:

Type Description
(Array, scalar)

Saturated log-likelihood.


Formula syntax

Models are specified with R-style formulas:

# Single smooth
GAM("y ~ s(x)")

# Multiple smooths
GAM("y ~ s(x1) + s(x2)")

# Tensor product
GAM("y ~ te(x1, x2, k=5)")

# Factor-by smooth
GAM("y ~ s(x, by=fac, k=10) + fac")

Smooth term arguments

Argument Description Default
k Basis dimension (number of knots). -1 means auto-select: resolves to 10 for 1D TPRS/cubic, 30 for 2D TPRS. -1 (auto)
bs Basis type: 'tp', 'ts', 'cr', 'cs', 'cc' 'tp'
by Factor variable for factor-by smooths None

Tensor product arguments

Argument Description Default
k Marginal basis dimension (scalar applied to all margins). -1 means auto-select (resolves to 10 for the default cr marginals). -1 (auto)

Use te() for full tensor products and ti() for interaction-only terms (excludes main effects).


Custom registrations

JaxGAM ships with built-in smooths, families, and links, but you can register your own at runtime. Custom entries extend the registry without modifying or removing built-in entries.

Registering a custom smooth

Your class must subclass jaxgam.smooths.Smooth.

from jaxgam.smooths import smooth_registry, Smooth

class PSplineSmooth(Smooth):
    ...  # implement setup(), basis(), penalty(), etc.

smooth_registry.register("ps", PSplineSmooth)

# Now usable in formulas:
GAM("y ~ s(x, bs='ps')").fit(data)

Registering a custom family

Your class must subclass jaxgam.families.ExponentialFamily.

from jaxgam.families import family_registry, ExponentialFamily

class NegativeBinomial(ExponentialFamily):
    ...  # implement variance(), deviance_resids(), initialize(), etc.

family_registry.register("nb", NegativeBinomial)

# Now usable by name:
GAM("y ~ s(x)", family="nb").fit(data)

Your class must subclass jaxgam.links.Link.

from jaxgam.links import link_registry, Link

class CauchitLink(Link):
    ...  # implement link(), inverse(), derivative()

link_registry.register("cauchit", CauchitLink)

Rules

  • Keys are case-insensitive"PS" and "ps" are the same key.
  • You cannot override a built-in or previously registered key. Attempting to do so raises ValueError.
  • Registrations are global and take effect immediately — any subsequent GAM call will see the new entry.

Inspecting a registry

from jaxgam.smooths import smooth_registry

smooth_registry.available      # ('cc', 'cr', 'cs', 'te', 'ti', 'tp', 'ts')
"tp" in smooth_registry        # True
len(smooth_registry)           # 7

GAMResults attributes

The GAMResults object returned by fit() exposes all fitted state as read-only attributes (frozen dataclass):

Attribute Type Description
coefficients ndarray (p,) Coefficient vector
fitted_values ndarray (n,) Fitted values on response scale
linear_predictor ndarray (n,) Linear predictor
Vp ndarray (p, p) Bayesian covariance matrix
edf ndarray Per-smooth effective degrees of freedom
edf1 ndarray Alternative EDF for significance testing
edf_total float Total effective degrees of freedom
scale float Estimated scale (dispersion) parameter
deviance float Model deviance
null_deviance float Null model deviance
smoothing_params ndarray Estimated smoothing parameters
converged bool Whether the optimizer converged
n_iter int Number of Newton iterations
score float REML/ML value at convergence
X ndarray (n, p) Design matrix
y ndarray (n,) Response vector
weights ndarray (n,) Prior weights
family ExponentialFamily Family object used for fitting
formula str Model formula
theta float \| None Estimated theta for NB (None for standard families)
method str Smoothing parameter method ("REML" or "ML")
n int Number of observations