lmp.infer._top_k#

Top-K inference method.

class lmp.infer._top_k.TopKInfer(*, k: int = 5, max_seq_len: int = 32, **kwargs: Any)[source]#

Bases: BaseInfer

Top-K inference method.

For each inference step, this method pick the token id with the top-K highest probability from next token id probability distribution over tokenizer’s vocabulary as the next token id prediction. It is a non-greedy algorithm since the best prediction (which corresponds to the highest probability) is not guaranteed to be chosen. In exchange, it has higher diversity on generation results compare to Top1Infer.

Parameters
  • k (int, default: 5) – Number of token ids to be sampled.

  • max_seq_len (str, default 32) – Maximum length constraint on generated token list. One can use larger contraint compare to training.

  • kwargs (Any, optional) – Useless parameter. Intently left for subclasses inheritance.

infer_name#

CLI name of top-K inference method is top-K.

Type

ClassVar[str]

k#

Number of token ids to be sampled.

Type

int

See also

lmp.infer

All available inference methods.

lmp.script.gen_txt

Use pre-trained language model checkpoint to generate continual text of given text segment.

Top1Infer

Top-1 inference method.

classmethod add_CLI_args(parser: ArgumentParser) None[source]#

Add top-K inference method hyperparameters to CLI argument parser.

Parameters

parser (argparse.ArgumentParser) – CLI argument parser.

Return type

None

See also

lmp.script.gen_txt

Use pre-trained language model checkpoint to generate continual text of given text segment.

Examples

>>> import argparse
>>> from lmp.infer import TopKInfer
>>> parser = argparse.ArgumentParser()
>>> TopKInfer.infer_parser(parser)
>>> args = parser.parse_args(['--k', '10'])
>>> assert args.k == 10
gen(model: BaseModel, tknzr: BaseTknzr, txt: str) str[source]#

Generate continual text conditioned on given text segment.

Top-K inference algorithm is structured as follow:

  1. Encode input text as 1 sequence batch.

  2. Remove token ids after <eos> since model is not trained to predict tokens after seeing <eos>.

  3. Loop over conditional token ids to generate conditional hidden states.

  4. Loop to generate token ids. In each iteration, generated token id was choosed so that it is one of the top-K highest probabilities from next token id probability distribution. Generation loop stops when <eos> is generated or maximum length constraint is violated.

  5. Decode generated token ids into text and return.

Parameters
  • model (BaseModel) – Pre-trained language model which will be used to generate text.

  • tknzr (BaseTknzr) – Pre-trained tokenizer which performs text encoding and decoding.

  • txt (str) – Text segment which the generation process is conditioned on.

Returns

Generated text.

Return type

str