Source code for lmp.model._trans_enc

r"""Transformer language model."""

import argparse
import math
from typing import Any, ClassVar, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import lmp.util.metric
import lmp.util.validate
from lmp.model._base import BaseModel
from lmp.tknzr._base import PAD_TKID, BaseTknzr


[docs]class MultiHeadAttnLayer(nn.Module): r"""Multi-head attention :footcite:`vaswani2017attention` layer. - Let :math:`B` be input mini-batch size. - Let :math:`S_q` be the length of each query sequence. - Let :math:`S_k` be the length of each key sequence. - Let :math:`\dMdl` be the number of features per time step in each sequence. - Let :math:`q` be a batch of query sequence with shape :math:`(B, S_q, \dMdl)`. - Let :math:`k` be a batch of key sequence with shape :math:`(B, S_k, \dMdl)`. - Let :math:`v` be a batch of value sequence with shape :math:`(B, S_k, \dMdl)`. - Let :math:`\msk` be a batch of attention mask with shape :math:`(B, S_q, S_k)`. - Let :math:`\nHd` be the number of attention heads. - Let :math:`d_k` be the number of key features in each attention head. - Let :math:`d_v` be the number of value features in each attention head. Multi-head attention layer is defined as follow: .. math:: \begin{align*} & \algoProc{\MultiHeadAttnLayer}(k, \msk, q, v) \\ & \indent{1} S_q \algoEq q.\sz{1} \\ & \indent{1} S_k \algoEq k.\sz{1} \\ & \indent{1} \algoFor{h \in \set{1, \dots, \nHd}} \\ & \indent{2} \algoCmt{Get query vector for each head.} \\ & \indent{2} \algoFor{t \in \set{1, \dots, S_q}} \\ & \indent{3} q_t^h \algoEq W_{Q, h} \cdot q_t \\ & \indent{2} \algoEndFor \\ & \indent{2} Q^h \algoEq \cat{q_1^h, \dots, q_{S_q}^h} \\ & \indent{2} \algoCmt{Get key-value vectors for each head.} \\ & \indent{2} \algoFor{t \in \set{1, \dots, S_k}} \\ & \indent{3} k_t^h \algoEq W_{K, h} \cdot k_t \\ & \indent{3} v_t^h \algoEq W_{V, h} \cdot v_t \\ & \indent{2} \algoEndFor \\ & \indent{2} K^h \algoEq \cat{k_1^h, \dots, k_{S_k}^h} \\ & \indent{2} V^h \algoEq \cat{v_1^h, \dots, v_{S_k}^h} \\ & \indent{2} \algoCmt{Apply attention mask on similarity scores.} \\ & \indent{2} \Sim^h \algoEq \dfrac{Q^h \cdot \pa{K^h}^\top}{\sqrt{d_k}} \\ & \indent{2} \algoFor{i \in \set{1, \dots, S_q}} \\ & \indent{3} \algoFor{j \in \set{1, \dots, S_k}} \\ & \indent{4} \algoIf{\msk_{i,j} \algoIs \algoTrue} \\ & \indent{5} \Sim_{i,j}^h \algoEq -10^9 \\ & \indent{4} \algoEndIf \\ & \indent{3} \algoEndFor \\ & \indent{3} \attn_i^h \algoEq \sof{\Sim_{i,1}^h, \dots, \Sim_{i,S_k}^h} \\ & \indent{2} \algoEndFor \\ & \indent{2} \algoCmt{Get attention scores.} \\ & \indent{2} \attn^h \algoEq \cat{\attn_1^h, \dots, \attn_{S_q}^h} \\ & \indent{2} F^h \algoEq \attn^h \cdot V^h \\ & \indent{1} \algoEndFor \\ & \indent{1} F \algoEq \fla{F^1, \dots, F^{\nHd}} \\ & \indent{1} O \algoEq W_O \cdot F \\ & \indent{1} \algoReturn O \\ & \algoEndProc \end{align*} +-----------------------------------------------------+----------------------------------------------------------+ | Trainable Parameters | Nodes | +------------------+----------------------------------+----------------------+-----------------------------------+ | Parameter | Shape | Symbol | Shape | +==================+==================================+======================+===================================+ | :math:`W_{K,h}` | :math:`(d_k, \dMdl)` | :math:`F` | :math:`(B, S_q, \nHd \times d_v)` | +------------------+----------------------------------+----------------------+-----------------------------------+ | :math:`W_O` | :math:`(\dMdl, \nHd \times d_v)` | :math:`F^h` | :math:`(B, S_q, d_v)` | +------------------+----------------------------------+----------------------+-----------------------------------+ | :math:`W_{Q,h}` | :math:`(d_k, \dMdl)` | :math:`K^h` | :math:`(B, S_k, d_k)` | +------------------+----------------------------------+----------------------+-----------------------------------+ | :math:`W_{V,h}` | :math:`(d_v, \dMdl)` | :math:`O` | :math:`(B, S_q, \dMdl)` | +------------------+----------------------------------+----------------------+-----------------------------------+ | | :math:`Q^h` | :math:`(B, S_q, d_k)` | | +----------------------+-----------------------------------+ | | :math:`V^h` | :math:`(B, S_k, d_v)` | | +----------------------+-----------------------------------+ | | :math:`\attn^h` | :math:`(B, S_q, S_k)` | | +----------------------+-----------------------------------+ | | :math:`\attn_i^h` | :math:`(B, S_k)` | | +----------------------+-----------------------------------+ | | :math:`k` | :math:`(B, S_k, \dMdl)` | | +----------------------+-----------------------------------+ | | :math:`k_t` | :math:`(B, \dMdl)` | | +----------------------+-----------------------------------+ | | :math:`k_t^h` | :math:`(B, d_k)` | | +----------------------+-----------------------------------+ | | :math:`\msk` | :math:`(B, S_q, S_k)` | | +----------------------+-----------------------------------+ | | :math:`\msk_{i,j}` | :math:`(B)` | | +----------------------+-----------------------------------+ | | :math:`q` | :math:`(B, S_q, \dMdl)` | | +----------------------+-----------------------------------+ | | :math:`q_t` | :math:`(B, \dMdl)` | | +----------------------+-----------------------------------+ | | :math:`q_t^h` | :math:`(B, d_k)` | | +----------------------+-----------------------------------+ | | :math:`\Sim^h` | :math:`(B, S_q, S_k)` | | +----------------------+-----------------------------------+ | | :math:`\Sim_{i,j}^h` | :math:`(B)` | | +----------------------+-----------------------------------+ | | :math:`v` | :math:`(B, S_k, \dMdl)` | | +----------------------+-----------------------------------+ | | :math:`v_t` | :math:`(B, \dMdl)` | | +----------------------+-----------------------------------+ | | :math:`v_t^h` | :math:`(B, d_v)` | +-----------------------------------------------------+----------------------+-----------------------------------+ Model parameters in Multi-head attention layer are initialized with uniform distribution :math:`\mathcal{U}(\init_l, \init_u)`. The lower bound :math:`\init_l` and upper bound :math:`\init_u` are given as hyperparameters. Parameters ---------- d_k: int, default: 1 Number of key features :math:`d_k` in each head. d_model: int, default: 1 Number of input / output features :math:`\dMdl`. d_v: int, default: 1 Number of value features :math:`d_v` in each head. init_lower: float, default: -0.1 Uniform distribution lower bound :math:`\init_l` used to initialize model parameters. init_upper: float, default: 0.1 Uniform distribution upper bound :math:`\init_u` used to initialize model parameters. kwargs: typing.Any, optional Useless parameter. Intently left for subclasses inheritance. n_head: int, default: 1 Number of attention heads :math:`\nHd`. Attributes ---------- d_k: int Number of key features :math:`d_k` in each head. d_model: int Number of input / output features :math:`\dMdl`. d_v: int Number of value features :math:`d_v` in each head. fc_ff_f2o: torch.nn.Linear Fully connected feed-forward layer :math:`W_O` which transform features to output. No biases are used. Input shape: :math:`(B, S_q, \nHd \times d_v)`. Output shape: :math:`(B, S_q, \dMdl)`. fc_ff_k2hk: torch.nn.Linear Fully connected feed-forward layer :math:`\pa{W_{K,1}, \dots, W_{K,\nHd}}` which transform key vectors to heads. No biases are used. Input shape: :math:`(B, S_k, \dMdl)`. Output shape: :math:`(B, S_k, \nHd \times d_k)`. fc_ff_q2hq: torch.nn.Linear Fully connected feed-forward layer :math:`\pa{W_{Q,1}, \dots, W_{Q,\nHd}}` which transform query vectors to heads. No biases are used. Input shape: :math:`(B, S_q, \dMdl)`. Output shape: :math:`(B, S_q, \nHd \times d_k)`. fc_ff_v2hv: torch.nn.Linear Fully connected feed-forward layer :math:`\pa{W_{V,1}, \dots, W_{V,\nHd}}` which transform value vectors to heads. No biases are used. Input shape: :math:`(B, S_k, \dMdl)`. Output shape: :math:`(B, S_k, \nHd \times d_v)`. init_lower: float Uniform distribution lower bound :math:`\init_l` used to initialize model parameters. init_upper: float Uniform distribution upper bound :math:`\init_u` used to initialize model parameters. n_head: int Number of attention heads :math:`\nHd`. scaler: float Dot product scaler :math:`\dfrac{1}{\sqrt{d_k}}`. """ def __init__( self, *, d_k: int = 1, d_model: int = 1, d_v: int = 1, init_lower: float = -0.1, init_upper: float = 0.1, n_head: int = 1, **kwargs: Any, ): super().__init__() # `d_k` validation. lmp.util.validate.raise_if_not_instance(val=d_k, val_name='d_k', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_k], val_names=['1', 'd_k']) self.d_k = d_k self.scaler = 1 / math.sqrt(d_k) # `d_model` validation. lmp.util.validate.raise_if_not_instance(val=d_model, val_name='d_model', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_model], val_names=['1', 'd_model']) self.d_model = d_model # `d_v` validation. lmp.util.validate.raise_if_not_instance(val=d_v, val_name='d_v', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_v], val_names=['1', 'd_v']) self.d_v = d_v # `init_lower` and `init_upper` validation. lmp.util.validate.raise_if_not_instance(val=init_lower, val_name='init_lower', val_type=float) lmp.util.validate.raise_if_not_instance(val=init_upper, val_name='init_upper', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[init_lower, init_upper], val_names=['init_lower', 'init_upper']) self.init_upper = init_upper self.init_lower = init_lower # `n_head` validation. lmp.util.validate.raise_if_not_instance(val=n_head, val_name='n_head', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, n_head], val_names=['1', 'n_head']) self.n_head = n_head # Fully connected Fully connected feed-forward layers which transform query, key and value vectors to heads. # No biases are used. self.fc_ff_q2hq = nn.Linear(in_features=d_model, out_features=n_head * d_k, bias=False) self.fc_ff_k2hk = nn.Linear(in_features=d_model, out_features=n_head * d_k, bias=False) self.fc_ff_v2hv = nn.Linear(in_features=d_model, out_features=n_head * d_v, bias=False) # Fully connected feed-forward layer which transform features to output. # No biases are used. self.fc_ff_f2o = nn.Linear(in_features=n_head * d_v, out_features=d_model, bias=False)
[docs] def forward( self, k: torch.Tensor, mask: torch.Tensor, q: torch.Tensor, v: torch.Tensor, ) -> torch.Tensor: r"""Perform multi-head attention on query, key, value. Below we describe the forward pass algorithm of multi-head attention layer. #. Let ``q`` be a batch of sequences of query vectors :math:`q`. #. Let ``k`` be a batch of sequences of key vectors :math:`k`. #. Let ``v`` be a batch of sequences of value vectors :math:`v`. #. Let ``mask`` be a batch of attention mask :math:`\msk`. #. Let ``q.size(1)`` be sequence length :math:`S_q`. #. Let ``k.size(1)`` be sequence length :math:`S_k`. #. Use ``self.fc_ff_q2hq`` to transform query vectors into multi-head query vectors :math:`Q^1, \dots, Q^{\nHd}`. #. Use ``self.fc_ff_k2hk`` to transform query vectors into multi-head query vectors :math:`K^1, \dots, K^{\nHd}`. #. Use ``self.fc_ff_v2hv`` to transform query vectors into multi-head query vectors :math:`V^1, \dots, V^{\nHd}`. #. Use :math:`Q^1, \dots, Q^{\nHd}` and :math:`K^1, \dots, K^{\nHd}` to calculate similarity scores :math:`\Sim^1, \dots, \Sim^{\nHd}`. #. Use ``mask`` to mask similarity scores :math:`\Sim^1, \dots, \Sim^{\nHd}`. #. Use softmax to transform similarity scores :math:`\Sim^1, \dots, \Sim^{\nHd}` into attention scores :math:`\attn^1, \dots, \attn^{\nHd}`. #. Use attention scores :math:`\attn^1, \dots, \attn^{\nHd}` and :math:`V^1, \dots, V^{\nHd}` to calculate hidden features :math:`F^1, \dots, F^{\nHd}`. #. Use :math:`W_O` and hidden features :math:`F^1, \dots, F^{\nHd}` to calculate output :math:`O`. #. Return :math:`O`. Parameters ---------- k: torch.Tensor Batch of sequences of key vectors with shape :math:`(B, S_k, \dMdl)` and ``dtype == torch.float``. mask: torch.Tensor Batch of attention mask with shape :math:`(B, S_q, S_k)` and ``dtype == torch.bool``. Set to true to mask attention at corresponding position. q: torch.Tensor Batch of sequences of query vectors with shape :math:`(B, S_q, \dMdl)` and ``dtype == torch.float``. v: torch.Tensor Batch of sequences of key vectors with shape :math:`(B, S_k, \dMdl)` and ``dtype == torch.float``. Returns ------- torch.Tensor Batch output features :math:`O` with shape :math:`(B, S_q, \dMdl)` and ``dtype == torch.float``. """ B = q.size(0) S_q = q.size(1) S_k = k.size(1) # Shape: (B, n_head, S_q, d_k). head_q = self.fc_ff_q2hq(q).reshape(B, S_q, self.n_head, self.d_k).transpose(1, 2) # Shape: (B, n_head, d_k, S_k). head_k_T = self.fc_ff_k2hk(k).reshape(B, S_k, self.d_k, self.n_head).transpose(1, 3) # Shape: (B, n_head, S_k, d_v). head_v = self.fc_ff_v2hv(v).reshape(B, S_k, self.n_head, self.d_v).transpose(1, 2) # Shape: (B, n_head, S_q, S_k). sim = self.scaler * (head_q @ head_k_T) # Shape: (B, n_head, S_q, S_k). sim.masked_fill_(mask.unsqueeze(1), -1e9) # Shape: (B, n_head, S_q, S_k). attn = F.softmax(sim, dim=3) # Shape: (B, n_head, S_q, d_v). weighted_feat = attn @ head_v # Shape: (B, S_q, d_model). return self.fc_ff_f2o(weighted_feat.transpose(1, 2).reshape(B, S_q, self.n_head * self.d_v))
[docs] def params_init(self) -> None: r"""Initialize model parameters. All weights are initialized with uniform distribution :math:`\mathcal{U}\pa{\init_l, \init_u}`. Returns ------- None """ nn.init.uniform_(self.fc_ff_q2hq.weight, self.init_lower, self.init_upper) nn.init.uniform_(self.fc_ff_k2hk.weight, self.init_lower, self.init_upper) nn.init.uniform_(self.fc_ff_v2hv.weight, self.init_lower, self.init_upper) nn.init.uniform_(self.fc_ff_f2o.weight, self.init_lower, self.init_upper)
[docs]class PosEncLayer(nn.Module): r"""Positional Encoding :footcite:`vaswani2017attention`. - Let :math:`S` be the lookup sequence length. - Let :math:`\dEmb` be the dimension of positional encodings. Positional encodings is defined as follow: .. math:: \begin{align*} & \algoProc{\PosEncLayer}\pa{S} \\ & \indent{1} \algoFor{\pos \in \set{1, \dots, S}} \\ & \indent{2} \algoFor{i \in \set{1, \dots, \dEmb}} \\ & \indent{3} \algoIf{i \text{ is even}} \\ & \indent{4} \PE_{(\pos,i)} \algoEq \sin\pa{\dfrac{\pos}{10000^{i / \dEmb}}} \\ & \indent{3} \algoElse \\ & \indent{4} \PE_{(\pos,i)} \algoEq \cos\pa{\dfrac{\pos}{10000^{i / \dEmb}}} \\ & \indent{3} \algoEndIf \\ & \indent{2} \algoEndFor \\ & \indent{1} \algoEndFor \\ & \indent{1} \algoReturn \PE \\ & \algoEndProc \end{align*} +----------------------+------------------------------------------------+ | Trainable Parameters | Nodes | +-------------+--------+------------------------+-----------------------+ | Parameter | Shape | Symbol | Shape | +=============+========+========================+=======================+ | | :math:`\PE_{(\pos,i)}` | :math:`(1)` | | +------------------------+-----------------------+ | | :math:`\PE` | :math:`(1, S, \dEmb)` | +----------------------+------------------------+-----------------------+ Parameters ---------- d_emb: int, default: 1 Positional encoding dimension :math:`\dEmb`. kwargs: typing.Any, optional Useless parameter. Intently left for subclasses inheritance. max_seq_len: int, default: 512 Maximum length constraint on the input sequence. Attributes ---------- d_emb: int Positional encoding dimension :math:`\dEmb`. max_seq_len: int Maximum length constraint on the input sequence. pe: torch.Tensor Positional encoding lookup table. """ def __init__( self, *, d_emb: int = 1, max_seq_len: int = 512, **kwargs: Any, ): super().__init__() # `d_emb` validation. lmp.util.validate.raise_if_not_instance(val=d_emb, val_name='d_emb', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_emb], val_names=['1', 'd_emb']) self.d_emb = d_emb # `max_seq_len` validation. lmp.util.validate.raise_if_not_instance(val=max_seq_len, val_name='max_seq_len', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, max_seq_len], val_names=['1', 'max_seq_len']) self.max_seq_len = max_seq_len # Create positional encoding lookup table. # Shape: `(S, d_emb)`. pe = torch.zeros((max_seq_len, d_emb)) # Position order from `0` to `S - 1`. # Shape: `(S, 1)`. pos = torch.arange(0, max_seq_len).unsqueeze(1) # Compute the positional encodings. # Shape: `(S, d_emb)`. div_term = torch.exp(torch.arange(0, d_emb, 2) * (-math.log(10000) / d_emb)) pe[:, 0::2] = torch.sin(pos * div_term) pe[:, 1::2] = torch.cos(pos * div_term) # First dimension is set to `1` to so that ``self.pe`` can broadcast along batch dimension. pe = pe.unsqueeze(0) self.register_buffer(name='pe', tensor=pe, persistent=True)
[docs] def forward(self, seq_len: int) -> torch.Tensor: r"""Lookup positional encodings. Lookup is starting from position ``0`` and end at position ``seq_len - 1`` (exclusive). Parameters ---------- seq_len: int Sequence length :math:`S`. Returns ------- torch.Tensor Positional encodings with shape :math:`(1, S, \dEmb)` and ``dtype == torch.float``. """ # `seq_len` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[seq_len, self.max_seq_len], val_names=['seq_len', 'self.max_seq_len'], ) # Lookup positional encodings. # Shape: `(1, S, d_emb)`. return self.pe[:, :seq_len, :]
[docs] def params_init(self) -> None: r"""Do nothing. Returns ------- None """ pass
[docs]class TransEnc(BaseModel): r"""Transformer encoder :footcite:`vaswani2017attention` language model. - Let :math:`x` be batch of token ids with batch size :math:`B` and per sequence length :math:`S`. - Let :math:`c` be previous batch of token ids (previous context window) with shape :math:`(B, S')`. Note that :math:`c` can be empty. - Let :math:`S_\max` be the maximum sequence length a model can deal with. - When :math:`c` is empty, the constraint :math:`S \leq S_\max` must be satisfied. - When :math:`c` is not empty, the constraint :math:`S + S' \leq S_\max` must be satisfied. - Let :math:`V` be the vocabulary size of the paired tokenizer. Each token id represents an unique token, i.e., :math:`x_t \in \set{1, \dots, V}`. - Let :math:`E` be the token embedding lookup table. - Let :math:`\dMdl` be the dimension of token embeddings. - Let :math:`e_t` be the token embedding correspond to token id :math:`x_t`. - Token embeddings have dropout probability :math:`p`. - Let :math:`\PE` be positional encoding layer. - Let :math:`\PE_t` be the positional encoding at the :math:`t` th position. - The dimension of positional encodings is :math:`\dMdl`. - Let :math:`\nLyr` be the number of transformer encoder layers. - Let :math:`h^\ell` be the output of the :math:`\ell` th transformer encoder layer. Transformer encoder language model is defined as follow: .. math:: \begin{align*} & \algoProc{\TransEnc}\pa{x, c} \\ & \indent{1} \algoIf{c \text{ is not empty}} \\ & \indent{2} x \algoEq \cat{x, c} \\ & \indent{1} \algoEndIf \\ & \indent{1} S \algoEq x.\sz{1} \\ & \indent{1} \algoCmt{Create attention mask.} \\ & \indent{1} \algoFor{i \in \set{1, \dots, S}} \\ & \indent{2} \algoFor{j \in \set{1, \dots, S}} \\ & \indent{3} \algoIf{x_i \algoIs \text{padding}} \\ & \indent{4} \msk_{i,j} \algoEq \algoTrue \\ & \indent{3} \algoElseIf{i \leq j} \\ & \indent{4} \msk_{i,j} \algoEq \algoFalse \\ & \indent{3} \algoElse \\ & \indent{4} \msk_{i,j} \algoEq \algoTrue \\ & \indent{3} \algoEndIf \\ & \indent{2} \algoEndFor \\ & \indent{1} \algoEndFor \\ & \indent{1} \algoCmt{Lookup token embedding and positional encoding.} \\ & \indent{1} \algoFor{t \in \set{1, \dots, S}} \\ & \indent{2} e_t \algoEq (x_t)\text{-th row of } E \text{ but treated as column vector} \\ & \indent{2} h_t^0 \algoEq \drop{e_t + \PE_t}{p} \\ & \indent{1} \algoEndFor \\ & \indent{1} h^0 \algoEq \cat{h_1^0, \dots, h_S^0} \\ & \indent{1} \algoCmt{Perform forward pass on stacking Transformer encoder layers} \\ & \indent{1} \algoFor{\ell \in \set{1, \dots, \nLyr}} \\ & \indent{2} h^\ell \algoEq \TransEncLayer\pa{ k \algoEq h^{\ell-1}, \msk \algoEq \msk, q \algoEq h^{\ell-1}, v \algoEq h^{\ell-1} } \\ & \indent{1} \algoEndFor \\ & \indent{1} \algoFor{t \in \set{1, \dots, S}} \\ & \indent{2} y_t \algoEq \sof{E \cdot h_t^{\nLyr}} \\ & \indent{1} \algoEndFor \\ & \indent{1} y \algoEq \cat{y_1, \dots, y_S} \\ & \indent{1} c' \algoEq \cat{x_{\max\pa{1, S - (S_\max-2)}}, \dots, x_S} \\ & \indent{1} \algoReturn \pa{y, c'} \\ & \algoEndProc \end{align*} +-------------------------------------------+-------------------------------------------------------+ | Trainable Parameters | Nodes | +------------------+------------------------+--------------------------+----------------------------+ | Parameter | Shape | Symbol | Shape | +==================+========================+==========================+============================+ | :math:`E` | :math:`(V, \dMdl)` | :math:`\PE` | :math:`(B, S_\max, \dMdl)` | +------------------+------------------------+--------------------------+----------------------------+ | :math:`\TransEncLayer` | :math:`\PE_t` | :math:`(B, \dMdl)` | +------------------+------------------------+--------------------------+----------------------------+ | | :math:`c` | :math:`(B, S')` | | +--------------------------+----------------------------+ | | :math:`c'` | :math:`(B, S_\max-1)` | | +--------------------------+----------------------------+ | | :math:`e_t` | :math:`(B, S, \dMdl)` | | +--------------------------+----------------------------+ | | :math:`h^\ell` | :math:`(B, S, \dMdl)` | | +--------------------------+----------------------------+ | | :math:`h_t^0` | :math:`(B, \dMdl)` | | +--------------------------+----------------------------+ | | :math:`\msk` | :math:`(B, S, S)` | | +--------------------------+----------------------------+ | | :math:`\msk_{i,j}` | :math:`(B)` | | +--------------------------+----------------------------+ | | :math:`x` | :math:`(B, S)` | | +--------------------------+----------------------------+ | | :math:`x_t` | :math:`(B)` | | +--------------------------+----------------------------+ | | :math:`y` | :math:`(B, S, V)` | | +--------------------------+----------------------------+ | | :math:`y_t` | :math:`(B, V)` | +-------------------------------------------+--------------------------+----------------------------+ The goal of optimization is to minimize the negative logliklihood of next token id :math:`x_{t+1}` given :math:`y_t`. The prediction loss is defined to be the average negative logliklihood over :math:`x` given :math:`y`. .. math:: \loss = \dfrac{-1}{S} \sum_{t = 1}^S \log \Pr(x_{t+1} \vert y_t). - :math:`y_t` is the next token id prediction probability distribution over tokenizer's vocabulary. We use inner product to calculate similarity scores over all token ids, and then use softmax to normalize similarity scores into probability range :math:`[0, 1]`. - Model parameters in Transformer encoder language model are initialized with uniform distribution :math:`\mathcal{U}(\init_l, \init_u)`. The lower bound :math:`\init_l` and upper bound :math:`\init_u` of uniform distribution are given as hyperparameters. Parameters ---------- d_ff: int, default: 1 Number of hidden units :math:`\dFf` in the 2-layer fully connected feed-forward network. d_k: int, default: 1 Number of key features :math:`d_k` in each head. d_model: int, default: 1 Number of input / output features :math:`\dMdl`. d_v: int, default: 1 Number of value features :math:`d_v` in each head. init_lower: float, default: -0.1 Uniform distribution lower bound :math:`\init_l` used to initialize model parameters. init_upper: float, default: 0.1 Uniform distribution upper bound :math:`\init_u` used to initialize model parameters. kwargs: typing.Any, optional Useless parameter. Intently left for subclasses inheritance. label_smoothing: float, default: 0.0 Smoothing applied on prediction target :math:`x_{t+1}`. max_seq_len: int, default: 512 Maximum length of the input sequence. n_lyr: int, default: 1 Number of Transformer encoder layers :math:`\nLyr`. n_head: int, default: 1 Number of attention heads :math:`\nHd`. p: float, default: 0.0 Dropout probability :math:`p`. tknzr: ~lmp.tknzr.BaseTknzr Tokenizer instance. Attributes ---------- d_ff: int Number of hidden units :math:`\dFf` in the 2-layer fully connected feed-forward network. d_k: int Number of key features :math:`d_k` in each head. d_model: int Number of input / output features :math:`\dMdl`. d_v: int Number of value features :math:`d_v` in each head. emb: torch.nn.Embedding Token embedding lookup matrix. Use token ids to lookup token embeddings. init_lower: float Uniform distribution lower bound :math:`\init_l` used to initialize model parameters. init_upper: float Uniform distribution upper bound :math:`\init_u` used to initialize model parameters. input_dp: torch.nn.Dropout Dropout with probability :math:`p` applied on the sum of token embeddings and position encodings. label_smoothing: float Smoothing applied on prediction target :math:`x_{t+1}`. loss_fn: torch.nn.CrossEntropyLoss Loss function to be optimized. model_name: ClassVar[str] CLI name of Transformer encoder is ``Transformer-encoder``. p: float Dropout probability :math:`p`. pos_enc: lmp.model.PosEncLayer Positional Encoding. stack_trans_enc: torch.nn.ModuleList :py:class:`~TransEncLayer` stacking layers. The number of stacking layers is equal to :math:`\nLyr`. Input shape: :math:`(B, S, \dMdl)`. Output shape: :math:`(B, S, \dMdl)`. """ model_name: ClassVar[str] = 'Transformer-encoder' def __init__( self, *, d_ff: int = 1, d_k: int = 1, d_model: int = 1, d_v: int = 1, init_lower: float = -0.1, init_upper: float = 0.1, label_smoothing: float = 0.0, max_seq_len: int = 512, n_head: int = 1, n_lyr: int = 1, p: float = 0.0, tknzr: BaseTknzr, **kwargs: Any, ): super().__init__(**kwargs) # `d_ff` validation. lmp.util.validate.raise_if_not_instance(val=d_ff, val_name='d_ff', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_ff], val_names=['1', 'd_ff']) self.d_ff = d_ff # `d_k` validation. lmp.util.validate.raise_if_not_instance(val=d_k, val_name='d_k', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_k], val_names=['1', 'd_k']) self.d_k = d_k # `d_model` validation. lmp.util.validate.raise_if_not_instance(val=d_model, val_name='d_model', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_model], val_names=['1', 'd_model']) self.d_model = d_model # `d_v` validation. lmp.util.validate.raise_if_not_instance(val=d_v, val_name='d_v', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_v], val_names=['1', 'd_v']) self.d_v = d_v # `init_upper` and `init_lower` validation. lmp.util.validate.raise_if_not_instance(val=init_upper, val_name='init_upper', val_type=float) lmp.util.validate.raise_if_not_instance(val=init_lower, val_name='init_lower', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[init_lower, init_upper], val_names=['init_lower', 'init_upper']) self.init_upper = init_upper self.init_lower = init_lower # `label_smoothing` validation. lmp.util.validate.raise_if_not_instance(val=label_smoothing, val_name='label_smoothing', val_type=float) lmp.util.validate.raise_if_wrong_ordered( vals=[0.0, label_smoothing, 1.0], val_names=['0.0', 'label_smoothing', '1.0'], ) self.label_smoothing = label_smoothing # `max_seq_len` validation. lmp.util.validate.raise_if_not_instance(val=max_seq_len, val_name='max_seq_len', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, max_seq_len], val_names=['1', 'max_seq_len']) self.max_seq_len = max_seq_len # `n_head` validation. lmp.util.validate.raise_if_not_instance(val=n_head, val_name='n_head', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, n_head], val_names=['1', 'n_head']) self.n_head = n_head # `n_lyr` validation. lmp.util.validate.raise_if_not_instance(val=n_lyr, val_name='n_lyr', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, n_lyr], val_names=['1', 'n_lyr']) self.n_lyr = n_lyr # `p` validation. lmp.util.validate.raise_if_not_instance(val=p, val_name='p', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, p, 1.0], val_names=['0.0', 'p', '1.0']) self.p = p # Token embedding layer. # Use token ids to perform token embeddings lookup. self.emb = nn.Embedding(num_embeddings=tknzr.vocab_size, embedding_dim=d_model, padding_idx=PAD_TKID) # Positional encoding layer. # Use token ids to perform positional encoding lookup. self.pos_enc = PosEncLayer(d_emb=d_model, max_seq_len=max_seq_len, **kwargs) # Token embedding and positional encoding dropout layer. self.input_dp = nn.Dropout(p=p) # Stacking transformer encoder layers. self.stack_trans_enc = nn.ModuleList([]) for _ in range(n_lyr): self.stack_trans_enc.append( TransEncLayer( d_ff=d_ff, d_k=d_k, d_model=d_model, d_v=d_v, init_lower=init_lower, init_upper=init_upper, n_head=n_head, p=p, **kwargs, ) ) # Loss function used to optimize language model. self.loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TKID, label_smoothing=label_smoothing)
[docs] @classmethod def add_CLI_args(cls, parser: argparse.ArgumentParser) -> None: r"""Add transformer encoder language model hyperparameters to CLI arguments parser. Parameters ---------- parser: argparse.ArgumentParser CLI argument parser. Returns ------- None See Also -------- :doc:`lmp.script.train_model </script/train_model>` Language model training script. Examples -------- >>> import argparse >>> import math >>> from lmp.model import TransEnc >>> parser = argparse.ArgumentParser() >>> TransEnc.add_CLI_args(parser) >>> args = parser.parse_args([ ... '--d_ff', '2', ... '--d_k', '4', ... '--d_model', '6', ... '--d_v', '8', ... '--init_lower', '-0.01', ... '--init_upper', '0.01', ... '--label_smoothing', '0.1', ... '--n_head', '10', ... '--n_lyr', '2', ... '--p', '0.1', ... ]) >>> assert args.d_ff == 2 >>> assert args.d_k == 4 >>> assert args.d_model == 6 >>> assert args.d_v == 8 >>> assert math.isclose(args.init_lower, -0.01) >>> assert math.isclose(args.init_upper, 0.01) >>> assert math.isclose(args.label_smoothing, 0.1) >>> assert args.n_head == 10 >>> assert args.n_lyr == 2 >>> assert math.isclose(args.p, 0.1) """ # `parser` validation. lmp.util.validate.raise_if_not_instance(val=parser, val_name='parser', val_type=argparse.ArgumentParser) # Add hyperparameters to CLI arguments. group = parser.add_argument_group('Transformer encoder hyperparameters') group.add_argument( '--d_ff', default=1, help=''' Number of hidden units in the 2-layer fully connected feed-forward network. Default is ``1``. ''', type=int, ) group.add_argument( '--d_k', default=1, help=''' Number of key features in each head. Default is ``1``. ''', type=int, ) group.add_argument( '--d_model', default=1, help=''' Number of input / output features. Default is ``1``. ''', type=int, ) group.add_argument( '--d_v', default=1, help=''' Number of value features in each head. Default is ``1``. ''', type=int, ) group.add_argument( '--init_lower', default=-0.1, help=''' Uniform distribution lower bound used to initialize model parameters. Default is ``-0.1``. ''', type=float, ) group.add_argument( '--init_upper', default=0.1, help=''' Uniform distribution lower bound used to initialize model parameters. Default is ``0.1``. ''', type=float, ) group.add_argument( '--label_smoothing', default=0.0, help=''' Label smoothing applied on cross entropy loss. Default is ``0.0``. ''', type=float, ) group.add_argument( '--n_head', default=1, help=''' Number of attention heads. Default is ``1``. ''', type=int, ) group.add_argument( '--n_lyr', default=1, help=''' Number of transformer encoder layers. Default is ``1``. ''', type=int, ) group.add_argument( '--p', default=0.0, help=''' Dropout probability for all layers. Default is ``0.0``. ''', type=float, )
[docs] def cal_loss( self, batch_cur_tkids: torch.Tensor, batch_next_tkids: torch.Tensor, batch_prev_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculate language model prediction loss. We use cross entropy loss as our training objective. This method is only used for training. Parameters ---------- batch_cur_tkids: torch.Tensor Batch current input token ids. ``batch_cur_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_next_tkids: torch.Tensor Prediction target of each sample in the batch. ``batch_next_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_prev_states: typing.Optional[torch.Tensor], default: None Batch of previous token ids :math:`c`. The tensor represent the batch of token ids used in the previous context. It has shape :math:`(B, S')` and ``dtype == torch.long``. If given, it will be concatenated with ``batch_cur_tkids``. Set to ``None`` to do nothing. Returns ------- tuple[torch.Tensor, torch.Tensor] The first tensor in the tuple is the mini-batch cross-entropy loss. Loss tensor has shape :math:`(1)` and ``dtype == torch.float``. The second tensor in the tuple is a batch of the token ids used in forward pass (we denoted it as :math:`c'` in our definition). The second tensor has shape :math:`(B, \min(S, S_\max-1))` and ``dtype == torch.long``. """ # Get next token id logits and last hidden states. # Logits shape: (B, S, V). # Each tensor in `batch_cur_states` has shape: (B, d_model). logits, batch_cur_states = self(batch_cur_tkids=batch_cur_tkids, batch_prev_states=batch_prev_states) # Calculate cross-entropy loss. # We reshape `logits` to (B x S, V) and `batch_next_tkids` to (B x S). # This is needed since this is how PyTorch design its API. # shape: (1). loss = self.loss_fn(logits.reshape(-1, self.emb.num_embeddings), batch_next_tkids.reshape(-1)) # Return loss and last hidden states. return (loss, batch_cur_states)
[docs] def create_mask(self, batch_tkids: torch.Tensor) -> torch.Tensor: r"""Create self-attention mask for ``batch_tkids``. Self-attention mask is created as follow: #. Create auto-regressive mask by masking everything above diagnoal. This is needed since input token at each time step can only see input tokens at previous time steps and itself. #. Create padding masks by masking every positions correspond to padding tokens. This is needed since paddings are meaningless. Parameters ---------- batch_tkids: torch.Tensor Batch of token ids with shape ``(B, S)`` and ``dtype == torch.long``. Returns ------- torch.Tensor: Auto-regressive self attention mask and padding mask. Returned tensor has shape ``(B, S, S)`` and ``dtype == torch.bool``. """ # Get batch size. B = batch_tkids.size(0) S = batch_tkids.size(1) # Create auto-regressive mask. # Shape: `(B, S, S)`. reg_mask = torch.ones((B, S, S), dtype=torch.bool) reg_mask = torch.triu(reg_mask, diagonal=1).to(batch_tkids.device) # Create padding mask. # Shape: `(B, S, S)`. pad_mask = (batch_tkids == PAD_TKID) pad_mask = torch.stack([pad_mask] * S, dim=1) pad_mask = pad_mask | pad_mask.transpose(1, 2) # Attention mask is consist of auto-regressive mask and padding mask. # Shape: `(B, S, S)`. return reg_mask | pad_mask
[docs] def forward( self, batch_cur_tkids: torch.Tensor, batch_prev_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculate next token id logits. Logits were calculated based on previous hidden states and current input token ids. Use :py:meth:`~pred` to convert logits into next token id probability distribution over tokenizer's vocabulary. Use :py:meth:`~cal_loss` to convert logits into next token id prediction loss. Below we describe the forward pass algorithm of Transformer encoder language model. #. Use token ids to lookup token embeddings with ``self.emb``. #. Use sequence length to lookup positional encodings with ``self.pos_enc``. #. Apply dropout to the sum of token embeddings and positional encodings. #. Feed the result into transformer encoder layer. We use teacher forcing in this step when perform training, i.e., inputs are directly given instead of generated by model. #. Feed the output of previous transformer encoder layer into next transformer encoder layer until all layers have been used once. #. Perform inner product on the output of the last transformer encoder layer and token embeddings to get similarity scores. #. Return similarity scores (logits). Parameters ---------- batch_cur_tkids: torch.Tensor Batch current input token ids. ``batch_cur_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_prev_states: typing.Optional[torch.Tensor], default: None Batch of previous token ids :math:`c`. The tensor represent the batch of token ids used in the previous context. It has shape :math:`(B, S')` and ``dtype == torch.long``. If given, it will be concatenated with ``batch_cur_tkids``. Set to ``None`` to do nothing. Returns ------- tuple[torch.Tensor, torch.Tensor] The first tensor in the tuple is the batch of next token id logits with shape :math:`(B, S, V)` and ``dtype == torch.float``. The second tensor in the tuple is a batch of the token ids used in forward pass (we denoted it as :math:`c'` in our definition). The second tensor has shape :math:`(B, \min(S, S_\max-1))` and ``dtype == torch.long``. """ # Concate token ids if ``batch_prev_state is not None``. if batch_prev_states is None: x = batch_cur_tkids else: x = torch.hstack([batch_prev_states, batch_cur_tkids]) # Token embedding lookup. # Shape: `(B, S, d_model)`. e = self.emb(x) # Positional encoding lookup. # Shape: `(B, S, d_model)`. pos = self.pos_enc(seq_len=x.size(1)) # Create attention mask. # Shape: `(B, S, S)` mask = self.create_mask(batch_tkids=x) # Loop through each layer. trans_enc_lyr_in = self.input_dp(e + pos) for lyr in range(self.n_lyr): # Get `lyr`-th transformer encoder layer. trans_enc_lyr = self.stack_trans_enc[lyr] # Feed previous transformer encoder layer output to next transformer encoder layer. # Shape: `(B, S, d_model)`. trans_enc_lyr_out = trans_enc_lyr(mask=mask, x=trans_enc_lyr_in) # Update Transformer encoder layer's input. trans_enc_lyr_in = trans_enc_lyr_out # Calculate similarity scores by calculating inner product over all token embeddings. # Shape: (B, S, V). sim = trans_enc_lyr_out @ self.emb.weight.transpose(0, 1) sim = sim[:, -batch_cur_tkids.size(1):, :] # Record token ids participated in the forward pass. # Maximum recording length is equal to ``self.max_seq_len - 1``. batch_cur_states = x[:, -self.max_seq_len + 1:] return (sim, batch_cur_states)
[docs] def params_init(self) -> None: r"""Initialize model parameters. All weights and biases are initialized with uniform distribution :math:`\mathcal{U}(\init_l, \init_u)`. Returns ------- None See Also -------- ~TransEncLayer.params_init Transformer encoder layer parameter initialization. """ nn.init.uniform_(self.emb.weight, self.init_lower, self.init_upper) for lyr in range(self.n_lyr): self.stack_trans_enc[lyr].params_init()
[docs] @torch.no_grad() def pred( self, batch_cur_tkids: torch.Tensor, batch_prev_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculate next token id probability distribution over tokenizer's vocabulary. Probabilities were calculated based on previous hidden states and current input token id. This method must only be used for inference. No tensor graphs will be constructed and no gradients will be calculated. Parameters ---------- batch_cur_tkids: torch.Tensor Batch current input token ids. ``batch_cur_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_prev_states: typing.Optional[torch.Tensor], default: None Batch of previous token ids :math:`c`. The tensor represent the batch of token ids used in the previous context. It has shape :math:`(B, S')` and ``dtype == torch.long``. If given, it will be concatenated with ``batch_cur_tkids``. Set to ``None`` to do nothing. Returns ------- tuple[torch.Tensor, torch.Tensor] The first tensor in the tuple is the batch of next token id probability distribution over the paired tokenizer's vocabulary. Probability tensor has shape :math:`(B, S, V)` and ``dtype == torch.float``. The second tensor in the tuple is a batch of the token ids used in forward pass (we denoted it as :math:`c'` in our definition). The second tensor has shape :math:`(B, \min(S, S_\max-1))` and ``dtype == torch.long``. """ # Get next token id logits and the last hidden states. # Logits shape: (B, S, V). # Each tensor in `batch_cur_states` has shape: (B, d_model). logits, batch_cur_states = self(batch_cur_tkids=batch_cur_tkids, batch_prev_states=batch_prev_states) # Calculate next token id probability distribution using softmax. # shape: (B, S, V). return (F.softmax(logits, dim=-1), batch_cur_states)
[docs]class TransEncLayer(nn.Module): r"""Transformer encoder layer :footcite:`vaswani2017attention`. - Let :math:`B` be mini-batch size. - Let :math:`S` be the length of each sequence in a mini-batch. - Let :math:`\dMdl` be the number of features per time step in each sequence. - Let :math:`x` be a batch of sequences of features with shape :math:`(B, S, \dMdl)`. - Let :math:`\msk` be a batch of attention mask with shape :math:`(B, S, S)`. - Let :math:`\nHd` be the number of attention heads. - Let :math:`d_k` be the number of key features in each attention head. - Let :math:`d_v` be the number of value features in each attention head. - Let :math:`\dFf` be the number of hidden units in the 2-layer fully connected feed-forward network. - Let :math:`p` be the dropout probability. Transformer encoder layer is defined as follow: .. math:: \begin{align*} & \algoProc{\TransEncLayer}(\msk, x) \\ & \indent{1} y_1 \algoEq \MultiHeadAttnLayer\pa{k \algoEq x, \msk \algoEq \msk, q \algoEq x, v \algoEq x} \\ & \indent{1} y_2 \algoEq \LayerNorm_1\pa{x + \drop{y_1}{p}} \\ & \indent{1} y_3 \algoEq W_2 \cdot \max\pa{\mathbf{0}, W_1 \cdot y_2 + b_1} + b_2 \\ & \indent{1} y_4 \algoEq \LayerNorm_2\pa{y_2 + \drop{y_3}{p}} \\ & \indent{1} \algoReturn y_4 \\ & \algoEndProc \end{align*} +-------------------------------------+--------------------------------------------+ | Trainable Parameters | Nodes | +-------------+-----------------------+--------------------+-----------------------+ | Parameter | Shape | Symbol | Shape | +=============+=======================+====================+=======================+ | :math:`W_1` | :math:`(\dFf, \dMdl)` | :math:`\mathbf{0}` | :math:`(B, S, \dFf)` | +-------------+-----------------------+--------------------+-----------------------+ | :math:`W_2` | :math:`(\dMdl, \dFf)` | :math:`\msk` | :math:`(B, S, S)` | +-------------+-----------------------+--------------------+-----------------------+ | :math:`b_1` | :math:`(\dFf)` | :math:`x` | :math:`(B, S, \dMdl)` | +-------------+-----------------------+--------------------+-----------------------+ | :math:`b_2` | :math:`(\dMdl)` | :math:`y_1` | :math:`(B, S, \dMdl)` | +-------------+-----------------------+--------------------+-----------------------+ | :math:`\MultiHeadAttnLayer` | :math:`y_2` | :math:`(B, S, \dMdl)` | +-------------------------------------+--------------------+-----------------------+ | :math:`\LayerNorm_1` | :math:`y_3` | :math:`(B, S, \dMdl)` | +-------------------------------------+--------------------+-----------------------+ | :math:`\LayerNorm_2` | :math:`y_4` | :math:`(B, S, \dMdl)` | +-------------------------------------+--------------------+-----------------------+ Model parameters in Transformer encoder layer are initialized with uniform distribution :math:`\mathcal{U}(\init_l, \init_u)`. The lower bound :math:`\init_l` and upper bound :math:`\init_u` are given as hyperparameters. Parameters ---------- d_ff: int Number of hidden units :math:`\dFf` in the 2-layer fully connected feed-forward network. d_k: int, default: 1 Number of key features :math:`d_k` in each head. d_model: int, default: 1 Number of input / output features :math:`\dMdl`. d_v: int, default: 1 Number of value features :math:`d_v` in each head. init_lower: float, default: -0.1 Uniform distribution lower bound :math:`\init_l` used to initialize model parameters. init_upper: float, default: 0.1 Uniform distribution upper bound :math:`\init_u` used to initialize model parameters. kwargs: typing.Any, optional Useless parameter. Intently left for subclasses inheritance. n_head: int, default: 1 Number of attention heads :math:`\nHd`. p: float, default: 0.0 Dropout probability :math:`p`. Attributes ---------- d_ff: int Number of hidden units :math:`\dFf` in the 2-layer fully connected feed-forward network. d_k: int Number of key features :math:`d_k` in each head. d_model: int Number of input / output features :math:`\dMdl`. d_v: int Number of value features :math:`d_v` in each head. ffn: torch.nn.Sequential 2-layer fully connected feed-forward network with parameters :math:`W_1, W_2, b_1, b_2`. Dropout with probability :math:`p` is applied to output. Input shape: :math:`(B, S, \dMdl)`. Output shape: :math:`(B, S, \dMdl)`. init_lower: float Uniform distribution lower bound :math:`\init_l` used to initialize model parameters. init_upper: float Uniform distribution upper bound :math:`\init_u` used to initialize model parameters. ln_1: torch.nn.LayerNorm Correspond to :math:`\LayerNorm_1`. Input shape: :math:`(B, S, \dMdl)`. Output shape: :math:`(B, S, \dMdl)`. ln_2: torch.nn.LayerNorm Correspond to :math:`\LayerNorm_2`. Input shape: :math:`(B, S, \dMdl)`. Output shape: :math:`(B, S, \dMdl)`. mha: ~MultiHeadAttnLayer Multi-head self attention layer. Multi-head attention is calculated through :math:`\MultiHeadAttnLayer` and self-attention is achieved by giving identical input to query, key and vector. Input shape: :math:`(B, S, \dMdl)`. Output shape: :math:`(B, S, \dMdl)`. mha_dp: torch.nn.Dropout Perform dropout with probability :math:`p` on the output of multi-head self attention. Input shape: :math:`(B, S, \dMdl)`. Output shape: :math:`(B, S, \dMdl)`. n_head: int Number of attention heads :math:`\nHd`. p: float Dropout probability :math:`p`. See Also -------- ~MultiHeadAttnLayer Multi-head attention layer. """ def __init__( self, *, d_ff: int = 1, d_k: int = 1, d_model: int = 1, d_v: int = 1, init_lower: float = -0.1, init_upper: float = 0.1, n_head: int = 1, p: float = 0.0, **kwargs: Any, ): super().__init__() # `d_ff` validation. lmp.util.validate.raise_if_not_instance(val=d_ff, val_name='d_ff', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_ff], val_names=['1', 'd_ff']) self.d_ff = d_ff # `d_k` validation. lmp.util.validate.raise_if_not_instance(val=d_k, val_name='d_k', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_k], val_names=['1', 'd_k']) self.d_k = d_k # `d_model` validation. lmp.util.validate.raise_if_not_instance(val=d_model, val_name='d_model', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_model], val_names=['1', 'd_model']) self.d_model = d_model # `d_v` validation. lmp.util.validate.raise_if_not_instance(val=d_v, val_name='d_v', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, d_v], val_names=['1', 'd_v']) self.d_v = d_v # `init_lower` and `init_upper` validation. lmp.util.validate.raise_if_not_instance(val=init_lower, val_name='init_lower', val_type=float) lmp.util.validate.raise_if_not_instance(val=init_upper, val_name='init_upper', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[init_lower, init_upper], val_names=['init_lower', 'init_upper']) self.init_upper = init_upper self.init_lower = init_lower # `n_head` validation. lmp.util.validate.raise_if_not_instance(val=n_head, val_name='n_head', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, n_head], val_names=['1', 'n_head']) self.n_head = n_head # `p` validation. lmp.util.validate.raise_if_not_instance(val=p, val_name='p', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, p, 1.0], val_names=['0.0', 'p', '1.0']) self.p = p # Multi-head attention layer. self.mha = MultiHeadAttnLayer( d_k=d_k, d_model=d_model, d_v=d_v, init_lower=init_lower, init_upper=init_upper, n_head=n_head, **kwargs, ) # Dropout is applied to the output of multi-head attention layer. self.mha_dp = nn.Dropout(p=p) # 2-layer fully connected feed-forward network. # Dropout is applied to the output. self.ffn = nn.Sequential( nn.Linear(in_features=d_model, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=d_model), nn.Dropout(p=p), ) # 2 different Layer Norm layer. self.ln_1 = nn.LayerNorm(normalized_shape=[d_model]) self.ln_2 = nn.LayerNorm(normalized_shape=[d_model])
[docs] def forward(self, mask: torch.Tensor, x: torch.Tensor) -> torch.Tensor: r"""Calculate batch of hidden states for ``x``. Below we describe the forward pass algorithm of transformer encoder layer. #. Let ``x`` be a batch of sequences of features :math:`x`. #. Let ``mask`` be a batch of attention mask :math:`\msk`. #. Use ``self.mha`` to perform multi-head self attention on ``x`` and get :math:`y_1`. #. Use ``self.mha_dp`` to perform dropout on :math:`y_1`. #. Add :math:`x` and :math:`y_1` (with dropout applied) and use ``self.ln_1`` to perform layer normalization on the addition result to get :math:`y_2`. #. Use ``self.ffn`` to perform 2-layer fully connected feed-forward network forward pass and get :math:`y_3`. #. Add :math:`y_2` and :math:`y_3` (with dropout applied) and use ``self.ln_2`` to perform layer normalization on the addition result to get :math:`y_4`. #. Return :math:`y_4`. Parameters ---------- x: torch.Tensor Batch of sequences of features with shape :math:`(B, S, \dMdl)` and ``dtype == torch.float``. mask: torch.Tensor Batch of attention mask with shape :math:`(B, S, S)` and ``dtype == torch.bool``. Set to true to mask attention at corresponding position. Returns ------- torch.Tensor Batch of sequences of output features :math:`y_4` with shape :math:`(B, S, \dMdl)` and ``dtype == torch.float``. """ # Perform multi-head self attention. # Shape: (B, S, d_model). mha_out = self.mha(q=x, k=x, v=x, mask=mask) # Apply dropout and residual connection, then perform layer normalization. # Shape: (B, S, d_model). x = self.ln_1(x + self.mha_dp(mha_out)) # Feed to 2-layer fully connected feed-forward network. # Shape: (B, S, d_model). ffn_out = self.ffn(x) # Apply dropout and residual connection, then perform layer normalization. # Shape: (B, S, d_model). return self.ln_2(x + ffn_out)
[docs] def params_init(self) -> None: r"""Initialize model parameters. All weights and biases are initialized with uniform distribution :math:`\mathcal{U}\pa{\init_l, \init_u}`. Returns ------- None See Also -------- ~MultiHeadAttnLayer.params_init Multi-head attention layer parameter initialization. """ # Initialize multi-head attention layer. self.mha.params_init() # Initialize 2-layer fully connected feed-forward network. nn.init.uniform_(self.ffn[0].weight, self.init_lower, self.init_upper) nn.init.uniform_(self.ffn[0].bias, self.init_lower, self.init_upper) nn.init.uniform_(self.ffn[2].weight, self.init_lower, self.init_upper) nn.init.uniform_(self.ffn[2].bias, self.init_lower, self.init_upper)