API Reference

diffmahnet

class diffmahnet.DiffMahFlow(scaler, nn_depth=2, nn_width=50, flow_layers=8, randkey=None)[source]

The primary class within diffmahnet. This class is used to train and from and emulate the diffmahpop model.

Parameters:
  • scaler (Scaler) – Scaler object to normalize the input data

  • nn_depth (int, optional) – Depth of the neural network, by default 2

  • nn_width (int, optional) – Number of hidden layers in the neural network, by default 50

  • flow_layers (int, optional) – Number of flow layers, by default 8

  • randkey (jax.random.PRNGKey, optional) – Random key for reproducibility, by default None

adam_fit(lossfunc, randkey=None, nsteps=100, progress=True, learning_rate=0.0001, thin=1, **kwargs)[source]

Fit the flow using the Adam stochastic gradient descent

Parameters:
  • lossfunc (callable) – Loss function to minimize, should have signature lossfunc(diffmahflow, randkey=key) -> float

  • randkey (PRNG Key, optional) – Set the random seed, by default use the current self.randkey

  • nsteps (int, optional) – Number of Adam steps to perform, by default 100

  • progress (bool, optional) – Set false to hide progress bars, by default True

  • learning_rate (float, optional) – Initial Adam learning rate, by default 1e-4

  • thin (int, optional) – Return parameters for every thin iterations, by default 1. Set thin=0 to only return final parameters

Returns:

Loss value at each step of the descent

Return type:

jnp.array[float]

init_fit(xtrain, utrain, randkey=None, learning_rate=0.01, max_patience=10, max_epochs=50)[source]

Train the flow directly on P(mah_params|m_obs,t_obs)

classmethod load(filename, randkey=None)[source]

Load a pre-trained model from an eqx file

Parameters:
  • filename (str) – Filename to load the model from. The “.eqx” extension will be added automatically if not present

  • randkey (jax.random.PRNGKey, optional) – Random key for reproducibility

Returns:

The loaded model

Return type:

DiffMahFlow

sample(condition, randkey=None, extra_shape=(), asparams=False, flow_params=None)[source]

Sample diffmah u_params, conditioned on (m_obs, t_obs)

Parameters:
  • condition (jnp.ndarray) – Array of m_obs and t_obs, of shape (n_samples, 2)

  • randkey (jax.random.PRNGKey, optional) – Random key for reproducibility

  • extra_shape (tuple, optional) – Extra shape to repeatedly sample for each condition value

  • asparams (bool, optional) – If true, return DiffmahParams tuple instead of uparams array

  • flow_params (jnp.ndarray, optional) – Set the parameters of the flow to this value for sampling (for functional programming instead of object-oriented)

Returns:

Sampled unbound params, of shape (n_samples, 5, *extra_shape)

Return type:

jnp.ndarray | DiffmahParams

save(filename)[source]

Save this model object to an eqx file for future use

Parameters:

filename (str) – Filename to save the model to. The “.eqx” extension will be added automatically if not present

class diffmahnet.Scaler[source]
diffmahnet.gen_time_grids(key, t_obs, t_min=0.1, n_tgrid=20)[source]