Source code for lmp.model._base

"""Language model base class."""

import abc
import argparse
from typing import Any, ClassVar, Tuple

import torch


[docs]class BaseModel(abc.ABC, torch.nn.Module): r"""Language model abstract base class. Implement basic functionalities of language model, including training loss calculation, next token id prediction and parsing training arguments. Let :math:`X = \set{x^1, x^2, \dots, x^B}` be a mini-batch of token id lists with batch size :math:`B`. A token id list :math:`x \in X` is defined as follow: .. math:: x = \pa{x_1, x_2, \dots, x_S, x_{S+1}}. - :math:`x` has length :math:`S+1`. - :math:`x_t` is the :math:`t`\-th time step of :math:`x`, the range of :math:`t` is :math:`\set{1, \dots, S+1}`. - :math:`x_1` is the token id of ``<bos>``. - :math:`x_{S+1}` is the token id of ``<eos>``. - Each language model will be paired with one tokenizer. Let :math:`V` be the size of the paired tokenizer's vocabulary. Then :math:`x_t \in \set{1, \dots, V}`. The training goal of a language model with parameter :math:`\theta` is to find an optimal parameter :math:`\theta^{\star}`, such that when replace :math:`\theta` with :math:`\theta^{\star}`, it maximizes the prediction probability of next token id :math:`x_{t+1}` given :math:`x_1, \dots, x_t`: .. math:: \theta^{\star} = \arg\max_{\theta} \prod_{x \in X} \prod_{t = 1}^S P(x_{t+1} \vert x_1, \dots, x_t ; \theta) Note that all token id lists in :math:`X` have the same length :math:`S+1` and :math:`t` start with :math:`1`. Thus for each token id list :math:`x \in X`, the first :math:`S` token ids are served as input, and the last :math:`S` token ids are served as prediction target. There are only :math:`S` positions contribute to loss. Parameters ---------- kwargs: typing.Any, optional Useless parameter. Intently left for subclasses inheritance. Attributes ---------- model_name: typing.ClassVar[str] CLI name of the model. Only used to parse CLI arguments. """ model_name: ClassVar[str] = 'base' def __init__(self, **kwargs: Any): super().__init__()
[docs] @classmethod @abc.abstractmethod def add_CLI_args(cls, parser: argparse.ArgumentParser) -> None: """Add language model hyperparameters to CLI argument parser. Parameters ---------- parser: argparse.ArgumentParser CLI argument parser. Returns ------- None See Also -------- :doc:`lmp.script.train_model </script/train_model>` Language model training script. """ raise NotImplementedError
[docs] @abc.abstractmethod def cal_loss( self, batch_cur_tkids: torch.Tensor, batch_next_tkids: torch.Tensor, batch_prev_states: Any = None, ) -> Tuple[torch.Tensor, Any]: """Calculate language model prediction loss. Loss is defined as the **next token id distribution difference** between model output and the answer. Predicting next token is treated as a classification problem, where the number of classes equals to tokenizer's vocabulary size. This method is only used for training. Parameters ---------- batch_cur_tkids: torch.Tensor Batch of current input token ids. ``batch_cur_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_next_tkids: torch.Tensor Prediction target of each sample in the batch. ``batch_next_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_prev_states: typing.Any, default: None Batch of previous calculation results. Set to ``None`` to use initial hidden states. Different models may have different hidden states structure. Returns ------- tuple[torch.Tensor, typing.Any] The first item in the tuple is the mini-batch loss with shape :math:`(1)` and ``dtype == torch.float``. The second item in the tuple represent the current hidden states. Different models may have different hidden states structure. """ raise NotImplementedError
[docs] @abc.abstractmethod def forward( self, batch_cur_tkids: torch.Tensor, batch_prev_states: Any = None, ) -> Tuple[torch.Tensor, Any]: r"""Calculate next token id logits. Logits were calculated based on previous hidden states and current input token id. Use :py:meth:`~pred` to convert logits into next token id probability distribution over tokenizer's vocabulary. Use :py:meth:`~cal_loss` to convert logits into next token id prediction loss. Parameters ---------- batch_cur_tkids: torch.Tensor Batch of current input token ids. ``batch_cur_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_prev_states: typing.Any, default: None Batch of previous calculation results. Set to ``None`` to use initial hidden states. Different models may have different hidden states structure. Returns ------- tuple[torch.Tensor, Any] The first item in the tuple is the batch of next token id logits with shape :math:`(B, S, V)` and ``dtype == torch.float``. The second item in the tuple represent the current hidden states. Different models may have different hidden states structure. See Also -------- ~lmp.tknzr.BaseTknzr.enc Source of token ids. """ raise NotImplementedError
[docs] @abc.abstractmethod def params_init(self) -> None: """Initialize model parameters. The ways and values to initialize a model are consider as hyperparameters. Different models may have different initialization sheme. Returns ------- None """ raise NotImplementedError
[docs] @torch.no_grad() @abc.abstractmethod def pred( self, batch_cur_tkids: torch.Tensor, batch_prev_states: Any = None, ) -> Tuple[torch.Tensor, Any]: """Calculate next token id probability distribution over tokenizer's vocabulary. Probability distribution is calculated based on previous hidden states and current input token id. This method is only used for inference. For training use :py:meth:`~cal_loss` instead. No tensor graphs are constructed and no gradients are calculated. Parameters ---------- batch_cur_tkids: torch.Tensor Batch of current input token ids. ``batch_cur_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_prev_states: typing.Any, default: None Batch of previous calculation results. Set to ``None`` to use initial hidden states. Different models may have different hidden states structure. Returns ------- tuple[torch.Tensor, Any] The first item in the tuple is the batch of next token id probability distributions over the tokenizer's vocabulary. Probability tensor has shape :math:`(B, S, V)` and ``dtype == torch.float``. The second item in the tuple represent the current hidden states. Different models may have different hidden states structure. """ raise NotImplementedError