# general
import jax
import jax.numpy as jnp
from functools import partial
# type checking
from beartype import beartype as typechecker
from jaxtyping import Array, Float, jaxtyped
from typing import Union
# jf1uids containers
from jf1uids.fluid_equations.registered_variables import RegisteredVariables
from jf1uids.option_classes.simulation_config import STATE_TYPE, SimulationConfig
from jf1uids.data_classes.simulation_helper_data import HelperData
# jf1uids functions
from jf1uids._physics_modules._self_gravity._self_gravity import _compute_gravitational_potential
from jf1uids.fluid_equations.fluid import get_absolute_velocity, total_energy_from_primitives
[docs]
@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=['config', 'registered_variables'])
def calculate_internal_energy(state, helper_data, gamma, config, registered_variables):
num_ghost_cells = config.num_ghost_cells
p = state[registered_variables.pressure_index]
internal_energy = p / (gamma - 1)
if config.dimensionality == 1:
return jnp.sum(internal_energy * helper_data.cell_volumes[num_ghost_cells:-num_ghost_cells])
else:
return jnp.sum(internal_energy * config.grid_spacing**config.dimensionality)
[docs]
@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=['config', 'registered_variables'])
def calculate_kinetic_energy(state, helper_data, config, registered_variables):
num_ghost_cells = config.num_ghost_cells
rho = state[registered_variables.density_index]
u = get_absolute_velocity(state, config, registered_variables)
kinetic_energy = 0.5 * rho * u ** 2
if config.dimensionality == 1:
return jnp.sum(kinetic_energy * helper_data.cell_volumes[num_ghost_cells:-num_ghost_cells])
else:
return jnp.sum(kinetic_energy * config.grid_spacing**config.dimensionality)
[docs]
@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=['config', 'registered_variables'])
def calculate_gravitational_energy(state, helper_data, gravitational_constant, config, registered_variables):
num_ghost_cells = config.num_ghost_cells
rho = state[registered_variables.density_index]
potential = _compute_gravitational_potential(rho, config.grid_spacing, config, gravitational_constant)
gravitational_energy = 0.5 * rho * potential
if config.dimensionality == 1:
return jnp.sum(gravitational_energy * helper_data.cell_volumes[num_ghost_cells:-num_ghost_cells])
else:
return jnp.sum(gravitational_energy * config.grid_spacing**config.dimensionality)
[docs]
@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=['config', 'registered_variables'])
def calculate_total_energy(
primitive_state: STATE_TYPE,
helper_data: HelperData,
gamma: Union[float, Float[Array, ""]],
gravitational_constant: Union[float, Float[Array, ""]],
config: SimulationConfig,
registered_variables: RegisteredVariables
) -> Float[Array, ""]:
"""
Calculate the total energy in the domain.
Args:
primitive_state: The primitive state array.
helper_data: The helper data.
gamma: The adiabatic index.
num_ghost_cells: The number of ghost cells.
Returns:
The total energy.
"""
num_ghost_cells = config.num_ghost_cells
rho = primitive_state[registered_variables.density_index]
u = get_absolute_velocity(primitive_state, config, registered_variables)
p = primitive_state[registered_variables.pressure_index]
energy = total_energy_from_primitives(rho, u, p, gamma)
if config.self_gravity:
potential = _compute_gravitational_potential(rho, config.grid_spacing, config, gravitational_constant)
energy += 0.5 * rho * potential
slice_off_ghost_cells = (slice(num_ghost_cells, -num_ghost_cells),) * config.dimensionality
energy = energy[slice_off_ghost_cells]
if config.dimensionality == 1:
return jnp.sum(energy * helper_data.cell_volumes[num_ghost_cells:-num_ghost_cells])
else:
return jnp.sum(energy * config.grid_spacing**config.dimensionality)
[docs]
@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=['config'])
def calculate_total_mass(
primitive_state: STATE_TYPE,
helper_data: HelperData,
config: SimulationConfig,
) -> Float[Array, ""]:
"""
Calculate the total mass in the domain.
Args:
primitive_state: The primitive state array.
helper_data: The helper data.
config: The simulation configuration.
Returns:
The total mass.
"""
num_ghost_cells = config.num_ghost_cells
if config.dimensionality == 1:
return jnp.sum(primitive_state[0, num_ghost_cells:-num_ghost_cells] * helper_data.cell_volumes[num_ghost_cells:-num_ghost_cells])
else:
slice_off_ghost_cells = (0,) + (slice(num_ghost_cells, -num_ghost_cells),) * config.dimensionality
# note that here the box size is assumed to be the box size without the ghost cells
return jnp.sum(primitive_state[slice_off_ghost_cells]) * config.box_size**config.dimensionality