

Training Multi-Billion-Parameter AI Weather Models with Optimized Domain and Tensor Parallelism
Wednesday, June 24, 2026 3:45 PM to 5:15 PM · 1 hr. 30 min. (Europe/Berlin)
Foyer D-G - 2nd Floor
Women in HPC Poster
Earth, Climate and Weather ModelingHW and SW Design for Scalable Machine Learning
Information
Poster is on display and will be presented at the poster pitch session.
AI-based methods have rapidly revolutionized atmospheric modeling, with recent successes in medium-range forecasting spurring the development of foundation models. However, accurate modeling of complex atmospheric dynamics at high spatial resolutions requires billions of neural network parameters and gigabyte-sized data samples, making accelerator memory and I/O-bandwidth the bottlenecks for model training. To overcome these limitations, we introduce Jigsaw, a distributed training and inference scheme that leverages domain and tensor parallelism to eliminate memory redundancy across model-parallel processes and reduce I/O demands. We apply the Jigsaw parallelization scheme into an MLP-Mixer architecture, WeatherMixer, a multi-layer-perceptron-based model with global vision that is well-suited for learning weather phenomena. Using Jigsaw, we train WeatherMixer with up to 3.2B-parameters, achieving predictive performance competitive with numerical weather prediction and state-of-the-art AI models. To highlight the computational performance, we perform scaling experiments on global 0.25° (~30 km resolution) ERA5 data across two HPC systems. Anticipating that future reanalysis datasets will include even higher resolutions, we demonstrate, for the first time, training on 0.125° data.
The scaling experiments demonstrate that high-resolution input data samples benefit from domain parallelism and improve per-GPU computational throughput by reducing dataloading bottlenecks. In compute–communication–limited regimes, Jigsaw achieves state-of-the-art performance in distributed model training, with 97% of theoretical peak performance on 4 GPUs; and a strong scaling speedup of 6.4 when training across 8 GPUs. By combining domain, tensor, and data parallelism at larger scales, training on 256 GPUs reaches 11 PFLOPs with a scaling efficiency of 72% compared to 51% without Jigsaw.
Contributors:
AI-based methods have rapidly revolutionized atmospheric modeling, with recent successes in medium-range forecasting spurring the development of foundation models. However, accurate modeling of complex atmospheric dynamics at high spatial resolutions requires billions of neural network parameters and gigabyte-sized data samples, making accelerator memory and I/O-bandwidth the bottlenecks for model training. To overcome these limitations, we introduce Jigsaw, a distributed training and inference scheme that leverages domain and tensor parallelism to eliminate memory redundancy across model-parallel processes and reduce I/O demands. We apply the Jigsaw parallelization scheme into an MLP-Mixer architecture, WeatherMixer, a multi-layer-perceptron-based model with global vision that is well-suited for learning weather phenomena. Using Jigsaw, we train WeatherMixer with up to 3.2B-parameters, achieving predictive performance competitive with numerical weather prediction and state-of-the-art AI models. To highlight the computational performance, we perform scaling experiments on global 0.25° (~30 km resolution) ERA5 data across two HPC systems. Anticipating that future reanalysis datasets will include even higher resolutions, we demonstrate, for the first time, training on 0.125° data.
The scaling experiments demonstrate that high-resolution input data samples benefit from domain parallelism and improve per-GPU computational throughput by reducing dataloading bottlenecks. In compute–communication–limited regimes, Jigsaw achieves state-of-the-art performance in distributed model training, with 97% of theoretical peak performance on 4 GPUs; and a strong scaling speedup of 6.4 when training across 8 GPUs. By combining domain, tensor, and data parallelism at larger scales, training on 256 GPUs reaches 11 PFLOPs with a scaling efficiency of 72% compared to 51% without Jigsaw.
Contributors:
Format
on-demandon-site
