项目作者: mcbal

项目描述 :
Experimental PyTorch implementations of deep implicit layers for self-attention
高级语言:
项目地址: git://github.com/mcbal/deep-implicit-attention.git
创建时间: 2020-12-26T12:32:57Z
项目社区:https://github.com/mcbal/deep-implicit-attention

开源协议:MIT License

下载


Deep Implicit Attention

Experimental implementation of deep implicit attention in PyTorch.

Summary: Using deep equilibrium models to implicitly solve a set of self-consistent mean-field equations of a random Ising model implements attention as a collective response 🤗 and provides insight into the transformer architecture, connecting it to mean-field theory, message-passing algorithms, and Boltzmann machines.

Blog post: Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms

Mean-field theory framework for transformer architectures

Transformer architectures can be understood as particular approximations of a parametrized mean-field description of a vector Ising model being probed by incoming data x_i:

  1. z_i = sum_j J_ij z_j - f(z_i) + x_i

where f is a neural network acting on every vector z_i and the z_i are solved for iteratively.

DEQMLPMixerAttention

A deep equilibrium version of MLP-Mixer transformer attention (https://arxiv.org/abs/2105.02723, https://arxiv.org/abs/2105.01601):

  1. z_i = g({z_j}) - f(z_i) + x_i

where g is an MLP acting across the sequence dimension instead of
the feature dimension (so across patches). The network f parametrizes the self-correction term and acts across the feature dimension (so individually on every sequence).

Compared to a vanilla softmax attention transformer module (see below), the
sum over couplings has been “amortized” and parametrized by an MLP.
The fixed-point variables z_i are also fed straight into the
feed-forward self-correction term. One could feed the naive mean-field update g({z_j}) + x_i instead to fully mimic the residual connection in the explicit MLP-Mixer architecture.

DEQVanillaSoftmaxAttention

A deep equilibrium version of vanilla softmax transformer attention (https://arxiv.org/abs/1706.03762):

  1. z_i = sum_j J_ij z_j - f(z_i) + x_i

where

  1. J_ij = [softmax(X W_Q W_K^T X^T / sqrt(dim))]_ij

Transformer attention takes the couplings J_ij to depend on x_i parametrically and considers the fixed-point equation above as a single-step update equation. Compared to the explicit vanilla softmax attention transformer module, there’s no values and the fixed-point variables z_i are fed straight into the feed-forward self-correction term.

DEQMeanFieldAttention

Fast and neural deep implicit attention as introduced in https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/.

Schematically, the fixed-point mean-field equations including the Onsager self-correction term look like:

  1. z_i = sum_j J_ij z_j - f(z_i) + x_i

where f is a neural network parametrizing the self-correction term for
every site and x_i denote the input injection or magnetic fields applied
at site i. Mean-field results are obtained by dropping the self-
correction term. This model generalizes the current generation of transformers in the sense that its couplings are free parameters independent of the incoming data x_i.

DEQAdaTAPMeanFieldAttention

Slow and explicit deep implicit attention as introduced in https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/ (served as grounding and inspiration for fast and neural one above)

Ising-like vector model with multivariate Gaussian prior over spins. Generalization of the application of the adaptive TAP mean-field approach
from a system of binary/scalar spins to vector spins. Schematically, the
fixed-point mean-field equations including the Onsager term look like:

  1. S_i ~ sum_j J_ij S_j - V_i S_i + x_i

where the V_i are self-corrections obtained self-consistently and x_i
denote the input injection or magnetic fields applied at site i. The
linear response correction step involves solving a system of equations,
leading to a complexity ~ O(N^3*d^3). Mean-field results are obtained
by setting V_i = 0.

Given the couplings between spins and a prior distribution for the single-
spin partition function, the adaptive TAP framework provides a closed-form
solution in terms of sets of equations that should be solved for a fixed
point. The algorithm is related to expectation propagation (see Section
4.3 in https://arxiv.org/abs/1409.6179) and boils down to matching the
first and second moments assuming a Gaussian cavity distribution.

Setup

Install package in editable mode:

  1. $ pip install -e .

Run tests with:

  1. $ python -m unittest

References

Selection of literature

On variational inference, iterative approximation algorithms, expectation propagation, mean-field methods and belief propagation:

On the adaptive Thouless-Anderson-Palmer (TAP) mean-field approach in disorder physics:

On Boltzmann machines and mean-field theory:

On deep equilibrium models:

On approximate message passing (AMP) methods in statistics:

  • A unifying tutorial on Approximate Message Passing (2021) by Oliver Y. Feng, Ramji Venkataramanan, Cynthia Rush, Richard J. Samworth: the example on page 2 basically describes how transformers implement approximate message passing: an iterative algorithm with a “denoising” step (attention) followed by a “memory term” or Onsager correction term (feed-forward layer)

Code inspiration