Source code for lmp.infer._top_k

"""Top-K inference method."""

import argparse
from typing import Any, ClassVar, List

import torch

import lmp.util.validate
from lmp.infer._base import BaseInfer
from lmp.model import BaseModel
from lmp.tknzr._base import EOS_TKID, PAD_TKID, BaseTknzr


[docs]class TopKInfer(BaseInfer): """Top-K inference method. For each inference step, this method pick the token id with the **top-K highest probability** from next token id probability distribution over tokenizer's vocabulary as the next token id prediction. It is a non-greedy algorithm since the best prediction (which corresponds to the highest probability) is not guaranteed to be chosen. In exchange, it has higher diversity on generation results compare to :py:class:`~lmp.infer.Top1Infer`. Parameters ---------- k: int, default: 5 Number of token ids to be sampled. max_seq_len: str, default 32 Maximum length constraint on generated token list. One can use larger contraint compare to training. kwargs: typing.Any, optional Useless parameter. Intently left for subclasses inheritance. Attributes ---------- infer_name: ClassVar[str] CLI name of top-K inference method is ``top-K``. k: int Number of token ids to be sampled. See Also -------- :doc:`lmp.infer </infer/index>` All available inference methods. :doc:`lmp.script.gen_txt </script/gen_txt>` Use pre-trained language model checkpoint to generate continual text of given text segment. :py:class:`~lmp.infer.Top1Infer` Top-1 inference method. """ infer_name: ClassVar[str] = 'top-K' def __init__(self, *, k: int = 5, max_seq_len: int = 32, **kwargs: Any): super().__init__(max_seq_len=max_seq_len) # `k` validation. lmp.util.validate.raise_if_not_instance(val=k, val_name='k', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[1, k], val_names=['1', 'k']) self.k = k
[docs] @classmethod def add_CLI_args(cls, parser: argparse.ArgumentParser) -> None: """Add top-K inference method hyperparameters to CLI argument parser. Parameters ---------- parser: argparse.ArgumentParser CLI argument parser. Returns ------- None See Also -------- :doc:`lmp.script.gen_txt </script/gen_txt>` Use pre-trained language model checkpoint to generate continual text of given text segment. Examples -------- >>> import argparse >>> from lmp.infer import TopKInfer >>> parser = argparse.ArgumentParser() >>> TopKInfer.infer_parser(parser) >>> args = parser.parse_args(['--k', '10']) >>> assert args.k == 10 """ super().add_CLI_args(parser=parser) # Required arguments. group = parser.add_argument_group('top-K inference method arguments') group.add_argument( '--k', default=5, help=''' Number of token ids to be sampled. Default is ``5``. ''', type=int, )
[docs] @torch.no_grad() def gen(self, model: BaseModel, tknzr: BaseTknzr, txt: str) -> str: """Generate continual text conditioned on given text segment. Top-K inference algorithm is structured as follow: #. Encode input text as 1 sequence batch. #. Remove token ids after ``<eos>`` since model is not trained to predict tokens after seeing ``<eos>``. #. Loop over conditional token ids to generate conditional hidden states. #. Loop to generate token ids. In each iteration, generated token id was choosed so that it is one of the top-K highest probabilities from next token id probability distribution. Generation loop stops when ``<eos>`` is generated or maximum length constraint is violated. #. Decode generated token ids into text and return. Parameters ---------- model: ~lmp.model.BaseModel Pre-trained language model which will be used to generate text. tknzr: ~lmp.tknzr.BaseTknzr Pre-trained tokenizer which performs text encoding and decoding. txt: str Text segment which the generation process is conditioned on. Returns ------- str Generated text. """ # Get model running device. device = next(model.parameters()).device # Encode as 1 sequence batch. # We convert token ids to tensor and move tensor to the same running device as model. # shape: (1, S). batch_cur_tkids = torch.LongTensor([tknzr.enc(txt=txt)]).to(device) # Remove token ids after `<eos>` since model is not trained to predict tokens after seeing `<eos>`. mask = (batch_cur_tkids == EOS_TKID) | (batch_cur_tkids == PAD_TKID) seq_len = batch_cur_tkids.size(1) - mask.sum() batch_cur_tkids = batch_cur_tkids[:, :seq_len] # Loop over conditioned token ids to generate conditioned hidden states. batch_prev_states = None for i in range(seq_len - 1): _, batch_cur_states = model.pred( batch_cur_tkids=batch_cur_tkids[:, i].unsqueeze(1), batch_prev_states=batch_prev_states, ) # Update hidden states. batch_prev_states = batch_cur_states # Calculate how many token at most can be generated. out_seq_len = self.max_seq_len - seq_len + 1 # Generate token ids. # shape: (1, 1). batch_cur_tkids = batch_cur_tkids[:, -1].unsqueeze(1) gen_tkids: List[int] = [] for _ in range(out_seq_len): # Get next token id probability distribution. # shape: (1, 1, V). batch_next_tkids_pd, batch_cur_states = model.pred( batch_cur_tkids=batch_cur_tkids, batch_prev_states=batch_prev_states, ) # Get top-K highest probabilities from next token id probability distribution. # shape: (1, 1, k). batch_next_tkids_topk_p, batch_next_tkids_topk = batch_next_tkids_pd.topk(k=self.k, dim=2) # Reshape probability tensor to perform sampling. # shape: (1, k). batch_next_tkids_topk_p = batch_next_tkids_topk_p.reshape(-1, self.k) # Use the top-K highest probabilities to construct multinomial distribution. # Then sample token id from multinomial distribution as the next token id prediction. # `batch_next_tkids_topk_sample` shape: (1, 1). batch_next_tkids_topk_sample = torch.multinomial(batch_next_tkids_topk_p, num_samples=1) # Use sampled result to fetch next token id prediction. # shape: (1, 1). batch_next_tkids = torch.gather( input=batch_next_tkids_topk, dim=2, index=batch_next_tkids_topk_sample.unsqueeze(2), ).squeeze(1) gen_tkid = int(batch_next_tkids[0, 0].item()) gen_tkids.append(gen_tkid) # Update input token ids. batch_cur_tkids = batch_next_tkids # Update hidden states. batch_prev_states = batch_cur_states # If the prediction token id is `<eos>`, then stop generation immediately. if gen_tkid == EOS_TKID: break # Output generated text. return tknzr.dec(tkids=gen_tkids, rm_sp_tks=True)