Source code for lmp.util.metric

"""Evaluation metrics."""

import torch

from lmp.vars import PAD_TKID


[docs]def nll(batch_tkids: torch.Tensor, batch_tkids_pd: torch.Tensor, use_log2: bool = True) -> torch.Tensor: r"""Calculate negative log-likelihood :math:`-\log p` on batch token ids. Let :math:`x = \pa{x^1, \dots, x^B}` be a batch of token sequences. Suppose that each token sequence has length :math:`S+1` and each token is defined in a vocabulary with size :math:`V`. Let :math:`x^b = \pa{x_1^b, \dots, x_{S+1}^b}` be the :math:`b`-th token sequence in the batch :math:`x`. Suppose that the probability :math:`\Pr\pa{x_{t+1}^b \vert x_1^b, \dots, x_t^b}` of the next token :math:`x_{t+1}^b` when seeing previous tokens :math:`x_1^b, \dots, x_t^b` is given. Let :math:`R` be the returned tensor with shape :math:`(B, S)`. Then the :math:`b`-th row, :math:`t`-th column of the returned tensor :math:`R` is defined as follow: .. math:: R_{b,t} = -\log_2 \Pr\pa{x_{t+1} \vert x_1^b, \dots, x_t^b}. If :math:`x_{t+1}^b` is padding token, then we assign :math:`R_{b,t}` to zero. Parameters ---------- batch_tkids: torch.Tensor Batch of token ids which represent prediction targets. ``batch_tkids`` has shape :math:`(B, S)` and ``dtype == torch.long``. batch_tkids_pd: torch.Tensor Batch of token id prediction probability distributions. ``batch_tkids_pd`` has shape :math:`(B, S, V)` and ``dtype == torch.float``. use_log2: bool, default: True Set to ``True`` to use :math:`\log_2`. Set to ``False`` to use :math:`\ln`. Returns ------- torch.Tensor :math:`-\log p` tensor. Returned tensor has shape :math:`(B, S)` and ``dtype == torch.float``. """ # Get target token id's probabilities. # Use `batch_tkids` as indices to gather values from probability distribution. # Since prediction has shape `(B, S, V)`, we need to gather along the `V` dimension. # shape: (B, S). batch_tkids_p = torch.gather(input=batch_tkids_pd, dim=2, index=batch_tkids.unsqueeze(2)).squeeze(2) # Mask `PAD_TKID`. mask = batch_tkids != PAD_TKID if use_log2: return mask * -batch_tkids_p.log2() return mask * -batch_tkids_p.log()