API Reference¶
GAM¶
Model specification and fit orchestration. fit() returns a GAMResults
frozen dataclass containing all fitted state.
from jaxgam import GAM, GAMResults
model = GAM("y ~ s(x)", family="gaussian")
results = model.fit(data)
results.predict(newdata)
results.summary()
GAM(formula: str, family: str | ExponentialFamily = 'gaussian', method: str = 'REML', sp: np.ndarray | list | None = None, **kwargs)
¶
Generalized Additive Model specification.
This class holds model specification parameters and orchestrates the
fit pipeline. Calling fit() returns a GAMResults frozen
dataclass containing all fitted state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
formula
|
str
|
Model formula in R-style Wilkinson notation, e.g. |
required |
family
|
str or ExponentialFamily
|
Distribution family. One of |
'gaussian'
|
method
|
str
|
Smoothing parameter estimation method: |
'REML'
|
sp
|
ndarray or list
|
Fixed smoothing parameters. If provided, skips Newton optimization. |
None
|
device
|
str
|
Target device: |
required |
**kwargs
|
Additional arguments. Supported scope guards:
|
{}
|
Examples:
>>> model = GAM("y ~ s(x)", family="gaussian")
>>> results = model.fit(data)
>>> results.predict(newdata)
array([...])
Design doc reference: docs/refactor_gam_api/design.md §3.3
fit(data: pd.DataFrame | dict, weights: np.ndarray | None = None, offset: np.ndarray | None = None) -> GAMResults
¶
Fit the GAM to data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame or dict
|
Data frame containing the variables in the formula. |
required |
weights
|
ndarray
|
Prior weights, shape |
None
|
offset
|
ndarray
|
Offset vector, shape |
None
|
Returns:
| Type | Description |
|---|---|
GAMResults
|
Frozen dataclass containing all fitted state. |
Design doc reference: docs/refactor_gam_api/design.md §3.3
|
|
GAMResults¶
Immutable results object returned by GAM.fit(). All post-estimation
methods (prediction, summary, plotting) live here.
GAMResults(coefficients: np.ndarray, fitted_values: np.ndarray, linear_predictor: np.ndarray, Vp: np.ndarray, scale: float, edf: np.ndarray, edf1: np.ndarray, edf_total: float, deviance: float, null_deviance: float, smoothing_params: np.ndarray, converged: bool, n_iter: int, score: float, family: ExponentialFamily, setup: ModelSetup, coef_map: CoefficientMap, smooth_info: tuple[SmoothInfo, ...], term_names: tuple[str, ...], X: np.ndarray, y: np.ndarray, weights: np.ndarray, offset: np.ndarray | None, theta: float | None, n: int, execution_path: str, lambda_strategy: str, formula: str, method: str, training_data: dict[str, np.ndarray])
dataclass
¶
Results from a fitted GAM.
All attributes are read-only (frozen dataclass). This object is the primary interface for post-estimation: prediction, inference, and visualization.
Design doc reference: docs/refactor_gam_api/design.md §3.4
predict(newdata: pd.DataFrame | dict | None = None, pred_type: str = 'response', se_fit: bool = False, offset: np.ndarray | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]
¶
Predict from a fitted GAM.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
newdata
|
DataFrame or dict
|
New data for prediction. If None, uses the training data. |
None
|
pred_type
|
str
|
Type of prediction: |
'response'
|
se_fit
|
bool
|
Whether to return standard errors. |
False
|
offset
|
array - like
|
Offset for new data predictions. |
None
|
Returns:
| Type | Description |
|---|---|
ndarray or tuple[ndarray, ndarray]
|
Predictions, or |
predict_matrix(newdata: pd.DataFrame | dict) -> np.ndarray
¶
Build constrained prediction matrix for new data.
Equivalent to R's predict.gam(type="lpmatrix").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
newdata
|
DataFrame or dict
|
New data for prediction. |
required |
Returns:
| Type | Description |
|---|---|
np.ndarray, shape ``(n_new, total_coefs)``
|
Constrained prediction matrix. |
summary() -> GAMSummary
¶
Print and return summary of a fitted GAM.
Computes parametric coefficient significance (z/t tests), smooth term significance (Wood 2013 testStat), and model-level statistics (R-squared, deviance explained, scale estimate).
Returns:
| Type | Description |
|---|---|
GAMSummary
|
Summary object with parametric and smooth term tables. The summary is also printed to stdout. |
plot(select: int | list | None = None, pages: int = 0, rug: bool = True, se: bool = True, shade: bool = True, **kwargs) -> tuple[matplotlib.figure.Figure, np.ndarray]
¶
Plot smooth components of a fitted GAM.
Equivalent to R's plot.gam().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
select
|
int, list, or None
|
Select specific smooth term(s) to plot (0-indexed). |
None
|
pages
|
int
|
Number of pages. 0 means automatic layout. |
0
|
rug
|
bool
|
Show rug marks at data covariate values. |
True
|
se
|
bool
|
Show standard error bands. |
True
|
shade
|
bool
|
If True, use shaded SE bands; if False, use dashed lines. |
True
|
**kwargs
|
Additional arguments passed to |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
fig |
Figure
|
The figure. |
axes |
ndarray
|
Array of Axes objects. |
Families¶
Distribution families for the response variable. Normally specified as a
string (family="gaussian") when constructing a GAM, but the classes
can be used directly for custom link functions.
Gaussian¶
Gaussian(link: str | Link | None = None)
¶
Bases: ExponentialFamily
Gaussian (normal) family with V(mu) = 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link
|
str or Link or None
|
Link function. Default is identity. |
None
|
Binomial¶
Binomial(link: str | Link | None = None)
¶
Bases: ExponentialFamily
Binomial family with V(mu) = mu * (1 - mu).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link
|
str or Link or None
|
Link function. Default is logit. |
None
|
variance(mu: np.ndarray) -> np.ndarray
¶
V(mu) = mu * (1 - mu).
deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray
¶
Deviance residuals for Binomial.
Unit deviance: 2 * [y * log(y/mu) + (1-y) * log((1-y)/(1-mu))] with edge-case handling for y=0 and y=1.
Matches R's binomial()$dev.resids.
Poisson¶
Poisson(link: str | Link | None = None)
¶
Bases: ExponentialFamily
Poisson family with V(mu) = mu.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link
|
str or Link or None
|
Link function. Default is log. |
None
|
Gamma¶
Gamma(link: str | Link | None = None)
¶
Bases: ExponentialFamily
Gamma family with V(mu) = mu^2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
link
|
str or Link or None
|
Link function. Default is inverse (1/mu). |
None
|
Negative Binomial (Extended Family)¶
The Negative Binomial family models overdispersed count data. It is an extended family with an extra dispersion parameter theta that can be estimated alongside smoothing parameters, or fixed.
- Variance:
mu + mu^2 / theta - As theta -> infinity: NB approaches Poisson
- Theta parameterization:
theta > 0(R's "size" parameter);alpha = 1/theta
from jaxgam.families import NegativeBinomial
# Estimate theta (default, starting from 1)
GAM("y ~ s(x)", family="nb").fit(data)
GAM("y ~ s(x)", family=NegativeBinomial()).fit(data)
# Estimate theta with a different starting value
GAM("y ~ s(x)", family=NegativeBinomial(theta=3)).fit(data)
# Fix theta at a known value
GAM("y ~ s(x)", family=NegativeBinomial(theta=2, fixed=True)).fit(data)
Constructor parameters:
- theta (float, default 1.0): dispersion parameter (must be positive)
- fixed (bool, default False): if True, theta is held constant during fitting
NegativeBinomial(theta: float = 1.0, *, fixed: bool = False, link: str | Link | None = None)
¶
Bases: ExtendedFamily
Negative Binomial family with V(mu) = mu + mu^2/theta.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
theta
|
float
|
Dispersion parameter (must be positive). When |
1.0
|
fixed
|
bool
|
If |
False
|
link
|
str or Link or None
|
Link function. Default is log. Supported: |
None
|
Examples:
>>> fam = NegativeBinomial() # estimate theta, start at 1
>>> fam = NegativeBinomial(theta=3) # estimate theta, start at 3
>>> fam = NegativeBinomial(theta=2, fixed=True) # fix theta = 2
variance(mu: np.ndarray) -> np.ndarray
¶
V(mu) = mu + mu^2/theta.
deviance_resids(y: np.ndarray, mu: np.ndarray, wt: np.ndarray) -> np.ndarray
¶
Per-observation deviance residuals.
R: efam.r lines 199-205.
Unit deviance: 2 * wt * [y * log(max(1,y)/mu) - (y+theta) * log1p((y-mu)/(mu+theta))]
The log1p rewrite avoids catastrophic cancellation when
theta is large and (y+theta)/(mu+theta) ≈ 1.
get_theta(transformed: bool = False) -> np.ndarray
¶
Extra parameter vector, shape (n_theta,) = (1,) for NB.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformed
|
bool
|
If True, return |
False
|
put_theta(log_theta: np.ndarray) -> None
¶
Set log(theta) vector. Called by Newton after each accepted step.
ExtendedFamily Base Class¶
Base class for families with extra distributional parameters estimated
via Newton optimization. NegativeBinomial inherits from this.
Future extended families (Tweedie, Beta, etc.) will also subclass it.
ExtendedFamily(link: str | Link | None = None)
¶
Bases: ExponentialFamily
Base class for families with extra parameters estimated via Newton.
Subclasses must implement:
get_theta/put_theta: mutable theta state for PIRLS runtimedeviance_fn: pure-function factoryD(eta, log_theta_vec) -> scalarfor the custom_jvp on PIRLSworking_weights_fn: pure-function factoryW(eta, log_theta_vec) -> (n,)for the custom_jvp on PIRLSsaturated_loglik_theta: explicit-theta saturated log-likelihood for the REML criterion AD trace
The fitting code branches on family.n_theta > 0 (compile-time check
via static family arg) to select the extended custom_jvp path.
R source reference: efam.r (extended family objects)
get_theta(transformed: bool = False) -> np.ndarray
abstractmethod
¶
Extra parameter vector, shape (n_theta,).
Returns log-scale by default. When transformed=True, returns
natural scale (e.g. exp(log_theta) for positive parameters).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformed
|
bool
|
If True, return parameters on the natural (not log) scale. |
False
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(n_theta))
|
Parameter vector. |
put_theta(log_theta: np.ndarray) -> None
abstractmethod
¶
Set extra parameter vector (log-scale), shape (n_theta,).
Called by the Newton optimizer after each accepted step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_theta
|
(ndarray, shape(n_theta))
|
New parameter values on log scale. |
required |
deviance_fn(y: np.ndarray, wt: np.ndarray)
abstractmethod
¶
Return pure JAX function D(eta, log_theta_vec) -> scalar.
log_theta_vec has shape (n_theta,).
Used by the custom_jvp for IFT theta terms and joint JVPs.
Must capture (y, wt, link) in closure; theta is an explicit
argument for AD tracing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
(ndarray, shape(n))
|
Response values. |
required |
wt
|
(ndarray, shape(n))
|
Prior weights. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Array, Array], Array]
|
Pure function |
working_weights_fn(wt: np.ndarray)
abstractmethod
¶
Return pure JAX function W(eta, log_theta_vec) -> (n,) array.
log_theta_vec has shape (n_theta,).
Used by the custom_jvp for joint dW JVPs. Must capture
(wt, link) in closure; theta is an explicit argument.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
wt
|
(ndarray, shape(n))
|
Prior weights. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Array, Array], Array]
|
Pure function |
saturated_loglik_theta(y: np.ndarray, wt: np.ndarray, scale: float, log_theta: np.ndarray, *, max_y: int = 0)
abstractmethod
¶
Saturated log-likelihood with explicit theta for AD trace.
log_theta has shape (n_theta,).
Called inside _diff_score where log_theta is a traced
JAX array. jax.grad differentiates through this w.r.t.
log_theta.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
(ndarray, shape(n))
|
Response values. |
required |
wt
|
(ndarray, shape(n))
|
Prior weights. |
required |
scale
|
float
|
Dispersion parameter. |
required |
log_theta
|
(ndarray, shape(n_theta))
|
Extra parameters on log scale (traced JAX value). |
required |
max_y
|
int
|
Maximum count in |
0
|
Returns:
| Type | Description |
|---|---|
(Array, scalar)
|
Saturated log-likelihood. |
Formula syntax¶
Models are specified with R-style formulas:
# Single smooth
GAM("y ~ s(x)")
# Multiple smooths
GAM("y ~ s(x1) + s(x2)")
# Tensor product
GAM("y ~ te(x1, x2, k=5)")
# Factor-by smooth
GAM("y ~ s(x, by=fac, k=10) + fac")
Smooth term arguments¶
| Argument | Description | Default |
|---|---|---|
k |
Basis dimension (number of knots). -1 means auto-select: resolves to 10 for 1D TPRS/cubic, 30 for 2D TPRS. |
-1 (auto) |
bs |
Basis type: 'tp', 'ts', 'cr', 'cs', 'cc' |
'tp' |
by |
Factor variable for factor-by smooths | None |
Tensor product arguments¶
| Argument | Description | Default |
|---|---|---|
k |
Marginal basis dimension (scalar applied to all margins). -1 means auto-select (resolves to 10 for the default cr marginals). |
-1 (auto) |
Use te() for full tensor products and ti() for interaction-only terms
(excludes main effects).
Custom registrations¶
JaxGAM ships with built-in smooths, families, and links, but you can register your own at runtime. Custom entries extend the registry without modifying or removing built-in entries.
Registering a custom smooth¶
Your class must subclass jaxgam.smooths.Smooth.
from jaxgam.smooths import smooth_registry, Smooth
class PSplineSmooth(Smooth):
... # implement setup(), basis(), penalty(), etc.
smooth_registry.register("ps", PSplineSmooth)
# Now usable in formulas:
GAM("y ~ s(x, bs='ps')").fit(data)
Registering a custom family¶
Your class must subclass jaxgam.families.ExponentialFamily.
from jaxgam.families import family_registry, ExponentialFamily
class NegativeBinomial(ExponentialFamily):
... # implement variance(), deviance_resids(), initialize(), etc.
family_registry.register("nb", NegativeBinomial)
# Now usable by name:
GAM("y ~ s(x)", family="nb").fit(data)
Registering a custom link¶
Your class must subclass jaxgam.links.Link.
from jaxgam.links import link_registry, Link
class CauchitLink(Link):
... # implement link(), inverse(), derivative()
link_registry.register("cauchit", CauchitLink)
Rules¶
- Keys are case-insensitive —
"PS"and"ps"are the same key. - You cannot override a built-in or previously registered key. Attempting
to do so raises
ValueError. - Registrations are global and take effect immediately — any subsequent
GAMcall will see the new entry.
Inspecting a registry¶
from jaxgam.smooths import smooth_registry
smooth_registry.available # ('cc', 'cr', 'cs', 'te', 'ti', 'tp', 'ts')
"tp" in smooth_registry # True
len(smooth_registry) # 7
GAMResults attributes¶
The GAMResults object returned by fit() exposes all fitted state as
read-only attributes (frozen dataclass):
| Attribute | Type | Description |
|---|---|---|
coefficients |
ndarray (p,) |
Coefficient vector |
fitted_values |
ndarray (n,) |
Fitted values on response scale |
linear_predictor |
ndarray (n,) |
Linear predictor |
Vp |
ndarray (p, p) |
Bayesian covariance matrix |
edf |
ndarray |
Per-smooth effective degrees of freedom |
edf1 |
ndarray |
Alternative EDF for significance testing |
edf_total |
float |
Total effective degrees of freedom |
scale |
float |
Estimated scale (dispersion) parameter |
deviance |
float |
Model deviance |
null_deviance |
float |
Null model deviance |
smoothing_params |
ndarray |
Estimated smoothing parameters |
converged |
bool |
Whether the optimizer converged |
n_iter |
int |
Number of Newton iterations |
score |
float |
REML/ML value at convergence |
X |
ndarray (n, p) |
Design matrix |
y |
ndarray (n,) |
Response vector |
weights |
ndarray (n,) |
Prior weights |
family |
ExponentialFamily |
Family object used for fitting |
formula |
str |
Model formula |
theta |
float \| None |
Estimated theta for NB (None for standard families) |
method |
str |
Smoothing parameter method ("REML" or "ML") |
n |
int |
Number of observations |