Source code for jf1uids._physics_modules._cosmic_rays.cr_fluid_equations

import jax.numpy as jnp
import jax
from functools import partial

from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker

from typing import Union

from jf1uids.fluid_equations.registered_variables import RegisteredVariables


# TODO: make 2D and 3D ready
[docs] @jaxtyped(typechecker=typechecker) @partial(jax.jit, static_argnames=['registered_variables']) def total_energy_from_primitives_with_crs( primitive_state: Float[Array, "num_vars num_cells"], registered_variables: RegisteredVariables ) -> Float[Array, "num_cells"]: # TODO: get from params gamma_cr = 4/3 gamma_gas = 5/3 # get the cosmic ray pressure cosmic_ray_pressure = primitive_state[registered_variables.cosmic_ray_n_index] ** gamma_cr # get the cosmic ray energy (density) cosmic_ray_energy = cosmic_ray_pressure / (gamma_cr - 1) # get the gas pressure gas_pressure = primitive_state[registered_variables.pressure_index] - cosmic_ray_pressure # get the gas energy rho_gas = primitive_state[registered_variables.density_index] velocity = primitive_state[registered_variables.velocity_index] gas_energy = gas_pressure / (gamma_gas - 1) + 0.5 * rho_gas * velocity**2 # total energy E_tot = gas_energy + cosmic_ray_energy return E_tot
[docs] @jaxtyped(typechecker=typechecker) @partial(jax.jit, static_argnames=['registered_variables']) def gas_pressure_from_primitives_with_crs( primitive_state: Float[Array, "num_vars num_cells"], registered_variables: RegisteredVariables ) -> Float[Array, "num_cells"]: # TODO: get from configs gamma_cr = 4/3 gamma_gas = 5/3 # get the cosmic ray pressure cosmic_ray_pressure = primitive_state[registered_variables.cosmic_ray_n_index] ** gamma_cr # return the gas pressure return primitive_state[registered_variables.pressure_index] - cosmic_ray_pressure
# TODO: make 2D and 3D ready
[docs] @jaxtyped(typechecker=typechecker) @partial(jax.jit, static_argnames=['registered_variables']) def total_pressure_from_conserved_with_crs( conserved_state: Float[Array, "num_vars num_cells"], registered_variables: RegisteredVariables ) -> Float[Array, "num_cells"]: # TODO: get from configs gamma_cr = 4/3 gamma_gas = 5/3 # get the cosmic ray pressure cosmic_ray_pressure = conserved_state[registered_variables.cosmic_ray_n_index] ** gamma_cr # get the cosmic ray energy (density) cosmic_ray_energy = cosmic_ray_pressure / (gamma_cr - 1) # get the gas energy gas_energy = conserved_state[registered_variables.pressure_index] - cosmic_ray_energy # get the gas pressure rho_gas = conserved_state[registered_variables.density_index] velocity = conserved_state[registered_variables.velocity_index] / rho_gas gas_pressure = (gas_energy - 0.5 * rho_gas * velocity**2) * (gamma_gas - 1) # get the total pressure total_pressure = cosmic_ray_pressure + gas_pressure return total_pressure