Evrards Collapse#

Imports#

# %pip install ../

import os

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7,8" # second gpu

# numerics
import jax
import jax.numpy as jnp

# timing
from timeit import default_timer as timer

# plotting
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# fluids
from jf1uids import SimulationConfig
from jf1uids import get_helper_data
from jf1uids import SimulationParams
from jf1uids import time_integration
from jf1uids.fluid_equations.fluid import construct_primitive_state
from jf1uids.option_classes.simulation_config import finalize_config

from jf1uids import get_registered_variables

Initiatization#

from jf1uids.option_classes.simulation_config import BACKWARDS, FORWARDS, HLL, HLLC, MINMOD, OSHER, PERIODIC_BOUNDARY, REFLECTIVE_BOUNDARY, BoundarySettings, BoundarySettings1D

print("👷 Setting up simulation...")

# simulation settings
gamma = 5/3

# spatial domain
box_size = 4.0
num_cells = 64

fixed_timestep = False
dt_max = 0.001
num_timesteps = 9

# setup simulation config
config = SimulationConfig(
    runtime_debugging = False,
    progress_bar = True,
    first_order_fallback = False,
    mhd = False,
    self_gravity = True,
    dimensionality = 3,
    box_size = box_size, 
    num_cells = num_cells,
    fixed_timestep = fixed_timestep,
    differentiation_mode = FORWARDS,
    num_timesteps = num_timesteps,
    limiter = MINMOD,
    riemann_solver = HLLC,
    boundary_settings = BoundarySettings(BoundarySettings1D(left_boundary = REFLECTIVE_BOUNDARY, right_boundary = REFLECTIVE_BOUNDARY), BoundarySettings1D(left_boundary = REFLECTIVE_BOUNDARY, right_boundary = REFLECTIVE_BOUNDARY), BoundarySettings1D(left_boundary = REFLECTIVE_BOUNDARY, right_boundary = REFLECTIVE_BOUNDARY))
)

helper_data = get_helper_data(config)

params = SimulationParams(
    t_end = 0.8,
    C_cfl = 0.4,
    dt_max = dt_max,
)

registered_variables = get_registered_variables(config)
👷 Setting up simulation...

Setting the initial state#

from jf1uids.fluid_equations.fluid import construct_primitive_state3D

R = 1.0
M = 1.0

dx = config.box_size / (config.num_cells - 1)

rho = jnp.where(helper_data.r <= R - dx / 2, M / (2 * jnp.pi * R ** 2 * helper_data.r), 1e-4)

overlap_weights = (R + dx / 2 - helper_data.r) / dx

rho = jnp.where((helper_data.r > R - dx / 2) & (helper_data.r < R + dx / 2), rho * overlap_weights, rho)

# Initialize velocity fields to zero
v_x = jnp.zeros_like(rho)
v_y = jnp.zeros_like(rho)
v_z = jnp.zeros_like(rho)

# initial thermal energy per unit mass = 0.05
e = 0.05
p = (gamma - 1) * rho * e

# Construct the initial primitive state for the 3D simulation.
initial_state = construct_primitive_state(
    config = config,
    registered_variables = registered_variables,
    density = rho,
    velocity_x = v_x,
    velocity_y = v_y,
    velocity_z = v_z,
    gas_pressure = p
)
config = finalize_config(config, initial_state.shape)

Simulation#

final_state = time_integration(initial_state, config, params, helper_data, registered_variables)
 |████████████████████████████████████████████████████████████████████████████████████████████████████| 100.1% 

Visualization#

Cut#

from matplotlib.colors import LogNorm

a = num_cells // 2 - 10
b = num_cells // 2 + 10

plt.imshow(final_state[registered_variables.pressure_index, :, :, num_cells // 2].T, cmap = "jet", origin = "lower", extent=[0, box_size, 0, box_size], norm = LogNorm())
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
Text(0, 0.5, 'y')
../../_images/7a5968d56a337458c366c2c263ac762b7cb6f10413ee28621da0680c63886bad.png
# plot the radial density profile rho over r in a log-log plot

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))

ax1.scatter(helper_data.r.flatten(), final_state[registered_variables.density_index].flatten(), label="Final Density", s = 1)
# x and y log scale
ax1.set_xscale("log")
ax1.set_yscale("log")
ax1.set_xlim(1e-2, 6e-1)
ax1.set_ylim(1e-2, 1e3)
ax1.set_xlabel("r")
ax1.set_ylabel("Density")

# velocity profile
v_r = -jnp.sqrt(final_state[registered_variables.velocity_index.x] ** 2 + final_state[registered_variables.velocity_index.y] ** 2 + final_state[registered_variables.velocity_index.z] ** 2)

ax2.scatter(helper_data.r.flatten(), v_r.flatten(), label="Radial Velocity", s = 1)
# log x scale
ax2.set_xscale("log")
ax2.set_xlim(1e-2, 6e-1)
ax2.set_xlabel("r")
ax2.set_ylabel("Velocity")

# plot P / rho^gamma

ax3.scatter(helper_data.r.flatten(), final_state[registered_variables.pressure_index].flatten() / final_state[registered_variables.density_index].flatten() ** gamma, label="P / rho^gamma", s = 1)
ax3.set_xlim(box_size / num_cells, 6e-1)
ax3.set_xlabel("r")
ax3.set_ylabel("P / rho^gamma")
ax3.set_xscale("log")
ax3.set_ylim(0, 0.25)
ax3.set_xlim(1e-2, 6e-1)

fig.suptitle("3D Collapse Test")

plt.tight_layout()
../../_images/718cbb4b0a677b45eddcd75b244bc01d0a9afe1abd628f4d8123e675fcf4cf88.png

Conservational properties#

config = config._replace(return_snapshots = True, num_snapshots = 60)
params = params._replace(t_end = 3.0)

snapshots = time_integration(initial_state, config, params, helper_data, registered_variables)
 |████████████████████████████████████████████████████████████████████████████████████████████████████| 100.0% 
total_energy = snapshots.total_energy
internal_energy = snapshots.internal_energy
kinetic_energy = snapshots.kinetic_energy
gravitational_energy = snapshots.gravitational_energy
total_mass = snapshots.total_mass
time = snapshots.time_points
t_end = 3.0
# plt.plot(time[time < t_end], (total_mass[time < t_end] - total_mass[0]) / total_mass[0], label="Relative Mass Change")
# plt.plot(time[time < t_end], (total_energy[time < t_end] - total_energy[0]) / total_energy[0], label="Relative Energy Change")
plt.plot(time, total_energy, label="Total Energy", color = "black")
plt.plot(time, internal_energy, label="Internal Energy", color = "green")
plt.plot(time, kinetic_energy, label="Kinetic Energy", color = "red")
plt.plot(time, gravitational_energy, label="Gravitational Energy", color = "blue")
# plot the zero line
# plt.axhline(0, color = "black", linestyle = "--")
# plt.ylim(-3, 3)
plt.legend()
<matplotlib.legend.Legend at 0x7f85885563b0>
../../_images/7c64c9b485cee9338a9d4e6be4defa4e69d93c5de2fe35609edca9c615ad8011.png
import matplotlib.animation as animation

fig, ax = plt.subplots()

def animate(i):
    ax.clear()
    state = snapshots.states[i]
    im = ax.imshow(state[0, :, :, num_cells // 2].T, cmap="jet", origin="lower", extent=[0, box_size, 0, box_size], norm=LogNorm())
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Density at time {time[i]:.2f}")
    return [im]

ani = animation.FuncAnimation(fig, animate, frames=len(snapshots.states), interval=100, blit=True)
plt.colorbar(ax.imshow(snapshots.states[0][0, :, :, num_cells // 2].T, cmap="jet", origin="lower", extent=[0, box_size, 0, box_size], norm=LogNorm()), ax=ax)
# save to gif
ani.save("3d_collapse.gif")
../../_images/6d40369a5f8d6feac80f2e30d5134089cfb9cb872e6d469bae806c820767ccef.png