Hamiltonian Neural Networks (HNNs) are a powerful approach to learning physics-preserving dynamics by modeling the Hamiltonian function \(H(q,p)\) of a system, rather than its vector field directly. This tutorial will guide you through the theory and implementation of HNNs using JAX and Flax.
Physics Informed ML
JAX
Flax
Published
June 9, 2025
Modified
June 15, 2025
Introduction
Neural networks are known to be fantastic universal approximators. Give them enough data, and they can learn almost any function. But what happens when that function describes a physical system? Consider a simple satellite orbiting a planet. We could train a standard neural network to predict its position and velocity at the next time step. It might work well for a short while, but over many time steps, tiny errors accumulate. The satellite might slowly spiral into the planet or drift away into space—violating a fundamental law of physics: the conservation of energy.
This is a common failure mode for standard neural networks. Despite their impressive capabilities, they often struggle with physical systems because they lack the built-in understanding that energy should be conserved, that certain symmetries should be preserved, and that the laws of physics are inviolable. They attempt to learn physics from scratch, and the approximations they learn are often imperfect. This often leads to significant long-term prediction errors.
What if we could teach neural networks to automatically learn and respect the fundamental laws of physics? This isn’t just a theoretical curiosity—it’s a practical necessity. Fortunately, a lot of progress has been made in this area of study, named Physics Informed Machine Learning (PIML). One of the most remarkable works in this field is Hamiltonian Neural Networks (HNN) (Greydanus, Dzamba, and Yosinski 2019), an elegant solution that bridges centuries-old physics with cutting-edge machine learning. But to understand why HNNs work so well, we need to take a journey from Newton’s familiar \(\vb{F} = m\vb{a}\) to the more abstract—but incredibly powerful—world of Hamiltonian mechanics.
The Journey from Newton to Hamilton
Starting Point: Newton’s Second Law
We all know Newton’s second law. For a system with \(N\) particles, we have:
\[\vb{F}_i = m_i \ddot{\vb{r}}_i,\]
where \(\vb{r}_i\) is the position of particle \(i\) and \(\ddot{\vb{r}}_i\equiv\dv*[2]{\vb{r}_i}{t}\). This gives us a system of \(3N\) second-order differential equations—one for each spatial dimension of each particle—that we can solve to find the trajectories of the particles over time.
While Newton’s approach is intuitive, it has limitations. Forces can be complex to specify, especially when dealing with constraints. Moreover, Newton’s formulation doesn’t immediately reveal conserved quantities like energy and momentum.
The Lagrangian Detour
The path to Hamiltonian mechanics goes through Lagrangian mechanics, which reformulates classical mechanics using energy rather than forces. Instead of thinking about forces acting on particles, we consider the energy of the system.
The Lagrangian is defined as:
\[L(\vb{q}, \dot{\vb{q}}, t) = T - V.\]
where:
\(T\) is the kinetic energy.
\(V\) is the potential energy.
\(\vb{q} = (q_1, q_2, \ldots, q_n)\) are generalized coordinates.
\(\dot{\vb{q}} \equiv \dv*{\vb{q}}{t}\) are the generalized velocities.
The genius of this approach is that \(\vb{q}\) can be any coordinates that completely describe the system’s configuration—not necessarily Cartesian positions. For a pendulum, we might use the angle, \(\theta\), and angular velocity \(\dot{\theta}\) as generalized coordinates and velocities, respectively, instead of \((x, y)\) coordinates and linear velocity. This flexibility allows us to choose coordinates that simplify the problem.
The equations of motion come from the Euler-Lagrange equation:
Using these equations, it is possible to show that
\[\dv{H}{t} = \pdv{H}{t}.\]
Therefore, if \(H\) doesn’t depend explicitly on time, then energy is automatically conserved.
Learning Physical Dynamics
The Problem with Traditional Neural Networks
Now, let’s examine what happens when we try to learn physical dynamics with standard neural networks. Suppose we have trajectory data \(\{(\vb{q}^{(i)}, \vb{p}^{(i)}, \dot{\vb{q}}^{(i)}, \dot{\vb{p}}^{(i)})\}\) from a physical system. A typical approach would be to train a neural network \(f_{\vb*{\theta}}\) to predict the time derivatives:
While this approach can fit the training data well, it suffers from a critical flaw: it does not respect conservation laws. When we integrate the learned dynamics forward in time, small errors accumulate and the system drifts away from conserved quantities like energy.
Consider a simple harmonic oscillator, with Hamiltonian
\[H = \frac{p^2}{2m} + \frac{1}{2}kq^2.\]
The true trajectories are ellipses in phase space, corresponding to constant energy levels. A traditional neural network might learn to approximate these trajectories, but small errors will cause the predicted trajectory to spiral inward or outward, violating energy conservation.
This is more than just a numerical issue; it reflects a fundamental mismatch between the neural network’s structure and the underlying physics.
Enter Hamiltonian Neural Networks
Hamiltonian Neural Networks (HNNs), introduced in (Greydanus, Dzamba, and Yosinski 2019), provide an elegant solution. Instead of directly learning the vector field \((\dot{\vb{q}}, \dot{\vb{p}})\), we use a neural network, \(H_{\vb*{\theta}}\), to approximate the Hamiltonian itself. Then, the dynamics are computed using Hamilton’s equations:
The crucial point is that we are optimizing the gradients of the neural network output, not the output itself.
Since we are learning \(H\), Hamilton’s equations automatically preserve the symplectic structure of phase space, and energy is exactly conserved by construction. That is the key insight that makes HNNs so powerful.
Implementation with JAX and Flax
Now that we’ve explored the theory behind Hamiltonian Neural Networks, it’s time to bring these ideas to life with code. In this section, we’ll walk through a hands-on implementation of an HNN using JAX and Flax. We’ll start by setting up our environment and dependencies, then build the model architecture, define the loss function, and finally implement some auxiliary functions to train and evaluate our network. Each step will be explained in detail, so you can follow along and understand not just how to implement an HNN, but also why each part is necessary. Let’s dive in.
Dependencies
Before we can build and train our Hamiltonian Neural Network, we need to set up our computational environment. Here, we import all the necessary libraries for numerical computation, neural network modeling, optimization, and visualization. Setting a random seed ensures that our experiments are reproducible, which is especially important when comparing results or debugging.
import jaxfrom jax import numpy as jnpfrom flax import nnximport optaxfrom scipy.integrate import solve_ivpimport numpy as npimport matplotlib.pyplot as plt# Set random seed for reproducibilityseed =123rngs = nnx.Rngs(seed)# Check available devices and the default backendprint(jax.devices())print(jax.default_backend())
[CpuDevice(id=0)]
cpu
Architecture
Now, let’s define the core of our HNN: the neural network architecture. Our model is a multilayer perceptron (MLP) that takes as input the position and momentum of the system and outputs a single scalar value—the Hamiltonian. The key innovation is that, instead of predicting the time derivatives directly, we use the gradients of this scalar output to recover the system’s dynamics via Hamilton’s equations. This structure is what allows the network to respect the underlying physics.
class HNN(nnx.Module):def__init__(self, system_dim, hidden_features, *, rngs):super().__init__()self.layers = [] in_features =2* system_dimfor out_features in hidden_features:self.layers.append(nnx.Linear(in_features, out_features, rngs=rngs)) in_features = out_featuresself.layers.append(nnx.Linear(in_features, 1, rngs=rngs))def__call__(self, q, p): q, p = jnp.atleast_1d(q, p) z = jnp.concatenate([q, p], axis=-1)for layer inself.layers[:-1]: z = nnx.tanh(layer(z)) h =self.layers[-1](z)return h.squeeze(-1)def symplectic_grad(self, q, p): dot_q = jax.vmap(jax.grad(self.__call__, argnums=1))(q, p) dot_p =-jax.vmap(jax.grad(self.__call__, argnums=0))(q, p)return dot_q, dot_pdef solve_ivp(self, q_0, p_0, t_span, *args, **kwargs):def f(t, z): q, p = jnp.split(z, 2, axis=-1) dot_q, dot_p =self.symplectic_grad(q, p) dot_z = jnp.concatenate([dot_q, dot_p], axis=-1)return dot_z q_0, p_0 = jnp.atleast_1d(q_0, p_0) z_0 = jnp.concatenate([q_0, p_0], axis=-1) sol = solve_ivp(f, t_span, z_0, *args, **kwargs) t, z = sol.t, sol.y.T q, p = np.split(z, 2, axis=-1)return t, q, p
The HNN class contains several important methods:
__init__: Constructs the neural network layers, stacking linear transformations with nonlinear activations to build a flexible function approximator for the Hamiltonian.
__call__: Defines the forward pass, taking position and momentum as input and returning the scalar Hamiltonian value.
symplectic_grad: Computes the time derivatives of position and momentum using the gradients of the learned Hamiltonian, implementing Hamilton’s equations in a differentiable way.
solve_ivp: Integrates the learned dynamics forward in time, allowing us to simulate trajectories by numerically solving the system’s equations of motion.
Loss function
With the model defined, we need a way to measure how well it captures the true dynamics of the system. The loss function does this by comparing the time derivatives predicted by the HNN (via the gradients of its Hamiltonian output) to the observed derivatives from our data. By minimizing this loss, we encourage the network to learn a Hamiltonian that generates the correct equations of motion.
Training a neural network involves more than just defining a model and a loss function. In this section, we set up the essential steps for training and evaluating our HNN, including batching the data, updating the model parameters, and tracking performance metrics. These utility functions help organize the training loop, making the process efficient and transparent. Let’s break down each component.
First, we define the training step. This function computes the loss and its gradients with respect to the model parameters, updates the optimizer, and tracks the loss for monitoring. Using JAX’s JIT compilation ensures that this step runs efficiently on modern hardware.
Next, we define the evaluation step. This function is similar to the training step, but it does not update the model parameters. Instead, it simply computes the loss on a batch of validation or test data, allowing us to track the model’s performance during training.
@nnx.jitdef eval_step(model, metrics, batch): model.eval() loss = loss_fn(model, batch) metrics.update(values=loss)
To efficiently train on large datasets, we need to split our data into batches. The batch_iterator function handles this, shuffling the data if desired and yielding batches of positions, momenta, and their derivatives. Batching is crucial for both computational efficiency and for the optimizer to generalize well.
During training, it is helpful to keep track of progress and spot any issues early. The print_logs function prints out the current epoch and the latest values of the tracked metrics at regular intervals, giving us a clear view of how training is proceeding.
def print_logs(epoch, history, period=10):if epoch % period ==0: logs = []for k, v in history.items(): key = k.replace('_', ' ').capitalize() value =f'{v[-1]:.4f}' logs.append(' = '.join([key, value]))print(f'Epoch {epoch +1}', *logs, sep=', ')
Finally, we bring everything together in the main training loop. The train function orchestrates the entire process: it iterates over epochs, performs training and evaluation steps, updates the metrics, and logs progress. At the end, it returns a history of the training and test losses, which can be used to analyze the model’s learning behavior.
def train( model, optimizer, z_train, dot_z_train, z_test, dot_z_test, epochs, batch_size=32): history = {'train_loss': [], 'test_loss': []} metrics = nnx.MultiMetric(loss=nnx.metrics.Average())for epoch inrange(epochs):# Trainingfor batch in batch_iterator(z_train, dot_z_train, batch_size, shuffle=True): train_step(model, optimizer, metrics, batch)for metric, value in metrics.compute().items(): history[f'train_{metric}'].append(value) metrics.reset() # Reset the metrics for the test set.# Evaluationfor batch in batch_iterator(z_test, dot_z_test, batch_size, shuffle=False): eval_step(model, metrics, batch)for metric, value in metrics.compute().items(): history[f'test_{metric}'].append(value) metrics.reset() # Reset the metrics for the next training epoch. print_logs(epoch, history)return history
Example: Simple pendulum
Let’s apply our HNN to the classic simple pendulum problem. We’ll derive the theoretical foundations and then generate datasets to train and evaluate our model.
Theory
Consider a simple pendulum of length \(l\) and mass \(m\), with angle \(\theta\) from the vertical. Polar coordinates are a natural choice here to describe the system:
This gives us Hamilton’s equations: \[\dot{\theta} = \frac{\partial H}{\partial p} = \frac{p}{ml^2},\]\[\dot{p} = -\frac{\partial H}{\partial \theta} = -mgl\theta.\]
The equation of motion becomes: \[\ddot{\theta} + \frac{g}{l}\theta = 0.\]
This is simple harmonic motion with angular frequency \(\omega = \sqrt{{g}/{l}}\). Given initial conditions \(\theta(0) = \theta_0\) and \(\dot{\theta}(0) = \dot{\theta}_0\), the general solution is:
Let’s walk through the implementation step by step, explaining the reasoning and intuition behind each part of the code.
Data
We’ll start by defining the physical parameters for our simple pendulum. These parameters set up the environment in which our pendulum will swing and will be used throughout the simulation and training process.
l =1.0# length (m)m =1.0# mass (kg)g =9.81# gravity (m/s^2)omega = jnp.sqrt(g / l) # natural frequency (s^{-1})
Next, we define the true Hamiltonian for the pendulum under the small angle approximation. This function serves as our ground truth, allowing us to compare the neural network’s predictions with the exact analytical solution.
@jax.vmapdef pendulum_hamiltonian(q, p): T = p**2/ (2* m * l**2) V = m * g * l * q**2/2 H = T + Vreturn H
To generate data for training and testing, we need to simulate the pendulum’s motion. The following function computes the analytical solution for the pendulum’s position, momentum, and their time derivatives, given initial conditions. This is crucial for creating a dataset that reflects the true physics of the system.
def pendulum_kinematics(t, q_0, dot_q_0): q = q_0 * jnp.cos(omega * t) + (dot_q_0 / omega) * jnp.sin(omega * t) dot_q =-q_0 * omega * jnp.sin(omega * t) + dot_q_0 * jnp.cos(omega * t) p = m * l**2* dot_q dot_p =-m * g * l * qreturn q, p, dot_q, dot_p# Vectorize in tpendulum_kinematics = jax.vmap(pendulum_kinematics, in_axes=(0, None, None))
Now, let’s put everything together in a function that simulates multiple experimental trajectories. This function generates noisy measurements, mimicking real-world data, and computes the corresponding velocities and momenta using finite differences. This dataset will be used to train and evaluate our HNN.
def simulate_experiment(n_trajectories, n_points, duration, noise_std=0.01, seed=123): np.random.seed(seed) t = np.linspace(0, duration, n_points) dt = t[1] - t[0] q_0 = np.random.uniform(-0.5, 0.5, (n_trajectories, 1)) dot_q_0 = np.zeros((n_trajectories, 1)) q = np.empty((n_trajectories, n_points, 1))# Generate clean analytical solutionsfor i inrange(n_trajectories): q[i], _, _, _ = pendulum_kinematics(t, q_0[i], dot_q_0[i])# Add noise to position measurements (simulating real experiments) q = q + np.random.normal(0, noise_std, q.shape)# Compute velocity using finite differences dot_q = np.gradient(q, dt, axis=1)# Compute momentum and its derivative p = m * l**2* dot_q dot_p = np.gradient(p, dt, axis=1)return t, (q, p), (dot_q, dot_p)
With our data generation pipeline ready, we can now create training and test datasets. This step ensures that our model will be evaluated on data it has not seen during training, providing a fair assessment of its generalization ability.
With the data in place, it is time to instantiate our Hamiltonian Neural Network and set up the optimizer. The model architecture is a multilayer perceptron, and we use the Adam optimizer for efficient training.
model = HNN(system_dim=1, hidden_features=[64, 64, 64], rngs=rngs)optimizer = nnx.Optimizer(model, optax.adam(1e-3))
Now, let’s train the model! The following code runs the training loop, logging the loss at each epoch so we can monitor progress. This is where the HNN learns to approximate the true Hamiltonian from data.
history = train( model, optimizer, z_train, dot_z_train, z_test, dot_z_test, epochs=100)
Epoch 1, Train loss = 2.6282, Test loss = 2.3213
Epoch 11, Train loss = 0.3642, Test loss = 0.3542
Epoch 21, Train loss = 0.3473, Test loss = 0.3672
Epoch 31, Train loss = 0.3585, Test loss = 0.3399
Epoch 41, Train loss = 0.3587, Test loss = 0.3422
Epoch 51, Train loss = 0.3474, Test loss = 0.3432
Epoch 61, Train loss = 0.3686, Test loss = 0.3602
Epoch 71, Train loss = 0.3692, Test loss = 0.3410
Epoch 81, Train loss = 0.3557, Test loss = 0.3433
Epoch 91, Train loss = 0.3435, Test loss = 0.3390
To visualize the training process, we plot the loss curves for both the training and test sets. A decreasing loss indicates that the model is successfully learning the underlying dynamics of the system.
As we can see from the training history plot, both the training and test losses decrease rapidly during the initial epochs and then stabilize at a low value. This indicates that the HNN is able to quickly learn the underlying dynamics of the system and generalizes well to unseen data. The close alignment between the training and test curves suggests that the model is not overfitting and is capturing the true physical behavior rather than memorizing the training set. This is a strong indication that the Hamiltonian Neural Network is successfully learning a physically meaningful representation of the system.
Evaluation
With a trained model in hand, let’s see how well it has learned the physics of the pendulum. We’ll use a series of visualizations and analyses to assess its performance.
Visualizing the Learned Hamiltonian
First, we will visualize the Hamiltonian surface learned by our neural network and compare it to the true analytical Hamiltonian. This gives us a direct look at what the model has captured about the system’s energy landscape.
Code
# Create a grid for visualizationq_grid = jnp.linspace(-0.6, 0.6, 50)p_grid = jnp.linspace(-2.0, 2.0, 50)Q, P = jnp.meshgrid(q_grid, p_grid)# Flatten for batch evaluationq = Q.reshape(-1, 1)p = P.reshape(-1, 1)# Compute HamiltoniansH_pred = model(q, p).reshape(Q.shape)H_true = pendulum_hamiltonian(q, p).reshape(Q.shape)# Place the minimun energy at 0 in both cases to compareH_pred = H_pred - np.min(H_pred)H_true = H_true - np.min(H_true)abs_error = jnp.abs(H_pred - H_true)# Create 2D contour plotsfig, axes = plt.subplots(1, 3, figsize=(9, 2.5), sharex=True, sharey=True)# Learned Hamiltonian plotcf1 = axes[0].contourf(Q, P, H_pred, cmap='viridis')axes[0].set_xlabel(r'$q$($\mathrm{rad}$)')axes[0].set_ylabel(r'$p$($\mathrm{kg\cdot m^2/2}$)')axes[0].set_title('Predicted Hamiltonian')fig.colorbar(cf1, ax=axes[0])# True Hamiltonian plotcf2 = axes[1].contourf(Q, P, H_true, cmap='viridis')axes[1].set_xlabel(r'$q$($\mathrm{rad}$)')axes[1].set_title('True Hamiltonian')fig.colorbar(cf2, ax=axes[1])# Absolute error plotcf3 = axes[2].contourf(Q, P, abs_error, cmap='magma')axes[2].set_xlabel(r'$q$($\mathrm{rad}$)')axes[2].set_title('Absolute Error')fig.colorbar(cf3, ax=axes[2])plt.show()
The comparison between the predicted and true Hamiltonian surfaces shows that the HNN has learned an energy landscape that closely matches the analytical solution across the entire phase space. The contour lines of the predicted Hamiltonian align well with those of the true Hamiltonian, indicating that the model has captured both the qualitative and quantitative structure of the system. The absolute error plot further confirms this: the error remains very low throughout most of the domain, with only minor discrepancies near the boundaries. This demonstrates that the HNN is not just fitting the training data, but is genuinely learning the underlying physics and generalizing well to new states.
Energy Conservation Analysis
One of the main advantages of HNNs is their ability to conserve energy over long time horizons. Let’s test this by integrating the learned dynamics and comparing the energy along the predicted and true trajectories. If the HNN has truly learned the physics, the energy should remain nearly constant over time.
Code
q_0 = np.random.uniform(-0.5, 0.5, (1,))dot_q_0 = p_0 = np.zeros((1,))duration =10.0# Long integration to test stabilitydt =0.01t = np.arange(0, duration, dt)# Compute the predicted trajectoryt, q_pred, p_pred = model.solve_ivp(q_0, p_0, (0, duration), t_eval=t)# Also get analytical solution for comparisonq_true, p_true, _, _ = pendulum_kinematics(t, q_0, dot_q_0)# Compute energy along trajectoryenergy_pred_for_trajectory_pred = model(q_pred, p_pred)energy_pred_for_trajectory_true = model(q_true, p_true)energy_true_for_trajectory_pred = pendulum_hamiltonian(q_pred, p_pred)energy_true_for_trajectory_true = pendulum_hamiltonian(q_true, p_true)fig, ax = plt.subplots(figsize=(9, 4))results = {r'$H_{\mathbf{\theta}}(q_{\text{pred}}, p_{\text{pred}})$': energy_pred_for_trajectory_pred,r'$H_{\mathbf{\theta}}(q_{\text{true}}, p_{\text{true}})$': energy_pred_for_trajectory_true,r'$H(q_{\text{pred}}, p_{\text{pred}})$': energy_true_for_trajectory_pred,r'$H(q_{\text{true}}, p_{\text{true}})$': energy_true_for_trajectory_true}for label, energy in results.items(): ax.plot(t, energy, label=label)ax.set_xlabel(r'Time ($\mathrm{s}$)')ax.set_ylabel(r'Energy ($\mathrm{J}$)')ax.legend()plt.show()
The energy conservation analysis plot demonstrates one of the key strengths of Hamiltonian Neural Networks. All four energy curves—whether computed from the predicted or true trajectories, and using either the learned or analytical Hamiltonian—remain nearly perfectly flat over time. This means that the HNN not only learns to reproduce the correct dynamics, but also inherently respects the conservation of energy, even over long integration times. The close overlap between the predicted and true energy curves further confirms that the model is capturing the true physics of the system, rather than just fitting the data. This level of stability and physical consistency is difficult to achieve with standard neural network approaches, highlighting the power of the HNN framework.
The energy curves in the plot above are all perfectly flat, but you may notice that the curves computed with the learned Hamiltonian \(H_{\theta}\) are offset vertically from those computed with the true Hamiltonian \(H\). This is not a bug or a sign of model error—in fact, it is a fundamental property of Hamiltonian mechanics! The Hamiltonian is only defined up to an additive constant: adding or subtracting a constant from \(H\) does not change the equations of motion or the system’s dynamics. As a result, the HNN is free to learn a Hamiltonian that differs from the true one by a constant offset. What matters is that the shape of the energy landscape and the conservation of energy are preserved, not the absolute value. This is why the curves are at different vertical levels, but remain perfectly flat and parallel over time.
Dynamics Comparison
To further evaluate our HNN, we will compare the predicted and true trajectories for several different initial conditions. This will show us how well the model generalizes across the phase space, both in terms of position, momentum, and the overall shape of the trajectories.
Code
from matplotlib.lines import Line2Dn_tests =5q_0 = np.linspace(0.1, 0.5, n_tests).reshape(-1, 1)dot_q_0 = p_0 = np.zeros((n_tests, 1))duration =10.0# Long integration to test stabilitydt =0.01t = np.arange(0, duration, dt)colors = plt.rcParams['axes.prop_cycle'].by_key()['color']fig, axes = plt.subplots(3, 1, figsize=(4, 12))for i inrange(n_tests):# Integrate with HNN t, q_pred, p_pred = model.solve_ivp(q_0[i], p_0[i], (0, duration), t_eval=t)# Get analytical solution q_true, p_true, _, _ = pendulum_kinematics(t, q_0[i], dot_q_0[i])# Plot trajectories axes[0].plot(t, q_pred.squeeze(), linestyle='dotted', color=colors[i]) axes[0].plot(t, q_true.squeeze(), color=colors[i]) axes[1].plot(t, p_pred.squeeze(), linestyle='dotted', color=colors[i]) axes[1].plot(t, p_true.squeeze(), color=colors[i])# Phase space axes[2].plot(q_pred.squeeze(), p_pred.squeeze(), linestyle='dotted', color=colors[i]) axes[2].plot(q_true.squeeze(), p_true.squeeze(), color=colors[i])axes[0].set_xlabel(r'Time ($\mathrm{s}$)')axes[0].set_ylabel(r'$q$($\mathrm{rad}$)')axes[0].set_title('Position Evolution')axes[1].set_xlabel(r'Time ($\mathrm{s}$)')axes[1].set_ylabel(r'$p$($\mathrm{kg\cdot m^2/s}$)')axes[1].set_title('Momentum Evolution')axes[2].set_xlabel(r'$q$($\mathrm{rad}$)')axes[2].set_ylabel(r'$p$($\mathrm{kg\cdot m^2/s}$)')axes[2].set_title('Phase Space Trajectories')# Custom legend elementslegend_elements = [ Line2D([0], [0], color='grey', lw=2, linestyle='-', label='True'), Line2D([0], [0], color='grey', lw=2, linestyle='dotted', label='Predicted')]# Add figure-level legend below the subplotsaxes[2].legend( handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=2)# Adjust layout to make space for legendplt.tight_layout()plt.show()
The plots above provide a direct, visual comparison between the trajectories predicted by the Hamiltonian Neural Network (dotted lines) and the true analytical solutions (continuous lines) for several different initial conditions. In the left and center panels, we see the evolution of position and momentum over time. The close overlap between the dotted and solid curves indicates that the HNN is able to accurately reproduce the true dynamics of the system—not just for a single trajectory, but across a range of initial states.
The main difference to notice is in the period of the \(q\) and \(p\) plots: the predicted trajectories oscillate with a slightly different frequency compared to the analytical ones. This suggests that the HNN has learned a value for the pendulum length \(l\) that is close, but not identical, to the true value used to generate the data. Despite this small discrepancy in period, the maximum amplitudes of both \(q\) and \(p\) are correctly captured, meaning the model accurately learns the energy scale and the range of motion for each trajectory.
The rightmost panel shows the phase space trajectories, where each orbit represents the evolution of position and momentum for a given initial condition. Here, the predicted (dotted) and true (solid) trajectories are nearly indistinguishable, demonstrating that the HNN has learned the correct geometry of the system’s motion. This level of agreement means the model generalizes well throughout the phase space, capturing both the quantitative details and the qualitative structure of the dynamics. In summary, the HNN is not simply memorizing the training data, but has learned the underlying physical laws that govern the pendulum’s motion.
Vector Field Visualization
Finally, let’s visualize the learned vector field of the HNN and compare it to the true vector field. This gives us a qualitative sense of how well the model has captured the underlying dynamics everywhere in phase space, not just along the training trajectories.
The vector field visualizations above provide a qualitative assessment of how well the Hamiltonian Neural Network has learned the underlying dynamics of the system. On the left, the predicted vector field generated by the HNN is shown, while the right panel displays the true analytical vector field. Each arrow represents the direction and magnitude of motion in phase space for a given position and momentum.
The close agreement between the predicted and true vector fields demonstrates that the HNN has not only learned to reproduce individual trajectories, but has also captured the global structure of the system’s dynamics throughout the entire phase space. The main differences are found at the boundaries of the plotted region, where the predicted vector field may deviate slightly from the true field. This is expected, as neural networks often generalize less accurately in regions with less training data. Nevertheless, the model faithfully represents the flow of the system in the central region, indicating that it has internalized the correct physical laws governing the pendulum’s motion.
Key Insights
When are Hamiltonian Neural Networks (HNNs) especially useful?
Long-term stability: HNNs shine when you need your model to respect energy conservation over long time horizons. This is crucial for simulating physical systems where even tiny errors can add up and cause unrealistic results.
Limited data: Because HNNs encode physical laws directly into their structure, they can often learn accurate dynamics from less data than a standard neural network would require.
Physical consistency: If you care about interpretability and want your model’s predictions to make physical sense (e.g., obeying conservation laws), HNNs are a great choice.
But HNNs are not a silver bullet. Some limitations to keep in mind:
Conservative systems only: HNNs assume the system doesn’t lose energy (no friction or dissipation). They’re not designed for systems where energy leaks away.
Need phase space: You must be able to describe your system in terms of positions and momenta, \((\vb{q}, \vb{p})\). This isn’t always straightforward for every problem.
Training complexity: The loss function involves gradients of the neural network output, which can make training trickier and sometimes slower than standard approaches.
Tips for successful HNN implementation:
Choose your coordinates wisely: Pick generalized coordinates \(\vb{q}\) that naturally fit your system. The right choice can make learning much easier.
Use symplectic integrators: When integrating trajectories, symplectic methods help preserve the geometric structure of phase space, leading to more accurate long-term predictions.
Preprocess your data: Make sure momentum and position variables are properly scaled. Good preprocessing can make a big difference in training stability and performance.
Conclusion
Hamiltonian Neural Networks beautifully demonstrate how classical physics and modern machine learning can be unified. By parameterizing the Hamiltonian—the fundamental quantity that generates dynamics—rather than the dynamics themselves, HNNs automatically incorporate conservation laws and symplectic structure.
This is not just mathematically elegant; it is also practically powerful. HNNs train faster, generalize better, and maintain physical consistency over long time horizons. They represent a paradigm where we don’t just use neural networks to approximate physical systems, but embed physical principles directly into the learning architecture.
The journey from Newton’s \(\vb{F} = m\vb{a}\) to Hamiltonian Neural Networks shows how different formulations of the same physics can lead to dramatically different computational approaches. As we continue to develop AI systems that interact with the physical world, this marriage of physics and learning will become increasingly important.
The future of scientific computing lies not in replacing physical knowledge with pure data-driven approaches, but in finding clever ways to combine the best of both worlds. Hamiltonian Neural Networks point the way forward.
References
Goldstein, Herbert, Charles P. Poole, and John L. Safko. 2008. Classical Mechanics. 3. ed., [Nachdr.]. San Francisco Munich: Addison Wesley.