2D Kelvin Helmholtz Instability#

Imports#

# %pip install ../

# ==== GPU selection ====
from autocvd import autocvd
autocvd(num_gpus = 1)
# =======================

# numerics
import jax
import jax.numpy as jnp

# timing
from timeit import default_timer as timer

# plotting
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# jf1uids
from jf1uids import SimulationConfig
from jf1uids import get_helper_data
from jf1uids import SimulationParams
from jf1uids import time_integration
from jf1uids import construct_primitive_state
from jf1uids import get_registered_variables
from jf1uids.option_classes.simulation_config import finalize_config

Initiating the Kelvin Helmholtz Instability#

from jf1uids.option_classes.simulation_config import BACKWARDS, DOUBLE_MINMOD, FORWARDS, HLL, HLLC, HYBRID_HLLC, MINMOD, OSHER, PERIODIC_BOUNDARY, BoundarySettings, BoundarySettings1D

print("👷 Setting up simulation...")

# simulation settings
gamma = 5/3

# spatial domain
box_size = 1.0
num_cells = 1024

fixed_timestep = False
scale_time = False
dt_max = 0.1
num_timesteps = 2000

# setup simulation config
config = SimulationConfig(
    runtime_debugging = True,
    first_order_fallback = False,
    progress_bar = True,
    dimensionality = 2,
    box_size = box_size, 
    num_cells = num_cells,
    fixed_timestep = fixed_timestep,
    differentiation_mode = FORWARDS,
    num_timesteps = num_timesteps,
    boundary_settings = BoundarySettings(
        x = BoundarySettings1D(PERIODIC_BOUNDARY, PERIODIC_BOUNDARY),
        y = BoundarySettings1D(PERIODIC_BOUNDARY, PERIODIC_BOUNDARY)
    ),
    limiter = DOUBLE_MINMOD,
    return_snapshots = True,
    num_snapshots = 100,
    riemann_solver = HYBRID_HLLC,
)

helper_data = get_helper_data(config)

params = SimulationParams(
    t_end = 2.0,
    C_cfl = 0.4
)

registered_variables = get_registered_variables(config)
👷 Setting up simulation...

Setting the initial state#

from jax.random import PRNGKey, uniform

# Set the random seed for reproducibility
key = PRNGKey(0)

# Grid size and configuration
num_cells = config.num_cells
x = jnp.linspace(0, 1, num_cells)
y = jnp.linspace(0, 1, num_cells)
X, Y = jnp.meshgrid(x, y, indexing="ij")

# Initialize state
rho = jnp.ones_like(X)
u_x = 0.5 * jnp.ones_like(X)
u_y = 0.01 * jnp.sin(2 * jnp.pi * X)

# between y = 0.25 and y = 0.75 set u_x to -0.5 and rho to 2.0
mask = (Y > 0.25) & (Y < 0.75)
u_x = jnp.where(mask, -0.5, u_x)
rho = jnp.where(mask, 2.0, rho)

# Initialize pressure
p = jnp.ones((num_cells, num_cells)) * 2.5


# initial state
initial_state = construct_primitive_state(
    config = config,
    registered_variables = registered_variables,
    density = rho,
    velocity_x = u_x,
    velocity_y = u_y,
    gas_pressure = p
)
config = finalize_config(config, initial_state.shape)

Simulation#

result = time_integration(initial_state, config, params, helper_data, registered_variables)
 |████████████████████████████████████████████████████████████████████| 100.0%  

Visualization#

Cut#

from matplotlib.colors import LogNorm

final_state = result.states[-1]

s = 0.1

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# equal aspect ratio
ax1.set_aspect('equal', 'box')
ax2.set_aspect('equal', 'box')
ax3.set_aspect('equal', 'box')

x = jnp.linspace(0, box_size, num_cells)
y = jnp.linspace(0, box_size, num_cells)

ym, xm = jnp.meshgrid(x, y)

# on the first axis plot the density
# log scaler
norm_rho = LogNorm(vmin = jnp.min(final_state[0, :, :]), vmax = jnp.max(final_state[0, :, :]), clip = True)
norm_p = LogNorm(vmin = jnp.min(final_state[3, :, :]), vmax = jnp.max(final_state[3, :, :]), clip = True)

# ax1.scatter(xm.flatten(), ym.flatten(), c = final_state[0, :, :].flatten(), s = s, norm = norm_rho, marker = "s", cmap = "jet")
# ax1.set_title("Density")

ax1.imshow(final_state[0, :, :].T, norm = norm_rho, cmap = "jet", origin = "lower", extent = [0, box_size, 0, box_size])
ax1.set_title("Density")

# on the second axis plot the absolute velocity
# abs_vel = jnp.sqrt(final_state[1, :, :]**2 + final_state[2, :, :]**2)

# vel_norm = LogNorm(vmin = jnp.min(abs_vel), vmax = jnp.max(abs_vel), clip = True)

ax2.imshow(final_state[1, :, :].T, cmap = "jet", origin = "lower", extent = [0, box_size, 0, box_size])
ax2.set_title("Velocity")

# on the third axis plot the pressure
ax3.imshow(final_state[4, :, :].T, norm = norm_p, cmap = "jet", origin = "lower", extent = [0, box_size, 0, box_size])
ax3.set_title("Pressure")
Text(0.5, 1.0, 'Pressure')
../../_images/64c6d7cb9afe9a3b27ae4810be8b1beb42fa12f4c49127243b0b02db355125b1.png
# import matplotlib.animation as animation

# # Create a figure and axis for the animation
# fig, ax = plt.subplots(figsize=(6, 6))
# ax.set_aspect('equal', 'box')

# # Initialize the plot with the first frame
# density = result.states[0][0, :, :]
# norm = LogNorm(vmin=jnp.min(density), vmax=jnp.max(density), clip=True)
# im = ax.imshow(density.T, norm=norm, cmap="jet", origin="lower", extent=[0, box_size, 0, box_size])
# ax.set_title("Density")

# # Add a color bar
# # cbar = fig.colorbar(im, ax=ax)
# # cbar.set_label('Density')

# # Update function for the animation
# def update(frame):
#     density = result.states[frame][0, :, :]
#     im.set_data(density.T)
#     return [im]

# # Create the animation
# ani = animation.FuncAnimation(fig, update, frames=len(result.states), blit=True)

# # Display the animation
# ani.save("kelvin_helmholtz.gif", fps=24)