{ "cells": [ { "cell_type": "markdown", "source": [ "Self-Attention\r\n", "======================\r\n", "\r\n", "In this notebook, we look at how self-attention works. In other words, no RNN, but simply self attention of source sequence for the encoder, attention (both self and cross) of the **shifted target sequence** for the decoder. \r\n", "\r\n", "![Self Attention](../images/encdec_self_simplified.png)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Download Untility Files for Plotting and Data Generation\r\n", "\r\n", "Download these utility functions first and place them in the same directory as this notebook. These files are the same as the ones in [Lab11 Sequence to Sequence Model](https://weiliu2k.github.io/CITS4012/LSTM/seq2seq.html)." ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "from IPython.display import FileLink, FileLinks" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "FileLink('plots.py')" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "plots.py
" ], "text/plain": [ "c:\\Users\\wei\\jupyter_book\\cits4012\\cits4012_natural_language_processing\\cits4012_natural_language_processing\\attention\\plots.py" ] }, "metadata": {}, "execution_count": 3 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "FileLink('plots_seq2seq.py')" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "plots_seq2seq.py
" ], "text/plain": [ "c:\\Users\\wei\\jupyter_book\\cits4012\\cits4012_natural_language_processing\\cits4012_natural_language_processing\\attention\\plots_seq2seq.py" ] }, "metadata": {}, "execution_count": 4 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "FileLink('util.py')" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "util.py
" ], "text/plain": [ "c:\\Users\\wei\\jupyter_book\\cits4012\\cits4012_natural_language_processing\\cits4012_natural_language_processing\\attention\\util.py" ] }, "metadata": {}, "execution_count": 5 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "FileLink('replay.py')" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "replay.py
" ], "text/plain": [ "c:\\Users\\wei\\jupyter_book\\cits4012\\cits4012_natural_language_processing\\cits4012_natural_language_processing\\attention\\replay.py" ] }, "metadata": {}, "execution_count": 6 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Imports" ], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "import copy\r\n", "import numpy as np\r\n", "\r\n", "import torch\r\n", "import torch.optim as optim\r\n", "import torch.nn as nn\r\n", "import torch.nn.functional as F\r\n", "from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset\r\n", "from util import StepByStep\r\n", "from plots import *\r\n", "from plots_seq2seq import *" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Data Generation" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "We still make use of the Square Sequences, using the first two corners to predict the last two. Same method as before for generating noisy squares. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "def generate_sequences(n=128, variable_len=False, seed=13):\r\n", " basic_corners = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]])\r\n", " np.random.seed(seed)\r\n", " bases = np.random.randint(4, size=n)\r\n", " if variable_len:\r\n", " lengths = np.random.randint(3, size=n) + 2\r\n", " else:\r\n", " lengths = [4] * n\r\n", " directions = np.random.randint(2, size=n)\r\n", " points = [basic_corners[[(b + i) % 4 for i in range(4)]][slice(None, None, d*2-1)][:l] + np.random.randn(l, 2) * 0.1 for b, d, l in zip(bases, directions, lengths)]\r\n", " return points, directions" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Multi-Headed Attention" ], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "class Attention(nn.Module):\r\n", " def __init__(self, hidden_dim, input_dim=None, proj_values=False):\r\n", " super().__init__()\r\n", " self.d_k = hidden_dim\r\n", " self.input_dim = hidden_dim if input_dim is None else input_dim\r\n", " self.proj_values = proj_values\r\n", " # Affine transformations for Q, K, and V\r\n", " self.linear_query = nn.Linear(self.input_dim, hidden_dim)\r\n", " self.linear_key = nn.Linear(self.input_dim, hidden_dim)\r\n", " self.linear_value = nn.Linear(self.input_dim, hidden_dim)\r\n", " self.alphas = None\r\n", " \r\n", " def init_keys(self, keys):\r\n", " self.keys = keys\r\n", " self.proj_keys = self.linear_key(self.keys)\r\n", " self.values = self.linear_value(self.keys) \\\r\n", " if self.proj_values else self.keys\r\n", " \r\n", " def score_function(self, query):\r\n", " proj_query = self.linear_query(query)\r\n", " # scaled dot product\r\n", " # N, 1, H x N, H, L -> N, 1, L\r\n", " dot_products = torch.bmm(proj_query, self.proj_keys.permute(0, 2, 1))\r\n", " scores = dot_products / np.sqrt(self.d_k)\r\n", " return scores\r\n", " \r\n", " def forward(self, query, mask=None):\r\n", " # Query is batch-first N, 1, H\r\n", " scores = self.score_function(query) # N, 1, L\r\n", " if mask is not None:\r\n", " scores = scores.masked_fill(mask == 0, -1e9)\r\n", " alphas = F.softmax(scores, dim=-1) # N, 1, L\r\n", " self.alphas = alphas.detach()\r\n", " \r\n", " # N, 1, L x N, L, H -> N, 1, H\r\n", " context = torch.bmm(alphas, self.values)\r\n", " return context" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 10, "source": [ "class MultiHeadAttention(nn.Module):\r\n", " def __init__(self, n_heads, d_model, input_dim=None, proj_values=True):\r\n", " super().__init__()\r\n", " self.linear_out = nn.Linear(n_heads * d_model, d_model)\r\n", " self.attn_heads = nn.ModuleList([Attention(d_model, \r\n", " input_dim=input_dim, \r\n", " proj_values=proj_values) \r\n", " for _ in range(n_heads)])\r\n", " \r\n", " def init_keys(self, key):\r\n", " for attn in self.attn_heads:\r\n", " attn.init_keys(key)\r\n", " \r\n", " @property\r\n", " def alphas(self):\r\n", " # Shape: n_heads, N, 1, L (source)\r\n", " return torch.stack([attn.alphas for attn in self.attn_heads], dim=0)\r\n", " \r\n", " def output_function(self, contexts):\r\n", " # N, 1, n_heads * D\r\n", " concatenated = torch.cat(contexts, axis=-1)\r\n", " # Linear transf. to go back to original dimension\r\n", " out = self.linear_out(concatenated) # N, 1, D\r\n", " return out\r\n", " \r\n", " def forward(self, query, mask=None):\r\n", " contexts = [attn(query, mask=mask) for attn in self.attn_heads]\r\n", " out = self.output_function(contexts)\r\n", " return out" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Self-Attention\r\n", "\r\n", "Instead of using RNN to generate the hidden states, we use attention mechanism to generate hidden states for inputs. This is called self-attention, i.e. attention on the input data points themselves. Each input will have its own context vector. All context vectors are concatenated and transformed through a linear layer, as an input to the feedforward network as hidden state for the corresponding input. " ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Attention Scores from Input\r\n", "\r\n", "Inputs or their affine transformations are used as Keys, Values and Queries. " ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "\\alpha_{\\color{blue}{0}\\color{red}0}, \\alpha_{{\\color{blue}{0}\\color{red}1}} = softmax(\\frac{\\color{red}{Q_0}\\color{black}\\cdot K_0}{\\sqrt{2}}, \\frac{\\color{red}{Q_0}\\color{black}\\cdot K_1}{\\sqrt{2}})\r\n", "\\\\\r\n", "\\color{blue}{context\\ vector_0}\\color{black}= \\alpha_{\\color{blue}{0}\\color{red}0}V_0 + \\alpha_{{\\color{blue}{0}\\color{red}1}}V_1\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "\\alpha_{\\color{blue}{1}\\color{red}0}, \\alpha_{{\\color{blue}{1}\\color{red}1}} = softmax(\\frac{\\color{red}{Q_1}\\color{black}\\cdot K_0}{\\sqrt{2}}, \\frac{\\color{red}{Q_1}\\color{black}\\cdot K_1}{\\sqrt{2}})\r\n", "\\\\\r\n", "\\color{blue}{context\\ vector_1}\\color{black}= \\alpha_{\\color{blue}{1}\\color{red}0}V_0 + \\alpha_{{\\color{blue}{1}\\color{red}1}}V_1\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "\\begin{array}{c|cc}\r\n", "& source\\\\\r\n", "target& \\color{red}{x_0} & \\color{red}{x_1} \\\\\r\n", "\\hline\r\n", " \\color{blue}{h_0} & \\alpha_{\\color{blue}{0}\\color{red}0} & \\alpha_{{\\color{blue}{0}\\color{red}1}} \\\\\r\n", " \\color{blue}{h_1} & \\alpha_{\\color{blue}{1}\\color{red}0} & \\alpha_{{\\color{blue}{1}\\color{red}1}}\r\n", "\\end{array}\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "![](../images/transf_encself.png)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Encoder with Self-Attention\r\n", "\r\n", "As you can see from the code below, each input has its own attention head, so it makes use of the multi-headed attention above. Check the code against the diagram for better understanding." ], "metadata": {} }, { "cell_type": "code", "execution_count": 11, "source": [ "class EncoderSelfAttn(nn.Module):\r\n", " def __init__(self, n_heads, d_model, ff_units, n_features=None):\r\n", " super().__init__()\r\n", " self.n_heads = n_heads\r\n", " self.d_model = d_model\r\n", " self.ff_units = ff_units\r\n", " self.n_features = n_features\r\n", " self.self_attn_heads = MultiHeadAttention(n_heads, d_model, input_dim=n_features)\r\n", " self.ffn = nn.Sequential(\r\n", " nn.Linear(d_model, ff_units),\r\n", " nn.ReLU(),\r\n", " nn.Linear(ff_units, d_model),\r\n", " )\r\n", " \r\n", " def forward(self, query, mask=None):\r\n", " self.self_attn_heads.init_keys(query)\r\n", " att = self.self_attn_heads(query, mask)\r\n", " out = self.ffn(att)\r\n", " return out" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Encoder the Square Sequence\r\n", "\r\n", "The perfect sequare with four corners, is split into a source sequence of two corners and a target sequence of two corners." ], "metadata": {} }, { "cell_type": "code", "execution_count": 12, "source": [ "full_seq = torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]).float().view(1, 4, 2)\r\n", "source_seq = full_seq[:, :2]\r\n", "target_seq = full_seq[:, 2:]" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "torch.manual_seed(11)\r\n", "encself = EncoderSelfAttn(n_heads=3, d_model=2, ff_units=10, n_features=2)\r\n", "query = source_seq\r\n", "encoder_states = encself(query)\r\n", "encoder_states" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[-0.0498, 0.2193],\n", " [-0.0642, 0.2258]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 13 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Decoder" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "This decoder has both self attention and cross-attention, which uses the decoder hidden state as query. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 14, "source": [ "class DecoderSelfAttn(nn.Module):\r\n", " def __init__(self, n_heads, d_model, ff_units, n_features=None):\r\n", " super().__init__()\r\n", " self.n_heads = n_heads\r\n", " self.d_model = d_model\r\n", " self.ff_units = ff_units\r\n", " self.n_features = d_model if n_features is None else n_features\r\n", " self.self_attn_heads = MultiHeadAttention(n_heads, d_model, input_dim=self.n_features)\r\n", " self.cross_attn_heads = MultiHeadAttention(n_heads, d_model)\r\n", " self.ffn = nn.Sequential(\r\n", " nn.Linear(d_model, ff_units),\r\n", " nn.ReLU(),\r\n", " nn.Linear(ff_units, self.n_features),\r\n", " )\r\n", " \r\n", " def init_keys(self, states):\r\n", " self.cross_attn_heads.init_keys(states)\r\n", " \r\n", " def forward(self, query, source_mask=None, target_mask=None):\r\n", " self.self_attn_heads.init_keys(query)\r\n", " att1 = self.self_attn_heads(query, target_mask)\r\n", " att2 = self.cross_attn_heads(att1, source_mask)\r\n", " out = self.ffn(att2)\r\n", " return out" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Subsequent Inputs and Teacher Forcing" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The **shifted target sequence** containes the last element of the source sequence, and everything but the last element of the target sequence. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 15, "source": [ "shifted_seq = torch.cat([source_seq[:, -1:], target_seq[:, :-1]], dim=1)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Attention Scores\r\n", "\r\n", "Because we feed the entire shifted target sequence to the decoder, we need to be careful not to allow the model to peek into the future. \r\n", "\r\n", "The attention score $\\alpha_{22}$ is problematic, because we are using the $K_2$ and $V_2$ which the model should not have seen yet for calculation. So we need **target mask** to ensure that it is not computed. " ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "\\alpha_{\\color{green}{2}\\color{red}1}, \\alpha_{{\\color{green}{2}\\color{red}2}} = softmax(\\frac{\\color{red}{Q_1}\\color{black}\\cdot K_1}{\\sqrt{2}}, \\frac{\\color{red}{Q_1}\\color{black}\\cdot K_2}{\\sqrt{2}})\r\n", "\\\\\r\n", "\\color{green}{context\\ vector_2}\\color{black}= \\alpha_{\\color{green}{2}\\color{red}1}V_1 + \\alpha_{{\\color{green}{2}\\color{red}2}}V_2\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Below, it is fine for us to use $K_2$ and $V_2$ to caluate $\\alpha_3$." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "\\alpha_{\\color{green}{3}\\color{red}1}, \\alpha_{{\\color{green}{3}\\color{red}2}} = softmax(\\frac{\\color{red}{Q_2}\\color{black}\\cdot K_1}{\\sqrt{2}}, \\frac{\\color{red}{Q_2}\\color{black}\\cdot K_2}{\\sqrt{2}})\r\n", "\\\\\r\n", "\\color{green}{context\\ vector_3}\\color{black}= \\alpha_{\\color{green}{3}\\color{red}1}V_1 + \\alpha_{{\\color{green}{3}\\color{red}2}}V_2\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Any values in the attention score matrix that are above the diagnal line needs to be masked. Here we need to mask $\\alpha_{22}$ to 0.\r\n", "$$\r\n", "\\begin{array}{c|cc}\r\n", "& source\\\\\r\n", "target& \\color{red}{x_1} & \\color{red}{x_2} \\\\\r\n", "\\hline\r\n", " \\color{green}{h_2} & \\alpha_{\\color{green}{2}\\color{red}1} & \\alpha_{{\\color{green}{2}\\color{red}2}} \\\\\r\n", " \\color{green}{h_3} & \\alpha_{\\color{green}{3}\\color{red}1} & \\alpha_{{\\color{green}{3}\\color{red}2}}\r\n", "\\end{array}\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Target Mask (Training)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The masked scoreing matrix should look like this:" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "\\begin{array}{c|cc}\r\n", "& source\\\\\r\n", "target& \\color{red}{x_1} & \\color{red}{x_2} \\\\\r\n", "\\hline\r\n", " \\color{green}{h_2} & \\alpha_{\\color{green}{2}\\color{red}1} & 0 \\\\\r\n", " \\color{green}{h_3} & \\alpha_{\\color{green}{3}\\color{red}1} & \\alpha_{{\\color{green}{3}\\color{red}2}}\r\n", "\\end{array}\r\n", "$$" ], "metadata": {} }, { "cell_type": "code", "execution_count": 16, "source": [ "def subsequent_mask(size):\r\n", " attn_shape = (1, size, size)\r\n", " subsequent_mask = (1 - torch.triu(torch.ones(attn_shape), diagonal=1)).bool()\r\n", " return subsequent_mask" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 17, "source": [ "subsequent_mask(2) # 1, L, L" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[ True, False],\n", " [ True, True]]])" ] }, "metadata": {}, "execution_count": 17 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ ":::{important} Subsequent Mask\r\n", "We must use this mask while querying the decoder to prevent it\r\n", "from cheating. You can choose to use an additional mask to\r\n", "\"hide\" more data from the decoder if you wish, but the\r\n", "subsequent mask is a strong requirement of the self-attention\r\n", "decoder.\r\n", ":::" ], "metadata": {} }, { "cell_type": "code", "execution_count": 18, "source": [ "torch.manual_seed(13)\r\n", "decself = DecoderSelfAttn(n_heads=3, d_model=2, ff_units=10, n_features=2)\r\n", "decself.init_keys(encoder_states)\r\n", "\r\n", "query = shifted_seq\r\n", "out = decself(query, target_mask=subsequent_mask(2))\r\n", "\r\n", "decself.self_attn_heads.alphas" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[[1.0000, 0.0000],\n", " [0.4011, 0.5989]]],\n", "\n", "\n", " [[[1.0000, 0.0000],\n", " [0.4264, 0.5736]]],\n", "\n", "\n", " [[[1.0000, 0.0000],\n", " [0.6304, 0.3696]]]])" ] }, "metadata": {}, "execution_count": 18 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Target Mask (Evaluation/Prediction)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The only difference between training and evaluation, concerning the target mask,\r\n", "is that we'll be using larger masks as we go. The very first mask is actually trivial\r\n", "since there are no elements above the diagonal:" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "1^{st}\\ Step\r\n", "\\begin{cases}\r\n", "\\begin{array}{c|cc}\r\n", "target& source\\\\\r\n", "& \\color{red}{x_1} & \\\\\r\n", "\\hline\r\n", " \\color{green}{h_2} & \\alpha_{\\color{green}{2}\\color{red}1}\r\n", "\\end{array}\r\n", "\\end{cases}\r\n", "$$" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "In evaluation/prediction time we only have the source sequence and, in our\r\n", "example, we use its last element as input for the decoder:" ], "metadata": {} }, { "cell_type": "code", "execution_count": 19, "source": [ "inputs = source_seq[:, -1:]\r\n", "trg_masks = subsequent_mask(1)\r\n", "out = decself(inputs, trg_masks)\r\n", "out" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[0.4132, 0.3728]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 19 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The mask is not actually masking anything in this case, and we get a prediction for\r\n", "the coordinates of $x_2$ as expected. In RNN based seq2seq models, this prediction would be used directly as the next input, but the self-attention decoder expects the full sequence as \"query\", so we concatenate the prediction to the previous \"query\"." ], "metadata": {} }, { "cell_type": "code", "execution_count": 20, "source": [ "inputs = torch.cat([inputs, out[:, -1:, :]], dim=-2)\r\n", "inputs" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[-1.0000, 1.0000],\n", " [ 0.4132, 0.3728]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 20 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Now there are two data points for querying the decoder, so we adjust the mask\r\n", "accordingly." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "$$\r\n", "2^{nd}\\ Step\r\n", "\\begin{cases}\r\n", "\\begin{array}{c|cc}\r\n", "target& source\\\\\r\n", "& \\color{red}{x_1} & \\color{green}{x_2} \\\\\r\n", "\\hline\r\n", " \\color{green}{h_2} & \\alpha_{\\color{green}{2}\\color{red}1} & 0 \\\\\r\n", " \\color{green}{h_3} & \\alpha_{\\color{green}{3}\\color{red}1} & \\alpha_{{\\color{green}{3}\\color{green}2}}\r\n", "\\end{array}\r\n", "\\end{cases}\r\n", "$$" ], "metadata": {} }, { "cell_type": "code", "execution_count": 21, "source": [ "trg_masks = subsequent_mask(2)\r\n", "out = decself(inputs, trg_masks)\r\n", "out" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[0.4137, 0.3727],\n", " [0.4132, 0.3728]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 21 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ ":::{important}\r\n", "The mask guarantees that the predicted $x_2$ (in the first step)\r\n", "won't change the predicted $x_2$ (in the second step) because\r\n", "predictions are made based on past data points only\r\n", ":::\r\n", "\r\n", "The last prediction is, once again, concatenated to the previous \"query\"." ], "metadata": {} }, { "cell_type": "code", "execution_count": 22, "source": [ "inputs = torch.cat([inputs, out[:, -1:, :]], dim=-2)\r\n", "inputs" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[-1.0000, 1.0000],\n", " [ 0.4132, 0.3728],\n", " [ 0.4132, 0.3728]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 22 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "But, since we're actually done with the predictions (the desired target sequence has\r\n", "a length of two), we simply exclude the first data point in the query (the one coming\r\n", "from the source sequence) and that's the predicted target sequence" ], "metadata": {} }, { "cell_type": "code", "execution_count": 23, "source": [ "inputs[:, 1:]" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[0.4132, 0.3728],\n", " [0.4132, 0.3728]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 23 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Encoder + Decoder + Self-Attention" ], "metadata": {} }, { "cell_type": "code", "execution_count": 24, "source": [ "class EncoderDecoderSelfAttn(nn.Module):\r\n", " def __init__(self, encoder, decoder, input_len, target_len):\r\n", " super().__init__()\r\n", " self.encoder = encoder\r\n", " self.decoder = decoder\r\n", " self.input_len = input_len\r\n", " self.target_len = target_len\r\n", " self.trg_masks = self.subsequent_mask(self.target_len)\r\n", " \r\n", " @staticmethod\r\n", " def subsequent_mask(size):\r\n", " attn_shape = (1, size, size)\r\n", " subsequent_mask = (1 - torch.triu(torch.ones(attn_shape), diagonal=1))\r\n", " return subsequent_mask\r\n", " \r\n", " def encode(self, source_seq, source_mask):\r\n", " # Encodes the source sequence and uses the result\r\n", " # to initialize the decoder\r\n", " encoder_states = self.encoder(source_seq, source_mask)\r\n", " self.decoder.init_keys(encoder_states)\r\n", " \r\n", " def decode(self, shifted_target_seq, source_mask=None, target_mask=None):\r\n", " # Decodes/generates a sequence using the shifted (masked)\r\n", " # target sequence - used in TRAIN mode\r\n", " outputs = self.decoder(shifted_target_seq, \r\n", " source_mask=source_mask,\r\n", " target_mask=target_mask)\r\n", " return outputs\r\n", " \r\n", " def predict(self, source_seq, source_mask):\r\n", " # Decodes/generates a sequence using one input\r\n", " # at a time - used in EVAL mode\r\n", " inputs = source_seq[:, -1:]\r\n", " for i in range(self.target_len):\r\n", " out = self.decode(inputs, source_mask, self.trg_masks[:, :i+1, :i+1])\r\n", " out = torch.cat([inputs, out[:, -1:, :]], dim=-2)\r\n", " inputs = out.detach()\r\n", " outputs = inputs[:, 1:, :]\r\n", " return outputs\r\n", " \r\n", " def forward(self, X, source_mask=None):\r\n", " # Sends the mask to the same device as the inputs\r\n", " self.trg_masks = self.trg_masks.type_as(X).bool()\r\n", " # Slices the input to get source sequence\r\n", " source_seq = X[:, :self.input_len, :]\r\n", " # Encodes source sequence AND initializes decoder\r\n", " self.encode(source_seq, source_mask)\r\n", " if self.training:\r\n", " # Slices the input to get the shifted target seq\r\n", " shifted_target_seq = X[:, self.input_len-1:-1, :]\r\n", " # Decodes using the mask to prevent cheating\r\n", " outputs = self.decode(shifted_target_seq, source_mask, self.trg_masks)\r\n", " else:\r\n", " # Decodes using its own predictions\r\n", " outputs = self.predict(source_seq, source_mask)\r\n", " \r\n", " return outputs" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Data Preparation" ], "metadata": {} }, { "cell_type": "code", "execution_count": 25, "source": [ "points, directions = generate_sequences()\r\n", "full_train = torch.as_tensor(points).float()\r\n", "target_train = full_train[:, 2:]" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 26, "source": [ "test_points, test_directions = generate_sequences(seed=19)\r\n", "full_test = torch.as_tensor(points).float()\r\n", "source_test = full_test[:, :2]\r\n", "target_test = full_test[:, 2:]" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 27, "source": [ "train_data = TensorDataset(full_train, target_train)\r\n", "test_data = TensorDataset(source_test, target_test)\r\n", "\r\n", "generator = torch.Generator()\r\n", "train_loader = DataLoader(train_data, batch_size=16, shuffle=True, generator=generator)\r\n", "test_loader = DataLoader(test_data, batch_size=16)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Model Configuration & Training" ], "metadata": {} }, { "cell_type": "code", "execution_count": 28, "source": [ "torch.manual_seed(23)\r\n", "encself = EncoderSelfAttn(n_heads=3, d_model=2, ff_units=10, n_features=2)\r\n", "decself = DecoderSelfAttn(n_heads=3, d_model=2, ff_units=10, n_features=2)\r\n", "model = EncoderDecoderSelfAttn(encself, decself, input_len=2, target_len=2)\r\n", "loss = nn.MSELoss()\r\n", "optimizer = optim.Adam(model.parameters(), lr=0.01)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 29, "source": [ "sbs_seq_selfattn = StepByStep(model, loss, optimizer)\r\n", "sbs_seq_selfattn.set_loaders(train_loader, test_loader)\r\n", "sbs_seq_selfattn.train(100)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 30, "source": [ "fig = sbs_seq_selfattn.plot_losses()" ], "outputs": [ { "output_type": "display_data", "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {} } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Visualizing Predictions" ], "metadata": {} }, { "cell_type": "code", "execution_count": 31, "source": [ "fig = sequence_pred(sbs_seq_selfattn, full_test, test_directions)" ], "outputs": [ { "output_type": "display_data", "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {} } ], "metadata": { "scrolled": false } } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.8.10", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.8.10 64-bit ('cits4012': conda)" }, "interpreter": { "hash": "d990147e05fc0cc60dd3871899a6233eb6a5324c1885ded43d013dc915f7e535" } }, "nbformat": 4, "nbformat_minor": 2 }