Source code for jf1uids.initial_condition_generation.construct_primitive_state
from jf1uids.fluid_equations.registered_variables import RegisteredVariables
from jf1uids.option_classes.simulation_config import FIELD_TYPE, STATE_TYPE, SimulationConfig
from typing import Union
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import jaxtyped
from functools import partial
from types import NoneType
# @jaxtyped(typechecker=typechecker)
[docs]
@partial(jax.jit, static_argnames=["registered_variables", "config", "sharding"])
def construct_primitive_state(
config: SimulationConfig,
registered_variables: RegisteredVariables,
density: FIELD_TYPE,
velocity_x: Union[FIELD_TYPE, NoneType] = None,
velocity_y: Union[FIELD_TYPE, NoneType] = None,
velocity_z: Union[FIELD_TYPE, NoneType] = None,
magnetic_field_x: Union[FIELD_TYPE, NoneType] = None,
magnetic_field_y: Union[FIELD_TYPE, NoneType] = None,
magnetic_field_z: Union[FIELD_TYPE, NoneType] = None,
gas_pressure: Union[FIELD_TYPE, NoneType] = None,
cosmic_ray_pressure: Union[FIELD_TYPE, NoneType] = None,
sharding=None,
) -> STATE_TYPE:
"""Stack the primitive variables into the state array.
IN 1D SET ONLY THE XCOMPONENTS, in 2D SET X AND Y COMPONENTS,
in 3D SET X, Y AND Z COMPONENTS
Args:
config: The simulation configuration.
registered_variables: The indices of the variables in the state array.
density: The density of the fluid.
velocity_x: The x-component of the velocity of the fluid.
velocity_y: The y-component of the velocity of the fluid.
velocity_z: The z-component of the velocity of the fluid.
magnetic_field_x: The x-component of the magnetic field in B / sqrt(\mu_0).
magnetic_field_y: The y-component of the magnetic field in B / sqrt(\mu_0).
magnetic_field_z: The z-component of the magnetic field in B / sqrt(\mu_0).
gas_pressure: The thermal pressure of the fluid.
cosmic_ray_pressure: The cosmic ray pressure of the fluid.
Returns:
The state array.
"""
if sharding is not None:
state = jax.lax.with_sharding_constraint(
jnp.zeros((registered_variables.num_vars, *density.shape)), sharding
)
else:
state = jnp.zeros((registered_variables.num_vars, *density.shape))
state = state.at[registered_variables.density_index].set(density)
if config.dimensionality == 1:
state = state.at[registered_variables.velocity_index].set(velocity_x)
elif config.dimensionality == 2:
state = state.at[registered_variables.velocity_index.x].set(velocity_x)
state = state.at[registered_variables.velocity_index.y].set(velocity_y)
elif config.dimensionality == 3:
state = state.at[registered_variables.velocity_index.x].set(velocity_x)
state = state.at[registered_variables.velocity_index.y].set(velocity_y)
state = state.at[registered_variables.velocity_index.z].set(velocity_z)
if config.mhd:
if config.dimensionality == 1:
state = state.at[registered_variables.magnetic_index].set(magnetic_field_x)
elif config.dimensionality >= 2:
state = state.at[registered_variables.magnetic_index.x].set(
magnetic_field_x
)
state = state.at[registered_variables.magnetic_index.y].set(
magnetic_field_y
)
state = state.at[registered_variables.magnetic_index.z].set(
magnetic_field_z
)
state = state.at[registered_variables.pressure_index].set(gas_pressure)
if registered_variables.cosmic_ray_n_active:
# TODO: get from params
gamma_cr = 4 / 3
state = state.at[registered_variables.pressure_index].set(
gas_pressure + cosmic_ray_pressure
)
state = state.at[registered_variables.cosmic_ray_n_index].set(
cosmic_ray_pressure ** (1 / gamma_cr)
)
return state