Source code for jf1uids.shock_finder.shock_finder

# general
from functools import partial
import jax.numpy as jnp
import jax

# typing
from typing import Tuple, Union
from jf1uids.data_classes.simulation_helper_data import HelperData
from jf1uids.fluid_equations.registered_variables import RegisteredVariables
from jf1uids.option_classes.simulation_config import (
    CARTESIAN,
    FIELD_TYPE,
    SPHERICAL,
    STATE_TYPE,
    SimulationConfig,
)
from jaxtyping import Array, Int, jaxtyped
from beartype import beartype as typechecker
from typing import Union

from jf1uids.option_classes.simulation_params import SimulationParams

# NOTE: currently only works for 1d setups, TODO: generalize


@partial(jax.jit, static_argnames=["config"])
def _calculate_1d_divergence(
    field: FIELD_TYPE, config: SimulationConfig, r: FIELD_TYPE
) -> FIELD_TYPE:
    # calculate the 1d divergence by a simple
    # central difference approximation
    div_field = jnp.zeros_like(field)
    if config.geometry == CARTESIAN:
        div_field = div_field.at[1:-1].set(
            (field[2:] - field[:-2]) / (2 * config.grid_spacing)
        )
    elif config.geometry == SPHERICAL:
        div_field = jnp.zeros_like(field)
        # this is not exactly correct, as our field values are
        # defined at the volumetric not geometric cell centers etc
        # but should be fine for the shock finder
        div_field = div_field.at[1:-1].set(
            (r[2:] ** 2 * field[2:] - r[:-2] ** 2 * field[:-2])
            / (2 * config.grid_spacing * r[1:-1] ** 2)
        )
    else:
        raise NotImplementedError(
            "Only Cartesian and Spherical geometry supported for the shock finder."
        )
    return div_field


[docs] @jax.jit def shock_sensor(pressure: FIELD_TYPE) -> FIELD_TYPE: """ WENO-JS 1D smoothness indicator for shock detection. Args: pressure: the 1d pressure Returns: shock sensors, high where large pressure jumps """ shock_sensors = jnp.zeros_like(pressure) shock_sensors = shock_sensors.at[1:-1].set( 1 / 4 * (pressure[2:] - pressure[:-2]) ** 2 + 13 / 12 * (pressure[2:] - 2 * pressure[1:-1] + pressure[:-2]) ) return shock_sensors
# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["registered_variables", "config"]) def shock_criteria( primitive_state: STATE_TYPE, config: SimulationConfig, registered_variables: RegisteredVariables, helper_data: HelperData, ) -> jnp.ndarray: """ Implement the shock criteria from Pfrommer et al, 2017. https://arxiv.org/abs/1604.07399 # NOTE: for now only 1D """ gamma_gas = 5 / 3 gamma_cr = 4 / 3 # get the velocity velocity = primitive_state[registered_variables.velocity_index] # get the cosmic ray pressure P_CRs = primitive_state[registered_variables.cosmic_ray_n_index] ** gamma_cr # i) \nabla \cdot \vec{v} < 0 div_v = _calculate_1d_divergence(velocity, config, helper_data.geometric_centers) converging_flow_criterion = div_v < 0 # ii) \nabla T \cdot \nabla \rho > 0 pseudo_temperature = ( primitive_state[registered_variables.pressure_index] / primitive_state[registered_variables.density_index] ) div_T = jnp.zeros_like(pseudo_temperature) div_T = div_T.at[1:-1].set((pseudo_temperature[2:] - pseudo_temperature[:-2]) / 2) div_rho = jnp.zeros_like(primitive_state[registered_variables.density_index]) div_rho = div_rho.at[1:-1].set( ( primitive_state[registered_variables.density_index][2:] - primitive_state[registered_variables.density_index][:-2] ) / 2 ) no_spurious_shocks = div_T * div_rho > 0 # iii) M1 > Mmin Mmin = 1.3 # NOTE: currently we only consider shocks moving left to right P2 = primitive_state[registered_variables.pressure_index, :-2] P2_CRs = P_CRs[:-2] e2_gas = P2 - P2_CRs # gas energy / volume e2_crs = P2_CRs / (gamma_cr - 1) # cosmic ray energy / volume e2 = e2_gas + e2_crs # total energy / volume rho2 = primitive_state[registered_variables.density_index, :-2] P1 = primitive_state[registered_variables.pressure_index, 2:] P1_CRs = P_CRs[2:] P1_gas = P1 - P1_CRs # gas pressure e1_gas = P1_gas / (gamma_gas - 1) # gas energy / volume e1_crs = P1_CRs / (gamma_cr - 1) # cosmic ray energy / volume e1 = e1_gas + e1_crs # total energy / volume rho1 = primitive_state[registered_variables.density_index, 2:] gamma_eff1 = (gamma_cr * P1_CRs + gamma_gas * P1_gas) / P1 gamma_eff2 = (gamma_cr * P2_CRs + gamma_gas * P2) / P2 gamma1 = P1 / e1 + 1 gamma2 = P2 / e2 + 1 gammat = P2 / P1 C = ((gamma2 + 1) * gammat + gamma2 - 1) * (gamma1 - 1) # advanced Mach number calculation, formula 16 from Dubois et al, 2019 denominator = jnp.where( jnp.abs(C - ((gamma1 + 1) + (gamma1 - 1) * gammat) * (gamma2 - 1)) > 1e-6, (C - ((gamma1 + 1) + (gamma1 - 1) * gammat) * (gamma2 - 1)), 1e-6, ) M1sq = 1 / gamma_eff2 * (gammat - 1) * C / denominator # simple Mach number calculation, crashes # the simulation where x_s = 1, better just evaluate # this where the other criterions hold / add a numerical # safeguard # x_s = rho2 / rho1 # M1sq = (P2 / P1 - 1) * x_s / (gamma_eff1 * (x_s - 1)) mach_number_criterion = jnp.zeros_like(converging_flow_criterion, dtype=jnp.bool_) mach_number_criterion = mach_number_criterion.at[1:-1].set(M1sq > Mmin**2) return converging_flow_criterion & no_spurious_shocks & mach_number_criterion
[docs] @partial(jax.jit, static_argnames=["registered_variables", "config"]) def find_shock_zone( primitive_state: STATE_TYPE, config: SimulationConfig, registered_variables: RegisteredVariables, helper_data: HelperData, ) -> Tuple[ Union[int, Int[Array, ""]], Union[int, Int[Array, ""]], Union[int, Int[Array, ""]] ]: """ Find a numerically broadened shock region based of the strongest shock based on the result of the shock_sensor function and the pressure difference between adjacent cells. Assumes a shock front moving left to right. Args: pressure: 1d pressure velocity: 1d velocity Returns: index of max shock sensor, left boundary of broadened shock, right boundary of broadened shock """ pressure = primitive_state[registered_variables.pressure_index] num_cells = pressure.shape[0] # one can either use the maximum of the shock sensor sensors = shock_sensor(pressure) # or the cell with maximum compression, as in Pfrommer et al 2017 # div_v = _calculate_1d_divergence(primitive_state[registered_variables.velocity_index], config, helper_data.geometric_centers) shock_crit = shock_criteria( primitive_state, config, registered_variables, helper_data ) max_shock_idx = jnp.argmax(jnp.where(shock_crit, sensors, -1)) # max_shock_idx = jnp.argmin(jnp.where(shock_crit, div_v, 1)) # calculate differences in pressure pressure_differences = jnp.zeros_like(pressure) # 0 <- 1 - 0 pressure_differences = pressure_differences.at[1:].set(pressure[1:] - pressure[:-1]) # bound on the change in pressure between adjacent cells compared # to the pressure jump at the max_shock_index bound_diff = 0.1 * jnp.abs(pressure_differences[max_shock_idx]) # left index: closest left index where |pressure_difference| < bound_diff or switched sign # right index: closest right index where |pressure_difference| < bound_diff or switched sign indices = jnp.arange(num_cells) left_indices = jnp.where( (indices < max_shock_idx) & ((jnp.abs(pressure_differences) < bound_diff) | (pressure_differences > 0)), indices, -1, ) right_indices = jnp.where( (indices > max_shock_idx) & ((jnp.abs(pressure_differences) < bound_diff) | (pressure_differences < 0)), indices, num_cells, ) left_idx = jnp.max(left_indices) right_idx = jnp.min(right_indices) return max_shock_idx, left_idx, right_idx