On the Implicit Bias of Gradient Descent for Temporal Extrapolation

Edo Cohen-Karlik, Avichai Ben David, Nadav Cohen, Amir Globerson

Research output: Contribution to journalConference articlepeer-review

1 Scopus citations

Abstract

When using recurrent neural networks (RNNs) it is common practice to apply trained models to sequences longer than those seen in training. This “extrapolating” usage deviates from the traditional statistical learning setup where guarantees are provided under the assumption that train and test distributions are identical. Here we set out to understand when RNNs can extrapolate, focusing on a simple case where the data generating distribution is memoryless. We first show that even with infinite training data, there exist RNN models that interpolate perfectly (i.e., they fit the training data) yet extrapolate poorly to longer sequences. We then show that if gradient descent is used for training, learning will converge to perfect extrapolation under certain assumptions on initialization. Our results complement recent studies on the implicit bias of gradient descent, showing that it plays a key role in extrapolation when learning temporal prediction models.

Original languageEnglish
Pages (from-to)10966-10981
Number of pages16
JournalProceedings of Machine Learning Research
Volume151
StatePublished - 2022
Event25th International Conference on Artificial Intelligence and Statistics, AISTATS 2022 - Virtual, Online, Spain
Duration: 28 Mar 202230 Mar 2022

Funding

FundersFunder number
Amnon and Anat Shashua
Blavatnik Family Foundation
European Research Council
European Unions Horizon 2020 research and innovation programme
Google
Google Research Gift
Israel Science Foundation1780/21
Yandex Initiative in Machine Learning
Google
Blavatnik Family Foundation
European Research Council
Israel Science Foundation
Horizon 2020ERC HOLI 819080

    Fingerprint

    Dive into the research topics of 'On the Implicit Bias of Gradient Descent for Temporal Extrapolation'. Together they form a unique fingerprint.

    Cite this