EF1 – Extracting dynamical Laws from Complex Data

Project

EF1-13

Stochastic and Rough Aspects in Deep Neural Networks

Project Heads

Christian Bayer, Peter-Karl Friz

Project Members

Nikolas Tapia

Project Duration

01.01.2021 − 31.12.2022

Located at

WIAS / TU Berlin

Description

We analyse residual neural networks using rough path theory. Extending worst-case stability analysis developed in the first funding period, we now embrace the stochastic nature (random initialization, stochastic optimization) of training networks, seen as ultimate justification of rough path analysis for deep networks.

 

Deep residual neural networks are a recent and major development in the field of deep learning. The citations of the key paper by He, Ren, Sun and Zhang quadrupled over the first funding period, now with more than 47K citations.

 

The basic idea is rather simple: one switches from the familiar network architecture x(i+1) = F(x(i)) to only model increments, i.e. x(i+1) = x(i) + F(x(i)), for i = 0,…,N-1 (x(i) denotes the state of the system at layer i).

 

The specific form of F is usually F(x) = σ(Wx + b) – for a weight matrix W and a bias vector b and a fixed non-linearity σ. Without much loss of generality we can represent the increments by x(i+1) = x(i) + f(x(i))W’ = x(i) + f(x(i))(w(i+1) – w(i)) interpreting the weights W’ as increments of a path w.

 

Both He et al., and E see this as Euler approximation of an ODE system controlled by w = w(t), allowing for the use of standard ODE results of such systems to ResNets, provided that w is regular (in view of standard IID initialisation of weights, a.k.a. (discrete) white noise, this is a strong assumption).

 

We provide a general stability analysis of residual neural networks and some related architectures taking into account their full stochastic dynamics, including random initialization and training. Mathematical results obtained in this project include a unified control theory of discrete and continuous rough differential equations and a theory of gradient descent on rough path space.

Rough path analysis provides a new tool-kit for the analysis of stability of DNN. Incorporating the stochastic features into the analysis improves our understanding, and opens up exciting new possibilities for improving architectures. In the context of our project, we envision expected signature methods (cf. Lyons 2014 ICM Lecture) as new tool to identify the dynamics of stochastic neural equations.

We propose three work packages for this project.

WP A: Time warping and network architectures. Our stability analysis only applies to ResNets of a fixed number of hidden layers. Indeed, the data traverses the layers through artificially introduced (fake) time. Even if the discrete rough paths driving both neural networks are parametrized on the same time interval, measuring their distance by the rough path distance does not make sense anymore. To overcome this obstacle, we propose a time-warping distance of rough ResNets, which is invariant under time-reparametrization. Indeed, the discrete signature contains all time-warping invariant features of the weight path. Therefore, it is natural to use these features to compare networks of different depths.

WP B: Embracing stochasticity. As weights are usually initialized in an i.i.d. way, the weights are nothing but (increments) of a high-dimensional random walk. By a functional central limit theorem, this implies that initial weights are close to white noise for deep networks. As a consequence, classical ODE or PDEs, as successfully proposed by Haber–Ruthotto, E and others, will inevitably miss an important feature of real DNNs. This brings us in the realm of stochastic differential equations, to be analyzed in a  discrete/continuous and pathwise setting. The above stability analysis is a worst case analysis. While the estimates are sharp, they may be non-optimal for most possible driving paths. Hence, we are also interested in an accompanying average case analysis. This however, poses considerable difficulty: even when the noise is Gaussian, which induces Gaussian concentration for the p-variation norms, its r.h.s. will have infinite expectation for p > 2. In the continuous case, similar problems are overcome, restoring integrability for any p in the Gaussian case. We will extend this analysis to the case of discrete random rough paths, later distorted by learning (cf. WP C).

WP C: Learning as pathspace transformation.Much of stochastic analysis is concerned with distorting a reference measure on paths (typically Wiener measure) to other measures which describe the statistics of a diffusive particle in a changing environment. (Understanding this transformation has been revolutionized by rough path theory.) Call P(0) the (pathspace) law of weights upon IID initialization, effectively (discrete) white noise. Training of a ResNet via (stochastic) gradient descent [13], can then be viewed as inducing as measure transformation P(n) P(n+1) on the (path)space of network weights. Empirical evidence supports preservation of roughness, hence any limit point of {P(n)} is expected to be supported on highly irregular paths. This dynamical system is intimately connected with the learning problem in ResNets and, at the same time, points to the law of typical network weights, w relevant for the average-case stability analysis presented in WP B. Moreover, some other works have pushed this perspective forward and consider training as an optimal control problem, whose solution can be characterised via the Potryagin maximum principle, which relates training to the backward flow of an ODE in weight space; methods are aligned with our recent work on rough transport.

External Website

Related Publications

  1. C. Bayer, P. K. Friz, N. Tapia. Stability of Deep Neural Networks via discrete rough paths. (2022) arXiv:2201.07566 [cs.LG].
  2. C. Bellingeri, A. Djurdjevac, P. K. Friz, N. Tapia. Transport and continuity equations with (very) rough noise.Partial Differ. Equ. Appl. 2 (2021), no. 49.
  3. J. Diehl, K. Ebrahimi-Fard, N. Tapia. Time-warping invariants of multidimensional time series. Acta Appl. Math 171(1):265–290 (2020). doi: 10.1007/s10440-020-00333-x.

Related Pictures