Appendix N — Hydrogen Atomic Orbitals

import warnings
from pathlib import Path

import numpy as np
from scipy.special import sph_harm, assoc_laguerre, factorial

import plotly.graph_objects as go
import plotly.subplots as sp
import plotly.io as pio

pio.renderers.default = "notebook"

# ------------------------------
# Hydrogenic wavefunction ψ_nlm
# ------------------------------
def hydrogen_wavefunction(n, l, m, r, phi, theta):
  """Calculate hydrogen atom wavefunction in spherical coordinates"""
  rho = 2 *r /n
  N = np.sqrt((2 /n)**3 *factorial(n -l -1) /(2 *n *factorial(n +l)))
  R = N *np.exp(-rho /2) *rho**l *assoc_laguerre(rho, n -l -1, 2 *l +1)
  
  with warnings.catch_warnings():
    warnings.filterwarnings('ignore', category=DeprecationWarning)
    Y = sph_harm(m, l, theta, phi)
  
  return R *Y

def estimate_params(n, use_probability=False):
  """Estimate plotting parameters based on quantum numbers"""
  rmax = n**2 *2.5
  iso_frac = 0.02 if use_probability else 0.08
  grid = 120 if use_probability else 100

  return rmax, iso_frac, grid

def resize_fig_with_colorbar(fig, scale):
  """Resize figure and colorbar proportionally"""
  fig.update_layout(
    width=(fig.layout.width or 700) *scale,
    height=(fig.layout.height or 450) *scale,
    font_size=14 *scale,
  )
  fig.update_traces(
    colorbar_thickness=15 *scale,
    selector=dict(type='heatmap')
  )
  
  return fig

# ------------------------------
# Main orbital plotter
# ------------------------------
def plot_orbital(
  n=2, l=1, m=0,
  plot_slices=True,
  use_probability=False,
  show_3d_axes=False,
  show_2d_axes=True,
  show_plot=True,
  save_html=False,
  filename="hydrogen_orbital.html",
):
  """
  Plot hydrogen atomic orbital with 3D isosurface and 2D slices
  
  Parameters:
    n: principal quantum number (n ≥ 1)
    l: angular momentum quantum number (0 ≤ l ≤ n -1)
    m: magnetic quantum number (-l ≤ m ≤ l)
    plot_slices: include 2D slice plots    
    use_probability: plot |ψ|² instead of Re(ψ)
    show_plot: display plot interactively
    show_3d_axes: show axes in 3D plot
    show_2d_axes: show grid/axes in 2D slices
    save_html: save plot as HTML file
    filename: output filename
  """
  rmax, iso_frac, grid = estimate_params(n, use_probability)
  margin = 1.1

  # Create 3D grid
  x = np.linspace(-rmax *margin, rmax *margin, grid)
  y = np.linspace(-rmax *margin, rmax *margin, grid)
  z = np.linspace(-rmax *margin, rmax *margin, grid)
  X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

  # Convert to spherical coordinates
  r = np.sqrt(X**2 +Y**2 +Z**2)
  phi = np.arccos(np.divide(Z, r, out=np.zeros_like(r), where=r!=0))
  theta = np.arctan2(Y, X)

  # Calculate wavefunction
  psi = hydrogen_wavefunction(n, l, m, r, phi, theta)
  
  # Configure plotting mode
  if use_probability:
    plot_data = np.abs(psi)**2
    data_label = "|ψ|²"
    colorscale_pos = 'Reds'
    colorscale_2d = 'Reds'
  else:
    plot_data = np.real(psi)
    data_label = "ψ"
    colorscale_pos = 'Reds'
    colorscale_neg = 'Blues'
    colorscale_2d = 'RdBu_r'
  
  iso = iso_frac *np.max(np.abs(plot_data))
  Xf, Yf, Zf, Vf = X.flatten(), Y.flatten(), Z.flatten(), plot_data.flatten()

  # ---------------------------------
  # Create subplot layout
  # ---------------------------------
  if plot_slices:
    fig = sp.make_subplots(
      rows=2, cols=3,
      specs=[
        [{"type": "scene", "colspan": 3}, None, None],
        [{"type": "xy"}, {"type": "xy"}, {"type": "xy"}]
      ],
      row_heights=[0.60, 0.40],
      column_widths=[0.33, 0.33, 0.34],
      horizontal_spacing=0.001,
      vertical_spacing=0.05,
      subplot_titles=(
        f"Hydrogen Atomic Orbital: {data_label} (n={n}, l={l}, m={m})",       
        "xy slice (z = 0)",
        "yz slice (x = 0)",
        "zx slice (y = 0)"
      )
    )
  else:
    fig = go.Figure()

  # -----------------------------
  # Add 3D Isosurface
  # -----------------------------
  iso_params = dict(
    x=Xf, y=Yf, z=Zf, value=Vf,
    #surface_count=1,
    opacity=0.25,
    caps=dict(x_show=False, y_show=False, z_show=False),
    showscale=False,
    hoverinfo='skip'
  )
  
  if use_probability:
    iso_pos = go.Isosurface(
      isomin=iso, isomax=iso,
      colorscale=colorscale_pos,
      name="probability",
      **iso_params
    )
    fig.add_trace(iso_pos, row=1, col=1) if plot_slices else fig.add_trace(iso_pos)
  else:
    iso_pos = go.Isosurface(
      isomin=iso, isomax=iso,
      colorscale=colorscale_pos,
      name="+lobe",
      **iso_params
    )
    iso_neg = go.Isosurface(
      isomin=-iso, isomax=-iso,
      colorscale=colorscale_neg,
      name="-lobe",
      **iso_params
    )
    if plot_slices:
      fig.add_trace(iso_pos, row=1, col=1)
      fig.add_trace(iso_neg, row=1, col=1)
    else:
      fig.add_trace(iso_pos)
      fig.add_trace(iso_neg)

  # -----------------------------
  # Add 3D cross-section slice
  # -----------------------------
  if not plot_slices:
    k, i, j = grid // 2, grid // 2, grid // 2
    i1, i2 = grid // 5, 4 *grid // 5
    
    fig.add_trace(
      go.Surface(
        x=X[i1:i2, i1:i2, k],
        y=Y[i1:i2, i1:i2, k],
        z=Z[i1:i2, i1:i2, k], 
        surfacecolor=plot_data[i1:i2, i1:i2, k],
        colorscale=colorscale_2d,
        showscale=False,
        opacity=0.3,
        hoverinfo='skip',             
      )
    )

    fig.add_trace(
      go.Surface(
        x=X[i, i1:i2, i1:i2],
        y=Y[i, i1:i2, i1:i2],
        z=Z[i, i1:i2, i1:i2],   
        surfacecolor=plot_data[i, i1:i2, i1:i2],
        colorscale=colorscale_2d,
        showscale=False,
        opacity=0.3,
        hoverinfo='skip'
      )
    )

    fig.add_trace(
      go.Surface(
        x=X[i1:i2, j, i1:i2],
        y=Y[i1:i2, j, i1:i2],
        z=Z[i1:i2, j, i1:i2],   
        surfacecolor=plot_data[i1:i2, j, i1:i2],
        colorscale=colorscale_2d,
        showscale=False,
        opacity=0.3,
        hoverinfo='skip'
      )
    )

    fig.add_trace(
      go.Heatmap(
        z=[[plot_data.min(), plot_data.max()]],
        colorscale=colorscale_2d,
        showscale=True,   
        opacity=0.0,     
        hoverinfo="skip",
        colorbar=dict(
          x=0.5,
          y=-0.15,
          orientation="h",
          thickness=12,
          len=0.6,
          xanchor="center"
        )
      )
    )
    fig.update_layout(
      xaxis=dict(visible=False, showgrid=False, zeroline=False),
      yaxis=dict(visible=False, showgrid=False, zeroline=False)
    )

  # -----------------------------
  # Add 2D slices
  # -----------------------------
  if plot_slices:

    fig.update_scenes(
      xaxis=dict(visible=show_3d_axes),
      yaxis=dict(visible=show_3d_axes),
      zaxis=dict(visible=show_3d_axes),
      aspectmode="cube",
      row=1, col=1
    )

    k, i, j = grid // 2, grid // 2, grid // 2
    
    if use_probability:
      vmin, vmax = 0, np.max(plot_data)
    else:
      absmax = np.max(np.abs(plot_data))
      vmin, vmax = -absmax, absmax
    
    heatmap_params = dict(
      colorscale=colorscale_2d,
      zmin=vmin,
      zmax=vmax
    )
    
    # XY slice
    fig.add_trace(go.Heatmap(
      z=plot_data[:, :, k].T,
      x=x, y=y,
      showscale=True,
      hovertemplate=f'x: %{{x:.2f}}<br>y: %{{y:.2f}}<br>{data_label}: %{{z:.4f}}<extra></extra>',
      colorbar=dict(x=0.5, y=-0.2, len=0.8, thickness=15, 
                    orientation="h", title=dict(text=data_label, side="bottom")),
      **heatmap_params
    ), row=2, col=1)
    fig.update_xaxes(range=[-rmax, rmax], constrain="domain", row=2, col=1)
    fig.update_yaxes(range=[-rmax, rmax], scaleanchor="x", scaleratio=1, 
                     showgrid=show_2d_axes, zeroline=show_2d_axes, row=2, col=1)
    
    # YZ slice
    fig.add_trace(go.Heatmap(
      z=plot_data[i, :, :].T,
      x=y, y=z,
      showscale=False,
      hovertemplate=f'y: %{{x:.2f}}<br>z: %{{y:.2f}}<br>{data_label}: %{{z:.4f}}<extra></extra>',
      **heatmap_params
    ), row=2, col=2)
    fig.update_xaxes(range=[-rmax, rmax], constrain="domain", row=2, col=2)
    fig.update_yaxes(range=[-rmax, rmax], scaleanchor="x2", scaleratio=1,
                     showgrid=show_2d_axes, zeroline=show_2d_axes, row=2, col=2)
    
    # ZX slice
    fig.add_trace(go.Heatmap(
      z=plot_data[:, j, :].T,
      x=z, y=x,
      showscale=False,
      hovertemplate=f'z: %{{x:.2f}}<br>x: %{{y:.2f}}<br>{data_label}: %{{z:.4f}}<extra></extra>',
      **heatmap_params
    ), row=2, col=3)
    fig.update_xaxes(range=[-rmax, rmax], constrain="domain", row=2, col=3)
    fig.update_yaxes(range=[-rmax, rmax], scaleanchor="x3", scaleratio=1,
                     showgrid=show_2d_axes, zeroline=show_2d_axes, row=2, col=3)

  else:

    fig.update_scenes(
      xaxis=dict(visible=show_3d_axes),
      yaxis=dict(visible=show_3d_axes),
      zaxis=dict(visible=show_3d_axes),
      aspectmode="cube"
    )
    fig.update_layout(
      title={
        "text": f"Hydrogen Atomic Orbital: {data_label} (n={n}, l={l}, m={m})",
        "x": 0.5,
        "xanchor": "center",
        "yanchor": "top",
        "font": dict(size=18)
      }
    )    

  # -----------------------------
  # Finalize layout
  # -----------------------------
  fig.update_layout(
    width=1300,
    height=950,
    margin=dict(l=20, r=20, t=80, b=20)
  )

  fig.update_annotations(font_size=12)
  fig = resize_fig_with_colorbar(fig, 0.55)

  if save_html:
    fig.write_html(filename)

  if show_plot:
    fig.show()

  return None

# -----------------------------
# User-friendly wrapper
# -----------------------------
def hydrogen_orbital_gui(
  n, l, m, 
  use_probability=False, 
  plot_slices=True, 
  show_plot=True, 
  save_html=False
):
  """
  Generate hydrogen orbital visualization with automatic file naming
  
  Parameters:
    n: principal quantum number
    l: angular momentum quantum number
    m: magnetic quantum number
    use_probability: plot probability density instead of wavefunction    
    show_plot: display plot
    save_html: save as HTML
  """
  if n < 1:
    raise ValueError("n must satisfy n >= 1")
  if l > n -1:
    raise ValueError("l must be =< n −1")
  if not (-l <= m <= l):
    raise ValueError("m must satisfy -l =< m =< l")

  results_folder = Path("figures/orbital")
  results_folder.mkdir(exist_ok=True, parents=True)

  mode = "pd" if use_probability else "wf"
  style = "d" if plot_slices else "s"
  fname = results_folder / f"orbital_{mode}_{style}_n{n}_l{l}_m{m}.html"
  
  plot_orbital(
    n=n, l=l, m=m,
    plot_slices=plot_slices,
    use_probability=use_probability,    
    show_plot=show_plot,
    show_3d_axes=False,
    show_2d_axes=True,
    save_html=save_html,
    filename=fname,
  )
hydrogen_orbital_gui(2, 0, 0, plot_slices=False, show_plot=False, save_html=True)
hydrogen_orbital_gui(2, 1, 1, show_plot=False, save_html=True)
hydrogen_orbital_gui(3, 1, 0, show_plot=False, save_html=True)
hydrogen_orbital_gui(3, 1, 0, plot_slices=False, show_plot=False, save_html=True)
hydrogen_orbital_gui(4, 2, -2, show_plot=False, save_html=True)
hydrogen_orbital_gui(4, 2, 0, plot_slices=False, show_plot=False, save_html=True)
hydrogen_orbital_gui(2, 1, 0, use_probability=True, show_plot=False, save_html=True)
hydrogen_orbital_gui(4, 2, -2, use_probability=True, show_plot=False, save_html=True)