{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Visualization for Radial 1D Stellar Wind Simulation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# numerics\n", "import jax\n", "import jax.numpy as jnp\n", "# # for now using CPU as of outdated NVIDIA Driver\n", "# jax.config.update('jax_platform_name', 'cpu')\n", "# # jax.config.update('jax_disable_jit', True)\n", "# # 64-bit precision\n", "# jax.config.update(\"jax_enable_x64\", True)\n", "\n", "# debug nans\n", "# jax.config.update(\"jax_debug_nans\", True)\n", "\n", "# timing\n", "from timeit import default_timer as timer\n", "\n", "# plotting\n", "import matplotlib.pyplot as plt\n", "from matplotlib.gridspec import GridSpec\n", "\n", "# fluids\n", "from jf1uids import WindParams\n", "from jf1uids import SimulationConfig\n", "from jf1uids import get_helper_data\n", "from jf1uids import SimulationParams\n", "from jf1uids import time_integration\n", "from jf1uids.fluid_equations.fluid import construct_primitive_state\n", "\n", "\n", "from jf1uids import get_registered_variables\n", "from jf1uids.option_classes import WindConfig\n", "\n", "\n", "# units\n", "from jf1uids import CodeUnits\n", "from astropy import units as u\n", "import astropy.constants as c\n", "from astropy.constants import m_p\n", "\n", "# wind-specific\n", "from jf1uids._physics_modules._stellar_wind.weaver import Weaver" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initiating the stellar wind simulation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "👷 Setting up simulation...\n" ] } ], "source": [ "from jf1uids.option_classes.simulation_config import OPEN_BOUNDARY, REFLECTIVE_BOUNDARY, SPHERICAL\n", "\n", "\n", "print(\"👷 Setting up simulation...\")\n", "\n", "# simulation settings\n", "gamma = 5/3\n", "\n", "# spatial domain\n", "geometry = SPHERICAL\n", "box_size = 1.0\n", "num_cells = 401\n", "\n", "left_boundary = REFLECTIVE_BOUNDARY\n", "right_boundary = OPEN_BOUNDARY\n", "\n", "# activate stellar wind\n", "stellar_wind = True\n", "\n", "fixed_timestep = True\n", "num_timesteps = 10000\n", "\n", "# setup simulation config\n", "config = SimulationConfig(\n", " runtime_debugging = True,\n", " geometry = geometry,\n", " box_size = box_size, \n", " num_cells = num_cells,\n", " wind_config = WindConfig(\n", " stellar_wind = stellar_wind,\n", " num_injection_cells = 10,\n", " trace_wind_density = False,\n", " ),\n", " # fixed_timestep = fixed_timestep,\n", " # num_timesteps = num_timesteps,\n", " # first_order_fallback = True,\n", ")\n", "\n", "helper_data = get_helper_data(config)\n", "\n", "registered_variables = get_registered_variables(config)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "config_high_res = SimulationConfig(\n", " geometry = geometry,\n", " box_size = box_size, \n", " num_cells = 2001,\n", " wind_config = WindConfig(\n", " stellar_wind = stellar_wind,\n", " num_injection_cells = 10,\n", " ),\n", " # fixed_timestep = fixed_timestep,\n", " # num_timesteps = num_timesteps,\n", " # first_order_fallback = True,\n", ")\n", "\n", "helper_data_high_res = get_helper_data(config_high_res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting the simulation parameters and initial state" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "For spherical geometry, only HLL is currently supported.\n", "Automatically setting reflective left and open right boundary for spherical geometry.\n", "For stellar wind simulations, we need source term aware timesteps, turning on.\n", "For spherical geometry, only HLL is currently supported.\n", "Automatically setting reflective left and open right boundary for spherical geometry.\n", "For stellar wind simulations, we need source term aware timesteps, turning on.\n" ] } ], "source": [ "# code units\n", "from jf1uids.option_classes.simulation_config import finalize_config\n", "\n", "\n", "code_length = 3 * u.parsec\n", "code_mass = 1e-3 * u.M_sun\n", "code_velocity = 1 * u.km / u.s\n", "code_units = CodeUnits(code_length, code_mass, code_velocity)\n", "\n", "# time domain\n", "C_CFL = 0.8\n", "t_final = 2.5 * 1e4 * u.yr\n", "t_end = t_final.to(code_units.code_time).value\n", "dt_max = 0.1 * t_end\n", "\n", "# wind parameters\n", "M_star = 40 * u.M_sun\n", "wind_final_velocity = 2000 * u.km / u.s\n", "wind_mass_loss_rate = 2.965e-3 / (1e6 * u.yr) * M_star\n", "\n", "wind_params = WindParams(\n", " wind_mass_loss_rate = wind_mass_loss_rate.to(code_units.code_mass / code_units.code_time).value,\n", " wind_final_velocity = wind_final_velocity.to(code_units.code_velocity).value\n", ")\n", "\n", "params = SimulationParams(\n", " C_cfl = C_CFL,\n", " dt_max = dt_max,\n", " gamma = gamma,\n", " t_end = t_end,\n", " wind_params=wind_params\n", ")\n", "\n", "params_high_res = SimulationParams(\n", " C_cfl = C_CFL,\n", " dt_max = dt_max,\n", " gamma = gamma,\n", " t_end = t_end,\n", " wind_params=wind_params\n", ")\n", "\n", "# homogeneous initial state\n", "rho_0 = 2 * c.m_p / u.cm**3\n", "p_0 = 3e4 * u.K / u.cm**3 * c.k_B\n", "\n", "rho_init = jnp.ones(num_cells) * rho_0.to(code_units.code_density).value\n", "u_init = jnp.zeros(num_cells)\n", "p_init = jnp.ones(num_cells) * p_0.to(code_units.code_pressure).value\n", "\n", "# get initial state\n", "initial_state = construct_primitive_state(\n", " config = config,\n", " registered_variables = registered_variables,\n", " density = rho_init,\n", " velocity_x = u_init,\n", " gas_pressure = p_init\n", ")\n", "\n", "config = finalize_config(config, initial_state.shape)\n", "\n", "# initial state high res\n", "rho_init_high_res = jnp.ones(config_high_res.num_cells) * rho_0.to(code_units.code_density).value\n", "u_init_high_res = jnp.zeros(config_high_res.num_cells)\n", "p_init_high_res = jnp.ones(config_high_res.num_cells) * p_0.to(code_units.code_pressure).value\n", "\n", "initial_state_high_res = construct_primitive_state(\n", " config = config_high_res,\n", " registered_variables = registered_variables,\n", " density = rho_init_high_res,\n", " velocity_x = u_init_high_res,\n", " gas_pressure = p_init_high_res\n", ")\n", "\n", "config_high_res = finalize_config(config_high_res, initial_state_high_res.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulation and Gradient" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dv = 0.1 km / s\n" ] } ], "source": [ "final_state = time_integration(initial_state, config, params, helper_data, registered_variables)\n", "\n", "# high res final state\n", "final_state_high_res = time_integration(initial_state_high_res, config_high_res, params_high_res, helper_data_high_res, registered_variables)\n", "\n", "def integrator(velocity):\n", " return time_integration(initial_state, config, SimulationParams(C_cfl=params.C_cfl, dt_max=params.dt_max, gamma=params.gamma, t_end=params.t_end, wind_params=WindParams(wind_mass_loss_rate=params.wind_params.wind_mass_loss_rate, wind_final_velocity=velocity)), helper_data, registered_variables)\n", "\n", "vel_sens = jax.jacfwd(integrator)(params.wind_params.wind_final_velocity)\n", "\n", "# calculate the finite difference derivative\n", "dv = 0.1\n", "# print dv in km/s\n", "print(f\"dv = {(dv * code_units.code_velocity).to(u.km/u.s)}\")\n", "vel_sens_fd = (integrator(params.wind_params.wind_final_velocity + dv) - integrator(params.wind_params.wind_final_velocity - dv)) / (2 * dv)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "👷 generating plots\n", "0.00852260137538079 code_length / code_velocity\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_2013104/3031789449.py:149: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def plot_weaver_comparison(axs, final_state, params, helper_data, code_units, rho_0, p_0):\n", " print(\"👷 generating plots\")\n", "\n", " rho = final_state[registered_variables.density_index]\n", " vel = final_state[registered_variables.velocity_index]\n", " p = final_state[registered_variables.pressure_index]\n", "\n", " rho = rho * code_units.code_density\n", " vel = vel * code_units.code_velocity\n", " p = p * code_units.code_pressure\n", "\n", " r_high_res = helper_data_high_res.geometric_centers * code_units.code_length\n", "\n", " rho_high_res = final_state_high_res[registered_variables.density_index]\n", " vel_high_res = final_state_high_res[registered_variables.velocity_index]\n", " p_high_res = final_state_high_res[registered_variables.pressure_index]\n", "\n", " rho_high_res = rho_high_res * code_units.code_density\n", " vel_high_res = vel_high_res * code_units.code_velocity\n", " p_high_res = p_high_res * code_units.code_pressure\n", "\n", " r = helper_data.geometric_centers * code_units.code_length\n", "\n", " # get weaver solution\n", " weaver = Weaver(\n", " params.wind_params.wind_final_velocity * code_units.code_velocity,\n", " params.wind_params.wind_mass_loss_rate * code_units.code_mass / code_units.code_time,\n", " rho_0,\n", " p_0\n", " )\n", " current_time = params.t_end * code_units.code_time# + 12e-4 * code_units.code_time\n", " print(current_time)\n", " \n", " # density\n", " r_density_weaver, density_weaver = weaver.get_density_profile(0.01 * u.parsec, 3.5 * u.parsec, current_time)\n", " r_density_weaver = r_density_weaver.to(u.parsec)\n", " density_weaver = (density_weaver / m_p).to(u.cm**-3)\n", "\n", " # velocity\n", " r_velocity_weaver, velocity_weaver = weaver.get_velocity_profile(0.01 * u.parsec, 3.5 * u.parsec, current_time)\n", " r_velocity_weaver = r_velocity_weaver.to(u.parsec)\n", " velocity_weaver = velocity_weaver.to(u.km / u.s)\n", "\n", " # pressure\n", " r_pressure_weaver, pressure_weaver = weaver.get_pressure_profile(0.01 * u.parsec, 3.5 * u.parsec, current_time)\n", " r_pressure_weaver = r_pressure_weaver.to(u.parsec)\n", " pressure_weaver = (pressure_weaver / c.k_B).to(u.cm**-3 * u.K)\n", "\n", " axs[0].set_yscale(\"log\")\n", " axs[0].plot(r.to(u.parsec), (rho / m_p).to(u.cm**-3), label=\"jf1uids\")\n", "\n", " axs[0].plot(r_density_weaver, density_weaver, \"--\", label=\"Weaver solution\")\n", "\n", " axs[0].plot(r_high_res.to(u.parsec), (rho_high_res / m_p).to(u.cm**-3), \"-.\", label=\"jf1uids, N = {}\".format(config_high_res.num_cells))\n", "\n", " axs[0].set_title(\"density\")\n", " axs[0].set_ylabel(r\"$\\rho$ in m$_p$ cm$^{-3}$\")\n", " axs[0].set_xlim(0, 3)\n", "\n", " axs[0].legend(loc=\"upper left\")\n", "\n", " # turn off x ticks\n", " axs[0].set_xticks([])\n", " axs[1].set_xticks([])\n", " axs[2].set_xticks([])\n", "\n", " axs[1].set_yscale(\"log\")\n", " axs[1].plot(r.to(u.parsec), (p / c.k_B).to(u.K / u.cm**3), label=\"jf1uids\")\n", " axs[1].plot(r_pressure_weaver, pressure_weaver, \"--\", label=\"Weaver solution\")\n", " axs[1].plot(r_high_res.to(u.parsec), (p_high_res / c.k_B).to(u.K / u.cm**3), \"-.\", label=\"jf1uids, N = {}\".format(config_high_res.num_cells))\n", "\n", " axs[1].set_title(\"pressure\")\n", " axs[1].set_ylabel(r\"$p$/k$_b$ in K cm$^{-3}$\")\n", " axs[1].set_xlim(0, 3)\n", "\n", " axs[1].legend(loc=\"upper left\")\n", "\n", "\n", " axs[2].set_yscale(\"log\")\n", " axs[2].plot(r.to(u.parsec), vel.to(u.km / u.s), label=\"jf1uids\")\n", " axs[2].plot(r_velocity_weaver, velocity_weaver, \"--\", label=\"Weaver solution\")\n", " axs[2].plot(r_high_res.to(u.parsec), vel_high_res.to(u.km / u.s), \"-.\", label=\"jf1uids, N = {}\".format(config_high_res.num_cells))\n", " axs[2].set_title(\"velocity\")\n", " # ylim 1 to 1e4 km/s\n", " axs[2].set_ylim(1, 1e4)\n", " axs[2].set_xlim(0, 3)\n", " axs[2].set_ylabel(\"v in km/s\")\n", " # xlabel\n", " # show legend upper left\n", " axs[2].legend(loc=\"upper right\")\n", "\n", "def sensitivity_plot(axs, vel_sens, vel_sens_fd):\n", "\n", " rho_sens_vel = vel_sens[registered_variables.density_index]\n", " vel_sens_vel = vel_sens[registered_variables.velocity_index]\n", " p_sens_vel = vel_sens[registered_variables.pressure_index]\n", "\n", " rho_sens_vel_fd = vel_sens_fd[registered_variables.density_index]\n", " vel_sens_vel_fd = vel_sens_fd[registered_variables.velocity_index]\n", " p_sens_vel_fd = vel_sens_fd[registered_variables.pressure_index]\n", "\n", " r = helper_data.geometric_centers * code_units.code_length\n", "\n", " axs[0].plot(r.to(u.parsec), rho_sens_vel, label=r\"d$\\rho$/dv$_\\infty$ autodiff\")\n", " axs[0].plot(r.to(u.parsec), rho_sens_vel_fd, \"--\", label=r\"d$\\rho$/dv$_\\infty$ finite diff.\")\n", " axs[0].set_ylabel(r\"d$\\rho$/dv$_\\infty$\")\n", " axs[0].legend(loc = \"upper left\")\n", " axs[0].tick_params(axis='y')\n", " axs[0].set_yscale('symlog')\n", " axs[0].set_xlim(0, 3)\n", " axs[0].set_xlabel(\"r in pc\")\n", " axs[0].yaxis.set_label_coords(-0.15, 0.5)\n", "\n", " axs[1].plot(r.to(u.parsec), p_sens_vel, label=r\"dp/dv$_\\infty$ autodiff\")\n", " axs[1].plot(r.to(u.parsec), p_sens_vel_fd, \"--\", label=r\"dp/dv$_\\infty$ finite diff.\")\n", " axs[1].set_ylabel(r\"dp/dv$_\\infty$\")\n", " axs[1].legend(loc = \"lower right\")\n", " axs[1].tick_params(axis='y')\n", " axs[1].set_yscale('symlog')\n", " axs[1].set_xlim(0, 3)\n", " axs[1].set_xlabel(\"r in pc\")\n", " axs[1].yaxis.set_label_coords(-0.15, 0.5)\n", "\n", " axs[2].plot(r.to(u.parsec), vel_sens_vel, label=r\"dv/dv$_\\infty$ autodiff\")\n", " axs[2].plot(r.to(u.parsec), vel_sens_vel_fd, \"--\", label=r\"dv/dv$_\\infty$ finite diff.\")\n", " axs[2].set_ylabel(r\"dv/dv$_\\infty$\")\n", " axs[2].legend(loc = \"upper right\")\n", " axs[2].tick_params(axis='y')\n", " axs[2].set_yscale('symlog')\n", " axs[2].set_xlim(0, 3)\n", " axs[2].set_xlabel(\"r in pc\")\n", " axs[2].yaxis.set_label_coords(-0.15, 0.5)\n", "\n", " axs[0].yaxis.set_major_locator(plt.MaxNLocator(3))\n", " axs[1].yaxis.set_major_locator(plt.MaxNLocator(6))\n", " axs[2].yaxis.set_major_locator(plt.MaxNLocator(3))\n", "\n", "\n", "fig = plt.figure(figsize=(14, 4.5))\n", "\n", "gs = GridSpec(2, 3, height_ratios=[3, 2], figure=fig, hspace=0.1, wspace=0.3)\n", "\n", "axs_upper = [fig.add_subplot(gs[0, i]) for i in range(3)]\n", "axs_lower = [fig.add_subplot(gs[1, i]) for i in range(3)]\n", "\n", "plot_weaver_comparison(axs_upper, final_state, params, helper_data, code_units, rho_0, p_0)\n", "sensitivity_plot(axs_lower, vel_sens, vel_sens_fd)\n", "\n", "plt.tight_layout()\n", "\n", "# TODO: add finite difference here\n", "\n", "plt.savefig(\"../figures/gradients_through_stellar_wind.pdf\", bbox_inches=\"tight\")" ] } ], "metadata": { "kernelspec": { "display_name": "f1uids", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }