Basic Tutorial

[1]:
import matplotlib.pyplot as plt
import corner
import numpy as np
import jax
import jax.numpy as jnp

import diffmahnet

Generate fake data

  • (N, 5) array of MAH unbound parameters

  • (N, 2) array of conditional variables \(M_{\rm obs}\) and \(t_{\rm obs}\)

[2]:
randkey = jax.random.key(0)
keys = jax.random.split(randkey, 6)

ndata = 10_000

# m_obs and t_obs
fake_conditions = jax.random.normal(keys[0], (ndata, 2)) + 1.5

# Apply some dependence to M_obs and t_obs on the MAH parameters
def gen_uparams(key, condition):
    fake_mah_uparams = jax.random.uniform(key, (condition.shape[0], 5)) + 3.0
    fake_mah_uparams = fake_mah_uparams * condition[:, 0:1] ** 2
    fake_mah_uparams = fake_mah_uparams * condition[:, 1:2] ** 3
    return fake_mah_uparams

fake_mah_uparams = gen_uparams(keys[1], fake_conditions)
scaler = diffmahnet.Scaler.compute(fake_mah_uparams, fake_conditions)

Create a very quick, small flow model with only 102 parameters

[3]:
flow = diffmahnet.DiffMahFlow(scaler, nn_depth=1, nn_width=2, flow_layers=2)
flow.get_params().size
[3]:
102

Train the model to the fake data we generated above

[4]:
res = flow.init_fit(
    fake_mah_uparams, fake_conditions, randkey=keys[2])
 90%|█████████ | 45/50 [00:06<00:00,  6.72it/s, train=-10.8, val=-11.5 (Max patience reached)]

Optionally, save the trained model and reload it later

[5]:
flow.save("fake_model.eqx")
[6]:
same_flow = diffmahnet.DiffMahFlow.load("fake_model.eqx")
jnp.all(same_flow.get_params() == flow.get_params())
[6]:
Array(True, dtype=bool)

Make predictive samples from our flow model

[7]:
test_conditions = jax.random.normal(keys[3], (ndata * 10, 2)) + 1.5
test_uparams = gen_uparams(keys[4], test_conditions)

# Generate samples, given the new "test" values of m_obs and t_obs
flow_mah_uparams = flow.sample(test_conditions, keys[5])
[8]:
# Plot the rough agreement between the test and flow prediction distributions
test_conditions_vs_param1 = np.concatenate(
    [test_conditions, test_uparams[:, 0:1]], axis=1)
fig = corner.corner(
    test_conditions_vs_param1, labels=["M_obs", "t_obs", "uparam1"],
    levels=(0.68, 0.95, 0.997), plot_datapoints=False, color="C0", alpha=0.1)
flow_conditions_vs_param1 = np.concatenate(
    [test_conditions, flow_mah_uparams[:, 0:1]], axis=1)
corner.corner(
    flow_conditions_vs_param1, labels=["M_obs", "t_obs", "uparam1"], fig=fig,
    quantiles=[0.16, 0.5, 0.84], fill_contours=True,
    levels=(0.68, 0.95, 0.997), plot_datapoints=False, color="C1", alpha=0.1)
fig.axes[1].text(0, 0.5, "Test data", color="C0")
fig.axes[1].text(0, 0.4, "Flow prediction", color="C1")
plt.show()
../_images/notebooks_intro_13_0.png
[9]:
# Note you can also generate actual DiffmahParams using asparams=True
flow.sample(
    test_conditions, keys[5], asparams=True)
[9]:
DiffmahParams(logm0=Array([ 8.202568, 16.948618, 12.490401, ...,  9.444087,  9.286882,
       16.998804], dtype=float32), logtc=Array([0.33600986, 0.99974513, 0.7944621 , ..., 0.4520527 , 0.49351144,
       0.9999833 ], dtype=float32), early_index=Array([7.447408, 9.998593, 9.452143, ..., 8.321765, 8.241098, 9.999733],      dtype=float32), late_index=Array([3.3145735, 4.9981775, 4.192531 , ..., 3.3211868, 3.4189644,
       4.999934 ], dtype=float32), t_peak=Array([ 7.9103446, 19.938223 , 13.115052 , ...,  9.360291 , 11.562651 ,
       19.999243 ], dtype=float32))

Try improving the fit by adjusting the flow hyperparameters

  • Increase the size of the neural network using:

    • nn_depth

    • nn_width

    • flow_layers

  • Increase the max_patience and/or max_epochs of the init_fit method