FlashSchNet: Fast and Accurate Coarse-Grained Neural Network Molecular Dynamics

2026-02-13Machine Learning

Machine LearningComputational Engineering, Finance, and Science
AI summary

The authors improved a type of AI model called SchNet that helps simulate how molecules move, making it much faster and more memory-efficient on GPUs. They did this by carefully managing how data moves inside the GPU, combining multiple steps to avoid repeating work, and using smart math tricks that reduce memory use without losing accuracy. Their new system, FlashSchNet, runs simulations faster than both the previous AI model and traditional methods, while keeping the same quality. This means researchers can simulate molecular dynamics more efficiently using GPUs.

Graph Neural NetworksMolecular DynamicsSchNetGPU MemoryRadial Basis FunctionsMessage PassingQuantizationHigh-Bandwidth MemoryProtein Coarse-GrainingSimulation Throughput
Authors
Pingzhi Li, Hongxuan Li, Zirui Liu, Xingcheng Lin, Tianlong Chen
Abstract
Graph neural network (GNN) potentials such as SchNet improve the accuracy and transferability of molecular dynamics (MD) simulation by learning many-body interactions, but remain slower than classical force fields due to fragmented kernels and memory-bound pipelines that underutilize GPUs. We show that a missing principle is making GNN-MD IO-aware, carefully accounting for reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. We present FlashSchNet, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques: (1) flash radial basis, which fuses pairwise distance computation, Gaussian basis expansion, and cosine envelope into a single tiled pass, computing each distance once and reusing it across all basis functions; (2) flash message passing, which fuses cutoff, neighbor gather, filter multiplication, and reduction to avoid materializing edge tensors in HBM; (3) flash aggregation, which reformulates scatter-add via CSR segment reduce, reducing atomic writes by a factor of feature dimension and enabling contention-free accumulation in both forward and backward passes; (4) channel-wise 16-bit quantization that exploits the low per-channel dynamic range in SchNet MLP weights to further improve throughput with negligible accuracy loss. On a single NVIDIA RTX PRO 6000, FlashSchNet achieves 1000 ns/day aggregate simulation throughput over 64 parallel replicas on coarse-grained (CG) protein containing 269 beads (6.5x faster than CGSchNet baseline with 80% reduction of peak memory), surpassing classical force fields (e.g. MARTINI) while retaining SchNet-level accuracy and transferability.