API Reference
diffmahnet
- class diffmahnet.DiffMahFlow(scaler, nn_depth=2, nn_width=50, flow_layers=8, randkey=None)[source]
The primary class within diffmahnet. This class is used to train and from and emulate the diffmahpop model.
- Parameters:
scaler (Scaler) – Scaler object to normalize the input data
nn_depth (int, optional) – Depth of the neural network, by default 2
nn_width (int, optional) – Number of hidden layers in the neural network, by default 50
flow_layers (int, optional) – Number of flow layers, by default 8
randkey (jax.random.PRNGKey, optional) – Random key for reproducibility, by default None
- adam_fit(lossfunc, randkey=None, nsteps=100, progress=True, learning_rate=0.0001, thin=1, **kwargs)[source]
Fit the flow using the Adam stochastic gradient descent
- Parameters:
lossfunc (callable) – Loss function to minimize, should have signature lossfunc(diffmahflow, randkey=key) -> float
randkey (PRNG Key, optional) – Set the random seed, by default use the current self.randkey
nsteps (int, optional) – Number of Adam steps to perform, by default 100
progress (bool, optional) – Set false to hide progress bars, by default True
learning_rate (float, optional) – Initial Adam learning rate, by default 1e-4
thin (int, optional) – Return parameters for every thin iterations, by default 1. Set thin=0 to only return final parameters
- Returns:
Loss value at each step of the descent
- Return type:
jnp.array[float]
- init_fit(xtrain, utrain, randkey=None, learning_rate=0.01, max_patience=10, max_epochs=50)[source]
Train the flow directly on P(mah_params|m_obs,t_obs)
- classmethod load(filename, randkey=None)[source]
Load a pre-trained model from an eqx file
- Parameters:
filename (str) – Filename to load the model from. The “.eqx” extension will be added automatically if not present
randkey (jax.random.PRNGKey, optional) – Random key for reproducibility
- Returns:
The loaded model
- Return type:
- sample(condition, randkey=None, extra_shape=(), asparams=False, flow_params=None)[source]
Sample diffmah u_params, conditioned on (m_obs, t_obs)
- Parameters:
condition (jnp.ndarray) – Array of m_obs and t_obs, of shape (n_samples, 2)
randkey (jax.random.PRNGKey, optional) – Random key for reproducibility
extra_shape (tuple, optional) – Extra shape to repeatedly sample for each condition value
asparams (bool, optional) – If true, return DiffmahParams tuple instead of uparams array
flow_params (jnp.ndarray, optional) – Set the parameters of the flow to this value for sampling (for functional programming instead of object-oriented)
- Returns:
Sampled unbound params, of shape (n_samples, 5, *extra_shape)
- Return type:
jnp.ndarray | DiffmahParams