GRAWA: Gradient-based Weighted Averaging for Distributed Training of Deep Learning Models

Tolga Dimlioglu, Anna Choromanska

Research output: Contribution to journalConference articlepeer-review


We study distributed training of deep learning models in time-constrained environments. We propose a new algorithm that periodically pulls workers towards the center variable computed as a weighted average of workers, where the weights are inversely proportional to the gradient norms of the workers such that recovering the flat regions in the optimization landscape is prioritized. We develop two asynchronous variants of the proposed algorithm that we call Model-level and Layer-level Gradient-based Weighted Averaging (resp. MGRAWA and LGRAWA), which differ in terms of the weighting scheme that is either done with respect to the entire model or is applied layer-wise. On the theoretical front, we prove the convergence guarantee for the proposed approach in both convex and non-convex settings. We then experimentally demonstrate that our algorithms outperform the competitor methods by achieving faster convergence and recovering better quality and flatter local optima. We also carry out an ablation study to analyze the scalability of the proposed algorithms in more crowded distributed training environments. Finally, we report that our approach requires less frequent communication and fewer distributed updates compared to the state-of-the-art baselines.

Original languageEnglish (US)
Pages (from-to)2251-2259
Number of pages9
JournalProceedings of Machine Learning Research
StatePublished - 2024
Event27th International Conference on Artificial Intelligence and Statistics, AISTATS 2024 - Valencia, Spain
Duration: May 2 2024May 4 2024

ASJC Scopus subject areas

  • Artificial Intelligence
  • Software
  • Control and Systems Engineering
  • Statistics and Probability


Dive into the research topics of 'GRAWA: Gradient-based Weighted Averaging for Distributed Training of Deep Learning Models'. Together they form a unique fingerprint.

Cite this