Set Norm and Equivariant Skip Connections: Putting the Deep in Deep Sets

Lily H. Zhang, Veronica Tozzo, John M. Higgins, Rajesh Ranganath

Research output: Contribution to journalConference articlepeer-review

Abstract

Permutation invariant neural networks are a promising tool for making predictions from sets. However, we show that existing permutation invariant architectures, Deep Sets and Set Transformer, can suffer from vanishing or exploding gradients when they are deep. Additionally, layer norm, the normalization of choice in Set Transformer, can hurt performance by removing information useful for prediction. To address these issues, we introduce the “clean path principle” for equivariant residual connections and develop set norm (SN), a normalization tailored for sets. With these, we build Deep Sets++ and Set Transformer++, models that reach high depths with better or comparable performance than their original counterparts on a diverse suite of tasks. We additionally introduce Flow-RBC, a new single-cell dataset and real-world application of permutation invariant prediction. We open-source our data and code here: https://github.com/rajeshlab/deep permutation invariant.

Original languageEnglish (US)
Pages (from-to)26559-26574
Number of pages16
JournalProceedings of Machine Learning Research
Volume162
StatePublished - 2022
Event39th International Conference on Machine Learning, ICML 2022 - Baltimore, United States
Duration: Jul 17 2022Jul 23 2022

ASJC Scopus subject areas

  • Artificial Intelligence
  • Software
  • Control and Systems Engineering
  • Statistics and Probability

Cite this