Source code for lmp.script.gen_txt

r"""Use pre-trained language model checkpoint to generate continual text of given text segment.

One must first run the script :doc:`lmp.script.train_model </script/train_model>` before running this script.
This script use pre-trained language model checkpoint to generate continual text of given text segment.
Most inference (generation) methods are stochastic process, only some are deterministic.

See Also
--------
:doc:`lmp.infer </infer/index>`
  All available inference methods.
:doc:`lmp.model </model/index>`
  All available language models.
:doc:`lmp.script.train_model </script/train_model>`
  Train language model.

Examples
--------
The following example use ``"Hello world"`` as conditioned text segment to generate continual text with pre-trained
language model experiment ``my_model_exp``.
It use ``top-1`` inference method to generate continual text.

.. code-block::

  python -m lmp.script.gen_txt top-1 \
    --ckpt 5000 \
    --exp_name my_model_exp \
    --max_seq_len 128 \
    --txt "Hello world"

The following example use the same conditioned text segment but inferencing with ``top-k`` inference method.

.. code-block::

  python -m lmp.script.gen_txt top-1 \
    --ckpt 5000 \
    --exp_name my_model_exp \
    --k 10 \
    --max_seq_len 128 \
    --txt "Hello world"

You can use ``-h`` or ``--help`` options to get a list of available inference methods.

.. code-block:: shell

  python -m lmp.script.gen_txt -h

You can use ``-h`` or ``--help`` options on a specific inference method to get a list of supported CLI arguments.

.. code-block:: shell

  python -m lmp.script.gen_txt top-k -h
"""

import argparse
import sys
from typing import List

import torch

import lmp.infer
import lmp.model
import lmp.util.cfg
import lmp.util.infer
import lmp.util.model
import lmp.util.rand
import lmp.util.tknzr
import lmp.util.validate


[docs]def parse_args(argv: List[str]) -> argparse.Namespace: """Parse CLI arguments. Parameters ---------- argv: list[str] List of CLI arguments. See Also -------- sys.argv Python CLI arguments interface. Returns ------- argparse.Namespace Parsed CLI arguments. """ # Create parser. parser = argparse.ArgumentParser( 'python -m lmp.script.gen_txt', description='Use pre-trained language model checkpoint to generate continual text of given text segment.', ) # Use inference method name to create subparser for all inference methods. subparsers = parser.add_subparsers(dest='infer_name', required=True) for infer_name, infer_type in lmp.infer.INFER_OPTS.items(): infer_subparser = subparsers.add_parser(infer_name, description=f'Use {infer_type.__name__} as inference method.') group = infer_subparser.add_argument_group('language model inference hyperparameters') group.add_argument( '--ckpt', default=-1, help=''' Pre-trained language model checkpoint. Set to ``-1`` to use the last checkpoint. Default is ``-1``. ''', type=int, ) group.add_argument( '--exp_name', default='my_model_exp', help=''' Pre-trained language model experiment name. Default is ``my_model_exp``. ''', type=str, ) group.add_argument( '--max_seq_len', default=32, help=''' Maximum sequence length constraint. Default is ``32``. ''', type=int, ) group.add_argument( '--txt', default='', help=''' Text segment which the generation process is condition on. Default is empty string. ''', type=str, ) group.add_argument( '--seed', default=42, help=''' Random seed. Default is ``42``. ''', type=int, ) # Add inference method specific arguments. infer_type.add_CLI_args(parser=infer_subparser) args = parser.parse_args(argv) # `args.ckpt` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[-1, args.ckpt], val_names=['-1', 'args.ckpt']) # `args.txt` validation. lmp.util.validate.raise_if_not_instance(val=args.txt, val_name='args.txt', val_type=str) return args
[docs]def main(argv: List[str]) -> None: """Script entry point. Parameters ---------- argv: list[str] List of CLI arguments. Returns ------- None """ # Parse CLI arguments. args = parse_args(argv=argv) # Set random seed for reproducibility. lmp.util.rand.set_seed(seed=args.seed) # Load pre-trained model configuration. model_cfg = lmp.util.cfg.load(exp_name=args.exp_name) # Load pre-trained tokenizer instance. tknzr = lmp.util.tknzr.load(exp_name=model_cfg.tknzr_exp_name) # Load pre-trained model instance. model = lmp.util.model.load(ckpt=args.ckpt, exp_name=args.exp_name) # Set model to evaluation model. # This turn off dropout layers in model. model = model.eval() # Get model running device. device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') # Move model to running device. model = model.to(device) # Get inference method. infer = lmp.util.infer.create(**args.__dict__) # Generate text with specified inference method. txt = infer.gen(model=model, tknzr=tknzr, txt=args.txt) # Output generate text. print(txt)
if __name__ == '__main__': main(argv=sys.argv[1:])