Basic Tutorial
[1]:
import matplotlib.pyplot as plt
import corner
import numpy as np
import jax
import jax.numpy as jnp
import diffmahnet
Generate fake data
(N, 5) array of MAH unbound parameters
(N, 2) array of conditional variables \(M_{\rm obs}\) and \(t_{\rm obs}\)
[2]:
randkey = jax.random.key(0)
keys = jax.random.split(randkey, 6)
ndata = 10_000
# m_obs and t_obs
fake_conditions = jax.random.normal(keys[0], (ndata, 2)) + 1.5
# Apply some dependence to M_obs and t_obs on the MAH parameters
def gen_uparams(key, condition):
fake_mah_uparams = jax.random.uniform(key, (condition.shape[0], 5)) + 3.0
fake_mah_uparams = fake_mah_uparams * condition[:, 0:1] ** 2
fake_mah_uparams = fake_mah_uparams * condition[:, 1:2] ** 3
return fake_mah_uparams
fake_mah_uparams = gen_uparams(keys[1], fake_conditions)
scaler = diffmahnet.Scaler.compute(fake_mah_uparams, fake_conditions)
Create a very quick, small flow model with only 102 parameters
[3]:
flow = diffmahnet.DiffMahFlow(scaler, nn_depth=1, nn_width=2, flow_layers=2)
flow.get_params().size
[3]:
102
Train the model to the fake data we generated above
[4]:
res = flow.init_fit(
fake_mah_uparams, fake_conditions, randkey=keys[2])
90%|█████████ | 45/50 [00:06<00:00, 6.72it/s, train=-10.8, val=-11.5 (Max patience reached)]
Optionally, save the trained model and reload it later
[5]:
flow.save("fake_model.eqx")
[6]:
same_flow = diffmahnet.DiffMahFlow.load("fake_model.eqx")
jnp.all(same_flow.get_params() == flow.get_params())
[6]:
Array(True, dtype=bool)
Make predictive samples from our flow model
[7]:
test_conditions = jax.random.normal(keys[3], (ndata * 10, 2)) + 1.5
test_uparams = gen_uparams(keys[4], test_conditions)
# Generate samples, given the new "test" values of m_obs and t_obs
flow_mah_uparams = flow.sample(test_conditions, keys[5])
[8]:
# Plot the rough agreement between the test and flow prediction distributions
test_conditions_vs_param1 = np.concatenate(
[test_conditions, test_uparams[:, 0:1]], axis=1)
fig = corner.corner(
test_conditions_vs_param1, labels=["M_obs", "t_obs", "uparam1"],
levels=(0.68, 0.95, 0.997), plot_datapoints=False, color="C0", alpha=0.1)
flow_conditions_vs_param1 = np.concatenate(
[test_conditions, flow_mah_uparams[:, 0:1]], axis=1)
corner.corner(
flow_conditions_vs_param1, labels=["M_obs", "t_obs", "uparam1"], fig=fig,
quantiles=[0.16, 0.5, 0.84], fill_contours=True,
levels=(0.68, 0.95, 0.997), plot_datapoints=False, color="C1", alpha=0.1)
fig.axes[1].text(0, 0.5, "Test data", color="C0")
fig.axes[1].text(0, 0.4, "Flow prediction", color="C1")
plt.show()
[9]:
# Note you can also generate actual DiffmahParams using asparams=True
flow.sample(
test_conditions, keys[5], asparams=True)
[9]:
DiffmahParams(logm0=Array([ 8.202568, 16.948618, 12.490401, ..., 9.444087, 9.286882,
16.998804], dtype=float32), logtc=Array([0.33600986, 0.99974513, 0.7944621 , ..., 0.4520527 , 0.49351144,
0.9999833 ], dtype=float32), early_index=Array([7.447408, 9.998593, 9.452143, ..., 8.321765, 8.241098, 9.999733], dtype=float32), late_index=Array([3.3145735, 4.9981775, 4.192531 , ..., 3.3211868, 3.4189644,
4.999934 ], dtype=float32), t_peak=Array([ 7.9103446, 19.938223 , 13.115052 , ..., 9.360291 , 11.562651 ,
19.999243 ], dtype=float32))
Try improving the fit by adjusting the flow hyperparameters
Increase the size of the neural network using:
nn_depth
nn_width
flow_layers
Increase the max_patience and/or max_epochs of the
init_fitmethod