Attention from Alignment, Practically Explained | by Daniel Warfield


Learn from what matters, Ignore what doesn’t.

Photo by Armand Khoury on Unsplash

Attention, as popularized by the landmark paper Attention Is All You Need (2017), is arguably the most important architectural trend in machine learning right now. Originally intended for sequence to sequence modeling, attention has exploded into virtually every sub-discipline of machine learning.

This post will describe a particular flavor of attention which proceeded the transforner style of attention. We’ll discuss how it works, and why it’s useful. We’ll also go over some literature and a tutorial implementing this form of attention in PyTorch. By reading this post, you will have a more thorough understanding of attention as a general concept, which is useful in exploring more cutting edge applications.

The attention mechanism was originally popularized in Neural Machine Translation by Jointly Learning to Align and Translate(2014), which is the guiding reference for this particular post. This paper employs an encoder-decoder architecture for english-to-french translation.

The encoder-decoder architecture in a nutshell, for a french to english translation task

This is a very common architecture, but the exact details can change drastically from implementation to implementation. For instance, some of the earlier literature in sequence to sequence encoder-decoders were recurrent networks that would incrementally “build” and then “deconstruct” the embedding.

Conceptualization of the information flow of a simple sequence to sequence recurrent encoder-decoder. The encoder incrementally embeds the english words, word by word, into the embedding space, which is then deconstructed by the decoder. In this diagram, the circles represent the embeddings throughout the encoder (red) the intermediate embedding space (white) and throughout the decoder (blue). In this case, the embeddings are long and complex vectors with abstract content which isn’t easily human interpretable.

This general idea, and minor variations therein, was state of the art for several years. However, one problem with this approach is that the entire input sequence has to be embedded into the embedding space, which is generally a fixed sized vector. As a result, these models can easily forget content from sequences which are too long. The attention mechanism was designed to alleviate the problem of needing to fit the entire input sequence into the embedding space. It does this by telling the model which inputs are related to which outputs. Or, in other words, the attention mechanism allows a model to focus on relevant portions of the input, and disregard the rest.

An example of what an attention based thought process might look like. In french, “je” is exactly identical to the word “I”. “Suis” is a conjugation of the verb “etre” which is “to be” and is conjugated as “suis” based on the subject “I” and the verb “am”. The choice of “directeur” is mostly related to the word “manager”, but is also related to the context in which that word is used. The choice of which inputs relate to which outputs is the task of the attention mechanism.

Practically, the attention mechanism we will be discussing ends up being a matrix of scores, called “alignment” scores. These alignment scores encode the degree to which a word in an input sequence relates to a word in the output sequence.

Two examples of attention matrices for two different english-to-french examples, from Neural Machine Translation by Jointly Learning to Align and Translate(2014). This paper only tangentially mentions the term “attention”, and actually called this an “alignment model”. The term “attention” seems to have been popularized after the fact.

The alignment score can be computed many ways. We’ll stick with our 2014 paper and pick apart it’s particular alignment function:

Alignment score calculation from Neural Machine Translation by Jointly Learning to Align and Translate(2014). This is not the only alignment function which exists, but it’s the one we will be focusing on. “Ua” and “Wa” represent a learnable transformations of the previous output embedding (si-1) and a particular input embedding (hj), while “va” represents a learnable reduction of the total embedding down to the final alignment scalar.

When calculating the allignment for the ith output, this approach uses the previous embedded state of the decoder (s_i-1), the embedding of an input word (h_j), and the learnable parameters W_a, U_a, and v_a, to calculate the alignment of the ith output with the jth input. The tanh activation function is included to add non-linearity, which is vital in training a model to understand complex relationships.

In other words, the function above calculates a score between the next output word and a single input word, denoting how relevant an input word is to the current output. This function is run across all input words (h_j) to calculate an alignment score for all input words given the current output.

A conceptual diagram of how the alignments for a given prediction (word 8) are calculated. The alignment function is calculated between the previous output embedding of the decoder, and all the inputs, to calculate the attention for the current output. Modified from Neural Machine Translation by Jointly Learning to Align and Translate(2014)

A softmax function is applied across all of the computed alignments, turning them into a probability. This is referred to in the literature as a “soft-search” or “soft-alignment”.

Exactly how attention is used can vary from implementation to implementation. In Neural Machine Translation by Jointly Learning to Align and Translate(2014), the attention mechanism decides which input embeddings to provide to the decoder.

A graphical illustration of the proposed model trying to generate the t-th target word yt given a source sentence (x1, x2, . . . , xT ). Each input embedding is multiplied by it’s respective alignment score, then they are summed together to form the context vector which is used for the current decoder output step. From Neural Machine Translation by Jointly Learning to Align and Translate(2014).

It does this selection process with a weighted sum. All input embeddings are multiplied by their respective alignment score (in practice most of the alignment scores have a value of zero, while one or two might have a value of 0.8 and 0.2 for instance), then those weighted embeddings are added together to create the context vector for a particular output.

The context vector used for the ith decoder step, ci. This is the weighted sum of all input embeddings based on their alignment score. From Neural Machine Translation by Jointly Learning to Align and Translate(2014).

The context vector is the combination of all of the inputs which are relevant to the current output.

The following figure ties together how attention fits into the bigger picture:

A breakdown of the flow of information for a particular output.
  1. The inputs are embedded into some initial vector representation (using word2vect, for instance).
  2. Those are passed through a bi-direcional LSTM, to create a bit of context awareness between the embeddings
  3. Alignment scores are calculated for each input using the previous decoder embedding and the learned parameters within the alignment function.
  4. the soft-maxed alignments are multiplied against each input, added together, and used to construct the context vector
  5. the decoder uses the previous decoder hidden state, along with the context vector, to generate the prediction of the current word.

In the next section we will implement this attention mechanism in PyTorch.

While I originally set out to implement the entire english to french example, it became apparent that the implementation would be excessively long, and contain many intricacies which are irrelevant to the explanation of attention itself. As a result, I created a toy problem which mimics the grammatical aspects of english to french translation to showcase the attention mechanism specifically, without the bulk of implementing LSTM’s, embeddings, utility tokens, batching, masking, and other problem specific components.

The full code can be found here, for those that are interested

As previously mentioned, english to french translation can be thought of as two subproblems; alignment and translation. The various networks within the encoder and decoder translate values, while the attention mechanism re-orients the vectors. In other words, Attention is all about alignment. To emulate the alignment problem of english to french translation the following toy problem was defined:

given some shuffled input of values

[[ 0.  1.], [17. 18.], [10. 11.], [13. 14.], [14. 15.], [ 2.  3.] ... ]

Organize them into a sequential output:

[[0. 1.], [1. 2.], [ 2.  3.], [ 3.  4.], [ 4.  5.], [ 5.  6.] ...]

Practically, the toy question posed to the attention mechanism is: Given the previous output vector, which output should come next, given a selection of possible outputs? This is a very similar question to the gramatical question in english to french translation, which is: Given a previous output word, which inputs are relevant to the next one? Thus, by solving this toy problem, we can show the power of attention mechanism without getting too far in the weeds.

Defining the Alignment Function

Recall the alignment function

The previously discussed alignment function, from Neural Machine Translation by Jointly Learning to Align and Translate(2014)

This function, essentially, decides the weight (α) of an input (hj) given the previous output (si-1). This can be implemented directly in PyTorch:

"""
Implimenting the alignment function.

The whole idea is that, given an embedding for the encoder and decoder, the
alignment func outputs a scalar raiting the alignment. So, the shapes of
v, W, and U should be such that the output is a scalar.
"""

import torch
import torch.nn.functional as F

#defining the size of the input and output vectors of the attention mechanism
EMBED_DIM = 100

#these need to be sized in such a way that matrix multiplication yields a scalar
#otherwise, they're just general learnable parameters. Different alignment
#functions might have different parameters. For instance, "attention is all you
#need" uses a projection head that generates a query, key, and value, which are
#used in a different self-alignment function. this can allign vectors of different
#lengths
encoder_embedding_dim = EMBED_DIM*2
decoder_embedding_dim = EMBED_DIM

U_attention = torch.rand(EMBED_DIM, encoder_embedding_dim)
W_attention = torch.rand(decoder_embedding_dim, EMBED_DIM)
v_attention = torch.rand(1,EMBED_DIM)

def alignment_func(s, h, W=W_attention, U=U_attention, v=v_attention):
"""
s: si-1, from the paper, the previous decoder state
h: hj, from the paper, an input embedding
W,U,v: trainable parameters

calculates v*tanh(W*s + U*h), should return the scalar alpha
"""

v1 = torch.matmul(W,s)
v2 = torch.matmul(U,h)
v3 = F.tanh(v1+v2)

return torch.matmul(v, v3)

#testing the alignment function between one embedded word and another
#dividing by value to get them in a good range for tanh
s = torch.rand(decoder_embedding_dim)/50
h = torch.rand(encoder_embedding_dim)/50
alignment_func(s, h)

An example output of the alignment function. A scalar, which corresponds to a specific input-output pair.

Defining Attention

For a given previous output, the task of the attention mechanism is to calculate which inputs to pay attention to. This can be done by calculating the alignment for all inputs and passing that vector of alignments through a softmax.

"""
defining attention, wich is a list of softmaxed alignment scores for all input
embeddings (hj) given the previous decoder embedding (si-1). This is equivilent
to a row of the attention matrix, hence the name of the function.
"""

def compute_attention_row(s, hs, W=W_attention, U=U_attention, v=v_attention):
"""
computes alignments for all h values given s

s is a vector of length embedding size
hs is a tensor of shape (sequence length, embedding size)
the output is a vector of sequence length
"""
return F.softmax(torch.cat([alignment_func(s, h, W, U, v) for h in hs]),0)

#testing the computation of an allignment row between the previous decoder
#embedding and all encoder embeddings
compute_attention_row(torch.rand(decoder_embedding_dim)/50, torch.rand(10,encoder_embedding_dim)/50)

A single attention vector for a given output position.

Defining A Learnable Attention Module

Now we need to wrap the previous function into a PyTorch nn.Module. This is implemented in such a way where the computation of attention creates a traceable gradient, allowing the U, W, and V parameters to be updated through back propagation.

Also, this module supports different encoder and decoder embeddings, which may be useful in adapting this module to different applications.

"""
Defining the attention module
"""

from torch import nn

#defining the input dimension from the encoder (h) and decoder (s)
encoder_embedding_dim = 10
decoder_embedding_dim = 20

#defining an example sequence length
sequence_length = 15

class Attention(nn.Module):
"""
-computes an alignment for all encoder embeddings
-constructs a context vector using those embeddings
-outputs that context vector
"""

def __init__(self, embed_dim=EMBED_DIM, encoder_embedding_dim=encoder_embedding_dim, decoder_embedding_dim=decoder_embedding_dim):
super(Attention, self).__init__()

#learnable attention parameters
self.U = nn.Parameter(torch.rand(embed_dim, encoder_embedding_dim), requires_grad=True )
self.W = nn.Parameter(torch.rand(embed_dim, decoder_embedding_dim), requires_grad=True )
self.v = nn.Parameter(torch.rand(1,embed_dim), requires_grad=True )
self.encoder_embedding_dim = encoder_embedding_dim

if torch.cuda.is_available():
self.cuda()

def forward(self, s, hn):
"""
computes a batch of context vectors given a current the all encoder
embeddings and the current decoder embedding
"""
#defining a tensor consisting of a context vector for each batch
weights = compute_attention_row(s, hn, W=self.W, U=self.U, v=self.v)

return torch.sum(hn * weights[:, None], axis=0)

print('==== Testing Attention ====')
#testing if the attention mechanism can support different sequence lengths
#and embedding dimensions
test_attention = Attention()

#defining previous decoder state
s = torch.rand(decoder_embedding_dim)/50
#defining input embeddings
hn = torch.rand(sequence_length, encoder_embedding_dim)/50

test_attention(s, hn).shape

Creates a context vector of length 10. This makes sense, as the input embeddings are of length 10, and the output of this attention technique is a context vector which is the weighted sum of all the input embeddings.

Training

Now we can train the attention module to solve the toy problem. This is done by generating an X/Y pair (a corresponding shuffled and un-shuffled set), then iterating through what every output should be and adjusting weights if the model is incorrect.

""" Training Attention

Essentially, this generates random X/Y pairs, and trains the model to predict
each output given the previous correct output and all of the inputs.

This is a proof of concept. In reality using minibatches, better initializations, and
stochastically providing the true previous output occasionally would probably improve
convergence and generalizability.
"""

import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

min_len = 5
max_len = 20

test_attention = Attention(20,2,2)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(test_attention.parameters(), lr=1e-3, momentum=0.9)
lr_phase = 0

#training on some number of random sequences
batch_losses = []
for i in tqdm(range(800)):

#generating x and y
y = []
x = []
for j in range(random.randint(min_len,max_len)):
y.append([j+1, j+2])
x.append([j+1,j+2])
random.shuffle(x)
x = np.array([[0,1]] + x).astype(np.float32)
y = np.array([[0,1]] + y).astype(np.float32)
x = torch.from_numpy(x)
y = torch.from_numpy(y)

#iterating over all training examples (given s predict s+1)
s_in = x[0]
sample_losses = []
for j in range(2,len(x)):

y_this = y[j]

optimizer.zero_grad()
s_out = test_attention(s_in, x)

loss = loss_fn(s_out, y_this)
sample_losses.append(loss.detach())

loss.backward(retain_graph=True)
optimizer.step()

s_in = torch.clone(y_this).detach()

batch_loss = np.mean(sample_losses)
batch_losses.append(batch_loss)

#hacking together a simple learning rate scheduler
if batch_loss<0.05 and lr_phase == 0:
optimizer = torch.optim.SGD(test_attention.parameters(), lr=1e-4, momentum=0.9)
lr_phase+=1

#stopping training when loss is good enough
if batch_loss<0.03:
break

plt.plot(batch_losses)

Training Loss. As noted in the implementation, the training process could be significantly improved to encourage better convergence, but this serves for our toy example.

Using the following code, we can generate a randomly shuffled sequence and ask our attention model to sort it.

"""
Visualizing alignment
"""

#generating x
x = []
for j in range(1, random.randint(min_len,max_len)):
x.append([j,j+1])
random.shuffle(x)

x = np.array([[0,1]] + x).astype(np.float32)
x = torch.from_numpy(x)

#Extracting learned parameters for generating alignment visual
W = test_attention.W
U = test_attention.U
v = test_attention.v

s = x[0]
y_hat = []
rows = []

#predicting the next element in the sequence.
#skipping over the trivia first, and not predicting one after the last.
for _ in range(0,len(x)-1):

#computing attention weights for this output, for visualization purposes
row = list(compute_attention_row(s, x, W=W, U=U, v=v).detach().numpy())
rows.append(row)

#predicting what should be in this location.
with torch.no_grad():
s = torch.round(test_attention(s, x))

y_hat.append(list(s))

#converting to numpy arrays
y_hat = np.array(y_hat)
x_p = np.array(x)

#printing intputs and predicted outputs
print('input: ')
print(x_p)
print('output: ')
print(y_hat)

#generating attention matrix plot
from matplotlib.ticker import MaxNLocator
alignments = np.array(rows)
plt.pcolormesh(alignments, edgecolors='k', linewidth=2)
ax = plt.gca()
ax.set_aspect('equal')
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.title('Algnment scores used in attention')
plt.ylabel('output index (each row is attention for an output)')
plt.xlabel('input index')
plt.show()

The shuffled input, the model output, and the alignment scores.

There are some artifacts based on the way the output is processed, but as you can see the attention mechanism is properly un-shuffling the input!

We created a general purpose alignment module which can be used within a larger network to focus on key information. This is directly inspired by the one used in english to french translation, but can be used in a variety of applications

In future posts, I’ll also describing several landmark papers in the ML space, with an emphasis on practical and intuitive explanations. The attention mechanism used in transformers is a bit different than this attention mechanism, which I’ll be covering in a future post.

Please like, share, and follow. As an independent author, your support really makes a huge difference!

Attribution: All of the images in this document were created by Daniel Warfield, unless a source is otherwise provided. You can use any images in this post for your own non-commercial purposes, so long as you reference this article, https://danielwarfield.dev, or both.



Source link

Be the first to comment

Leave a Reply

Your email address will not be published.


*