{ "cells": [ { "cell_type": "markdown", "id": "6a70a944", "metadata": {}, "source": [ "# Basic Tutorial" ] }, { "cell_type": "code", "execution_count": null, "id": "19c5d040", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import corner\n", "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "import diffmahnet" ] }, { "cell_type": "markdown", "id": "14c40021", "metadata": {}, "source": [ "## Generate fake data\n", "- (N, 5) array of MAH unbound parameters\n", "- (N, 2) array of conditional variables $M_{\\rm obs}$ and $t_{\\rm obs}$" ] }, { "cell_type": "code", "execution_count": null, "id": "3f3875b8", "metadata": {}, "outputs": [], "source": [ "randkey = jax.random.key(0)\n", "keys = jax.random.split(randkey, 6)\n", "\n", "ndata = 10_000\n", "\n", "# m_obs and t_obs\n", "fake_conditions = jax.random.normal(keys[0], (ndata, 2)) + 1.5\n", "\n", "# Apply some dependence to M_obs and t_obs on the MAH parameters\n", "def gen_uparams(key, condition):\n", " fake_mah_uparams = jax.random.uniform(key, (condition.shape[0], 5)) + 3.0\n", " fake_mah_uparams = fake_mah_uparams * condition[:, 0:1] ** 2\n", " fake_mah_uparams = fake_mah_uparams * condition[:, 1:2] ** 3\n", " return fake_mah_uparams\n", "\n", "fake_mah_uparams = gen_uparams(keys[1], fake_conditions)\n", "scaler = diffmahnet.Scaler.compute(fake_mah_uparams, fake_conditions)" ] }, { "cell_type": "markdown", "id": "8121f26f", "metadata": {}, "source": [ "## Create a very quick, small flow model with only 102 parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "95a67fad", "metadata": {}, "outputs": [], "source": [ "flow = diffmahnet.DiffMahFlow(scaler, nn_depth=1, nn_width=2, flow_layers=2)\n", "flow.get_params().size" ] }, { "cell_type": "markdown", "id": "674f764e", "metadata": {}, "source": [ "## Train the model to the fake data we generated above" ] }, { "cell_type": "code", "execution_count": null, "id": "fd26c996", "metadata": {}, "outputs": [], "source": [ "res = flow.init_fit(\n", " fake_mah_uparams, fake_conditions, randkey=keys[2])" ] }, { "cell_type": "markdown", "id": "ee98bba6", "metadata": {}, "source": [ "## Optionally, save the trained model and reload it later" ] }, { "cell_type": "code", "execution_count": null, "id": "eb6170cf", "metadata": {}, "outputs": [], "source": [ "flow.save(\"fake_model.eqx\")" ] }, { "cell_type": "code", "execution_count": null, "id": "834a6fb9", "metadata": {}, "outputs": [], "source": [ "same_flow = diffmahnet.DiffMahFlow.load(\"fake_model.eqx\")\n", "jnp.all(same_flow.get_params() == flow.get_params())" ] }, { "cell_type": "markdown", "id": "42f7a9f4", "metadata": {}, "source": [ "## Make predictive samples from our flow model" ] }, { "cell_type": "code", "execution_count": null, "id": "56e345aa", "metadata": {}, "outputs": [], "source": [ "test_conditions = jax.random.normal(keys[3], (ndata * 10, 2)) + 1.5\n", "test_uparams = gen_uparams(keys[4], test_conditions)\n", "\n", "# Generate samples, given the new \"test\" values of m_obs and t_obs\n", "flow_mah_uparams = flow.sample(test_conditions, keys[5])" ] }, { "cell_type": "code", "execution_count": null, "id": "afb3f48b", "metadata": {}, "outputs": [], "source": [ "# Plot the rough agreement between the test and flow prediction distributions\n", "test_conditions_vs_param1 = np.concatenate(\n", " [test_conditions, test_uparams[:, 0:1]], axis=1)\n", "fig = corner.corner(\n", " test_conditions_vs_param1, labels=[\"M_obs\", \"t_obs\", \"uparam1\"],\n", " levels=(0.68, 0.95, 0.997), plot_datapoints=False, color=\"C0\", alpha=0.1)\n", "flow_conditions_vs_param1 = np.concatenate(\n", " [test_conditions, flow_mah_uparams[:, 0:1]], axis=1)\n", "corner.corner(\n", " flow_conditions_vs_param1, labels=[\"M_obs\", \"t_obs\", \"uparam1\"], fig=fig,\n", " quantiles=[0.16, 0.5, 0.84], fill_contours=True,\n", " levels=(0.68, 0.95, 0.997), plot_datapoints=False, color=\"C1\", alpha=0.1)\n", "fig.axes[1].text(0, 0.5, \"Test data\", color=\"C0\")\n", "fig.axes[1].text(0, 0.4, \"Flow prediction\", color=\"C1\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "9cd7b55e", "metadata": {}, "outputs": [], "source": [ "# Note you can also generate actual DiffmahParams using asparams=True\n", "flow.sample(\n", " test_conditions, keys[5], asparams=True)" ] }, { "cell_type": "markdown", "id": "584ae4db", "metadata": {}, "source": [ "## Try improving the fit by adjusting the flow hyperparameters\n", "- Increase the size of the neural network using:\n", " - nn_depth\n", " - nn_width\n", " - flow_layers\n", "- Increase the max_patience and/or max_epochs of the `init_fit` method" ] } ], "metadata": { "kernelspec": { "display_name": "main312", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 5 }