lmp.infer._top_p
#
Top-P inference method.
- class lmp.infer._top_p.TopPInfer(*, max_seq_len: int = 32, p: float = 0.9, **kwargs: Any)[source]#
Bases:
BaseInfer
Top-P inference method.
Top-P sampling, also called nucleus sampling 1, is similar to top-K sampling but \(k\) changes in each inference step. \(p\) is used as cumulative probability threshold and \(k\) is choosed so that the top-K highest probabilities have cumulative probability less than or equal to \(p\). Top-P sampling is a non-greedy algorithm.
- Parameters
See also
- lmp.infer
All available inference methods.
- lmp.script.gen_txt
Use pre-trained language model checkpoint to generate continual text of given text segment.
TopKInfer
Top-K inference method.
- classmethod add_CLI_args(parser: ArgumentParser) None [source]#
Add top-P inference method hyperparameters to CLI argument parser.
- Parameters
parser (argparse.ArgumentParser) – CLI argument parser.
- Return type
None
See also
- lmp.script.gen_txt
Use pre-trained language model checkpoint to generate continual text of given text segment.
Examples
>>> import argparse >>> import math >>> from lmp.infer import TopPInfer >>> parser = argparse.ArgumentParser() >>> TopPInfer.infer_parser(parser) >>> args = parser.parse_args(['--p', '0.9']) >>> assert math.isclose(args.p, 0.9)
- gen(model: BaseModel, tknzr: BaseTknzr, txt: str) str [source]#
Generate continual text conditioned on given text segment.
Top-P inference algorithm is structured as follow:
Encode input text as 1 sequence batch.
Remove token ids after
<eos>
since model is not trained to predict tokens after seeing<eos>
.Loop over conditional token ids to generate conditional hidden states.
Loop to generate token ids. In each iteration, generated token id was choosed so that it is one of the top-K highest probabilities from next token id probability distribution, where \(k\) is the number of token ids whose cumulative probabilities (probabilities are sorted in desending order) are less than or equal to
self.p
. Generation loop stops when<eos>
is generated or maximum length constraint is violated.Decode generated token ids into text and return.
- 1
Ari Holtzman, Jan Buys, Li Du, Maxwell Forbes, and Yejin Choi. The curious case of neural text degeneration. In International Conference on Learning Representations. 2020. URL: https://openreview.net/forum?id=rygGQyrFvH.