Attention, as popularized by the landmark paper Attention Is All You Need (2017), is arguably the most important architectural trend in machine learning right now. Originally intended for sequence to sequence modeling, attention has exploded into virtually every sub-discipline of machine learning.

This post will describe a particular flavor of attention which proceeded the transforner style of attention. We’ll discuss how it works, and why it’s useful. We’ll also go over some literature and a tutorial implementing this form of attention in PyTorch. By reading this post, you will have a more thorough understanding of attention as a general concept, which is useful in exploring more cutting edge applications.

The attention mechanism was originally popularized in Neural Machine Translation by Jointly Learning to Align and Translate(2014), which is the guiding reference for this particular post. This paper employs an encoder-decoder architecture for english-to-french translation.

This is a very common architecture, but the exact details can change drastically from implementation to implementation. For instance, some of the earlier literature in sequence to sequence encoder-decoders were recurrent networks that would incrementally “build” and then “deconstruct” the embedding.

This general idea, and minor variations therein, was state of the art for several years. However, one problem with this approach is that the entire input sequence has to be embedded into the embedding space, which is generally a fixed sized vector. As a result, these models can easily forget content from sequences which are too long. **The attention mechanism was designed to alleviate the problem of needing to fit the entire input sequence into the embedding space. It does this by telling the model which inputs are related to which outputs. **Or, in other words, the attention mechanism allows a model to focus on relevant portions of the input, and disregard the rest.

Practically, the attention mechanism we will be discussing ends up being a matrix of scores, called “alignment” scores. These alignment scores encode the degree to which a word in an input sequence relates to a word in the output sequence.

The alignment score can be computed many ways. We’ll stick with our 2014 paper and pick apart it’s particular alignment function:

When calculating the allignment for the ith output, this approach uses the previous embedded state of the decoder (s_i-1), the embedding of an input word (h_j), and the learnable parameters W_a, U_a, and v_a, to calculate the alignment of the ith output with the jth input. The tanh activation function is included to add non-linearity, which is vital in training a model to understand complex relationships.

**In other words, the function above calculates a score between the next output word and a single input word, denoting how relevant an input word is to the current output**. This function is run across all input words (h_j) to calculate an alignment score for all input words given the current output.

A softmax function is applied across all of the computed alignments, turning them into a probability. This is referred to in the literature as a “soft-search” or “soft-alignment”.

Exactly how attention is used can vary from implementation to implementation. In Neural Machine Translation by Jointly Learning to Align and Translate(2014), the attention mechanism decides which input embeddings to provide to the decoder.

It does this selection process with a weighted sum. All input embeddings are multiplied by their respective alignment score (in practice most of the alignment scores have a value of zero, while one or two might have a value of 0.8 and 0.2 for instance), then those weighted embeddings are added together to create the context vector for a particular output.

**The context vector is the combination of all of the inputs which are relevant to the current output.**

The following figure ties together how attention fits into the bigger picture:

- The inputs are embedded into some initial vector representation (using word2vect, for instance).
- Those are passed through a bi-direcional LSTM, to create a bit of context awareness between the embeddings
- Alignment scores are calculated for each input using the previous decoder embedding and the learned parameters within the alignment function.
- the soft-maxed alignments are multiplied against each input, added together, and used to construct the context vector
- the decoder uses the previous decoder hidden state, along with the context vector, to generate the prediction of the current word.

In the next section we will implement this attention mechanism in PyTorch.

While I originally set out to implement the entire english to french example, it became apparent that the implementation would be excessively long, and contain many intricacies which are irrelevant to the explanation of attention itself. As a result, I created a toy problem which mimics the grammatical aspects of english to french translation to showcase the attention mechanism specifically, without the bulk of implementing LSTM’s, embeddings, utility tokens, batching, masking, and other problem specific components.

The full code can be found here, for those that are interested

As previously mentioned, english to french translation can be thought of as two subproblems; alignment and translation. The various networks within the encoder and decoder translate values, while the attention mechanism re-orients the vectors. In other words, **Attention is all about alignment.** To emulate the alignment problem of english to french translation the following toy problem was defined:

given some shuffled input of values

`[[ 0. 1.], [17. 18.], [10. 11.], [13. 14.], [14. 15.], [ 2. 3.] ... ]`

Organize them into a sequential output:

`[[0. 1.], [1. 2.], [ 2. 3.], [ 3. 4.], [ 4. 5.], [ 5. 6.] ...]`

Practically, the toy question posed to the attention mechanism is: **Given the previous output vector, which output should come next, given a selection of possible outputs? **This is a very similar question to the gramatical question in english to french translation, which is: **Given a previous output word, which inputs are relevant to the next one? **Thus, by solving this toy problem, we can show the power of attention mechanism without getting too far in the weeds.

## Defining the Alignment Function

Recall the alignment function

This function, essentially, decides the weight (α) of an input (hj) given the previous output (si-1). This can be implemented directly in PyTorch:

`"""`

Implimenting the alignment function.The whole idea is that, given an embedding for the encoder and decoder, the

alignment func outputs a scalar raiting the alignment. So, the shapes of

v, W, and U should be such that the output is a scalar.

"""

import torch

import torch.nn.functional as F

#defining the size of the input and output vectors of the attention mechanism

EMBED_DIM = 100

#these need to be sized in such a way that matrix multiplication yields a scalar

#otherwise, they're just general learnable parameters. Different alignment

#functions might have different parameters. For instance, "attention is all you

#need" uses a projection head that generates a query, key, and value, which are

#used in a different self-alignment function. this can allign vectors of different

#lengths

encoder_embedding_dim = EMBED_DIM*2

decoder_embedding_dim = EMBED_DIM

U_attention = torch.rand(EMBED_DIM, encoder_embedding_dim)

W_attention = torch.rand(decoder_embedding_dim, EMBED_DIM)

v_attention = torch.rand(1,EMBED_DIM)

def alignment_func(s, h, W=W_attention, U=U_attention, v=v_attention):

"""

s: si-1, from the paper, the previous decoder state

h: hj, from the paper, an input embedding

W,U,v: trainable parameters

calculates v*tanh(W*s + U*h), should return the scalar alpha

"""

v1 = torch.matmul(W,s)

v2 = torch.matmul(U,h)

v3 = F.tanh(v1+v2)

return torch.matmul(v, v3)

#testing the alignment function between one embedded word and another

#dividing by value to get them in a good range for tanh

s = torch.rand(decoder_embedding_dim)/50

h = torch.rand(encoder_embedding_dim)/50

alignment_func(s, h)

## Defining Attention

For a given previous output, the task of the attention mechanism is to calculate which inputs to pay attention to. This can be done by calculating the alignment for all inputs and passing that vector of alignments through a softmax.

`"""`

defining attention, wich is a list of softmaxed alignment scores for all input

embeddings (hj) given the previous decoder embedding (si-1). This is equivilent

to a row of the attention matrix, hence the name of the function.

"""def compute_attention_row(s, hs, W=W_attention, U=U_attention, v=v_attention):

"""

computes alignments for all h values given s

s is a vector of length embedding size

hs is a tensor of shape (sequence length, embedding size)

the output is a vector of sequence length

"""

return F.softmax(torch.cat([alignment_func(s, h, W, U, v) for h in hs]),0)

#testing the computation of an allignment row between the previous decoder

#embedding and all encoder embeddings

compute_attention_row(torch.rand(decoder_embedding_dim)/50, torch.rand(10,encoder_embedding_dim)/50)

## Defining A Learnable Attention Module

Now we need to wrap the previous function into a PyTorch `nn.Module`

. This is implemented in such a way where the computation of attention creates a traceable gradient, allowing the U, W, and V parameters to be updated through back propagation.

Also, this module supports different encoder and decoder embeddings, which may be useful in adapting this module to different applications.

`"""`

Defining the attention module

"""from torch import nn

#defining the input dimension from the encoder (h) and decoder (s)

encoder_embedding_dim = 10

decoder_embedding_dim = 20

#defining an example sequence length

sequence_length = 15

class Attention(nn.Module):

"""

-computes an alignment for all encoder embeddings

-constructs a context vector using those embeddings

-outputs that context vector

"""

def __init__(self, embed_dim=EMBED_DIM, encoder_embedding_dim=encoder_embedding_dim, decoder_embedding_dim=decoder_embedding_dim):

super(Attention, self).__init__()

#learnable attention parameters

self.U = nn.Parameter(torch.rand(embed_dim, encoder_embedding_dim), requires_grad=True )

self.W = nn.Parameter(torch.rand(embed_dim, decoder_embedding_dim), requires_grad=True )

self.v = nn.Parameter(torch.rand(1,embed_dim), requires_grad=True )

self.encoder_embedding_dim = encoder_embedding_dim

if torch.cuda.is_available():

self.cuda()

def forward(self, s, hn):

"""

computes a batch of context vectors given a current the all encoder

embeddings and the current decoder embedding

"""

#defining a tensor consisting of a context vector for each batch

weights = compute_attention_row(s, hn, W=self.W, U=self.U, v=self.v)

return torch.sum(hn * weights[:, None], axis=0)

print('==== Testing Attention ====')

#testing if the attention mechanism can support different sequence lengths

#and embedding dimensions

test_attention = Attention()

#defining previous decoder state

s = torch.rand(decoder_embedding_dim)/50

#defining input embeddings

hn = torch.rand(sequence_length, encoder_embedding_dim)/50

test_attention(s, hn).shape

## Training

Now we can train the attention module to solve the toy problem. This is done by generating an X/Y pair (a corresponding shuffled and un-shuffled set), then iterating through what every output should be and adjusting weights if the model is incorrect.

`""" Training Attention`Essentially, this generates random X/Y pairs, and trains the model to predict

each output given the previous correct output and all of the inputs.

This is a proof of concept. In reality using minibatches, better initializations, and

stochastically providing the true previous output occasionally would probably improve

convergence and generalizability.

"""

import random

from tqdm import tqdm

import numpy as np

import matplotlib.pyplot as plt

min_len = 5

max_len = 20

test_attention = Attention(20,2,2)

loss_fn = nn.MSELoss()

optimizer = torch.optim.SGD(test_attention.parameters(), lr=1e-3, momentum=0.9)

lr_phase = 0

#training on some number of random sequences

batch_losses = []

for i in tqdm(range(800)):

#generating x and y

y = []

x = []

for j in range(random.randint(min_len,max_len)):

y.append([j+1, j+2])

x.append([j+1,j+2])

random.shuffle(x)

x = np.array([[0,1]] + x).astype(np.float32)

y = np.array([[0,1]] + y).astype(np.float32)

x = torch.from_numpy(x)

y = torch.from_numpy(y)

#iterating over all training examples (given s predict s+1)

s_in = x[0]

sample_losses = []

for j in range(2,len(x)):

y_this = y[j]

optimizer.zero_grad()

s_out = test_attention(s_in, x)

loss = loss_fn(s_out, y_this)

sample_losses.append(loss.detach())

loss.backward(retain_graph=True)

optimizer.step()

s_in = torch.clone(y_this).detach()

batch_loss = np.mean(sample_losses)

batch_losses.append(batch_loss)

#hacking together a simple learning rate scheduler

if batch_loss<0.05 and lr_phase == 0:

optimizer = torch.optim.SGD(test_attention.parameters(), lr=1e-4, momentum=0.9)

lr_phase+=1

#stopping training when loss is good enough

if batch_loss<0.03:

break

plt.plot(batch_losses)

Using the following code, we can generate a randomly shuffled sequence and ask our attention model to sort it.

`"""`

Visualizing alignment

"""#generating x

x = []

for j in range(1, random.randint(min_len,max_len)):

x.append([j,j+1])

random.shuffle(x)

x = np.array([[0,1]] + x).astype(np.float32)

x = torch.from_numpy(x)

#Extracting learned parameters for generating alignment visual

W = test_attention.W

U = test_attention.U

v = test_attention.v

s = x[0]

y_hat = []

rows = []

#predicting the next element in the sequence.

#skipping over the trivia first, and not predicting one after the last.

for _ in range(0,len(x)-1):

#computing attention weights for this output, for visualization purposes

row = list(compute_attention_row(s, x, W=W, U=U, v=v).detach().numpy())

rows.append(row)

#predicting what should be in this location.

with torch.no_grad():

s = torch.round(test_attention(s, x))

y_hat.append(list(s))

#converting to numpy arrays

y_hat = np.array(y_hat)

x_p = np.array(x)

#printing intputs and predicted outputs

print('input: ')

print(x_p)

print('output: ')

print(y_hat)

#generating attention matrix plot

from matplotlib.ticker import MaxNLocator

alignments = np.array(rows)

plt.pcolormesh(alignments, edgecolors='k', linewidth=2)

ax = plt.gca()

ax.set_aspect('equal')

ax.yaxis.set_major_locator(MaxNLocator(integer=True))

ax.xaxis.set_major_locator(MaxNLocator(integer=True))

plt.title('Algnment scores used in attention')

plt.ylabel('output index (each row is attention for an output)')

plt.xlabel('input index')

plt.show()

There are some artifacts based on the way the output is processed, but as you can see the attention mechanism is properly un-shuffling the input!

We created a general purpose alignment module which can be used within a larger network to focus on key information. This is directly inspired by the one used in english to french translation, but can be used in a variety of applications

In future posts, I’ll also describing several landmark papers in the ML space, with an emphasis on practical and intuitive explanations. The attention mechanism used in transformers is a bit different than this attention mechanism, which I’ll be covering in a future post.

**Please like, share, and follow. As an independent author, your support really makes a huge difference!**

**Attribution:** All of the images in this document were created by Daniel Warfield, unless a source is otherwise provided. You can use any images in this post for your own non-commercial purposes, so long as you reference this article, https://danielwarfield.dev, or both.

## Be the first to comment