{ "cells": [ { "cell_type": "markdown", "id": "625d51b7", "metadata": {}, "source": [ "# Loading pre-trained models" ] }, { "cell_type": "code", "execution_count": null, "id": "376b9349", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import jax\n", "from jax import numpy as jnp\n", "\n", "import diffmahnet" ] }, { "cell_type": "code", "execution_count": null, "id": "0153a552", "metadata": {}, "outputs": [], "source": [ "# View available pre-trained models\n", "available_names = diffmahnet.pretrained_model_names\n", "available_names" ] }, { "cell_type": "code", "execution_count": null, "id": "b902f5a1", "metadata": {}, "outputs": [], "source": [ "centrals_model = diffmahnet.load_pretrained_model(\"cenflow_v2_0.eqx\")\n", "satellites_model = diffmahnet.load_pretrained_model(\"satflow_v2_0.eqx\")" ] }, { "cell_type": "markdown", "id": "3788ca0b", "metadata": {}, "source": [ "## Generate diffmah parameters\n", "- Condition on a few values of $u = (\\log M_{\\rm obs}, t_{\\rm obs})$\n", "- For each conditional value, generate a sample of 1000 MAHs" ] }, { "cell_type": "code", "execution_count": null, "id": "451d3529", "metadata": {}, "outputs": [], "source": [ "n_sample = 1000\n", "m_grid = jnp.array([11.0, 12.5, 14.0])\n", "t_grid = jnp.array([2.0, 7.5, 13.0])\n", "m_vals, t_vals = [jnp.repeat(x.flatten(), n_sample)\n", " for x in jnp.meshgrid(m_grid, t_grid)]\n", "print(m_vals)\n", "print(t_vals)" ] }, { "cell_type": "markdown", "id": "fc7539e7", "metadata": {}, "source": [ "## Create functions roughly equivalent to `mc_diffmah_*pop()` for our trained models" ] }, { "cell_type": "code", "execution_count": null, "id": "71af076c", "metadata": {}, "outputs": [], "source": [ "# Note a few differences from mc_diffmah_cenpop:\n", "# - Only returns a single set of DiffmahParams\n", "# - Does not depend on lgt0 or t_peak\n", "mc_diffmahnet_cenpop = centrals_model.make_mc_diffmahnet()\n", "mc_diffmahnet_satpop = satellites_model.make_mc_diffmahnet()\n", "\n", "randkey = jax.random.key(0)\n", "keys = jax.random.split(randkey, 2)\n", "cenflow_diffmahparams = mc_diffmahnet_cenpop(\n", " centrals_model.get_params(), m_vals, t_vals, keys[0])\n", "satflow_diffmahparams = mc_diffmahnet_satpop(\n", " satellites_model.get_params(), m_vals, t_vals, keys[1])" ] }, { "cell_type": "code", "execution_count": null, "id": "26574fc4", "metadata": {}, "outputs": [], "source": [ "# Plot mass accretion histories from the predicted diffmah parameters\n", "tgrid = jnp.linspace(0.5, t_vals, 100).T\n", "cen_mah = diffmahnet.log_mah_kern(\n", " cenflow_diffmahparams, tgrid, np.log10(13.8))\n", "sat_mah = diffmahnet.log_mah_kern(\n", " satflow_diffmahparams, tgrid, np.log10(13.8))" ] }, { "cell_type": "code", "execution_count": null, "id": "c15ce3cb", "metadata": {}, "outputs": [], "source": [ "# Plot the MAH of every 200th halo (5 per set of {M_obs, t_obs})\n", "plt.plot([], [], label=\"centrals\", color=\"C0\")\n", "plt.plot([], [], label=\"satellites\", color=\"C1\")\n", "plt.plot(tgrid[::200].T, cen_mah[::200].T, color=\"C0\", alpha=0.5)\n", "plt.plot(tgrid[::200].T, sat_mah[::200].T, color=\"C1\", alpha=0.5)\n", "plt.legend(frameon=False)\n", "plt.xlabel(\"$\\\\rm t \\\\; [Gyr]$\")\n", "plt.ylabel(\"$\\\\rm \\\\log(M_h(t)/M_\\\\odot)$\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "main312", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 5 }