Source code for lmp.infer._base
"""Inference method base class."""
import abc
import argparse
from typing import Any, ClassVar
import torch
import lmp.util.validate
from lmp.model import BaseModel
from lmp.tknzr import BaseTknzr
[docs]class BaseInfer(abc.ABC):
"""Inference method abstract base class.
Implement basic functionalities for language model inference, including text generation and parsing inference
hyperparameters.
Parameters
----------
max_seq_len: str, default: 32
Maximum length constraint on generated token list.
One can use larger contraint compare to training.
kwargs: typing.Any, optional
Useless parameter.
Intently left for subclasses inheritance.
Attributes
----------
infer_name: ClassVar[str]
CLI name of the inference method.
Only used to parse CLI arguments.
max_seq_len: str
Maximum length constraint of generated token list.
See Also
--------
:doc:`lmp.infer </infer/index>`
All available inference methods.
"""
infer_name: ClassVar[str] = 'base'
def __init__(self, *, max_seq_len: int = 32, **kwargs: Any):
# `max_seq_len` validation.
lmp.util.validate.raise_if_not_instance(val=max_seq_len, val_name='max_seq_len', val_type=int)
lmp.util.validate.raise_if_wrong_ordered(vals=[1, max_seq_len, 1024], val_names=['1', 'max_seq_len', '1024'])
self.max_seq_len = max_seq_len
[docs] @classmethod
def add_CLI_args(cls, parser: argparse.ArgumentParser) -> None:
"""Add inference method hyperparameters to CLI argument parser.
Parameters
----------
parser: argparse.ArgumentParser
CLI argument parser.
Returns
-------
None
See Also
--------
:doc:`lmp.script.gen_txt </script/gen_txt>`
Use pre-trained language model checkpoint to generate continual text of given text segment.
"""
# `parser` validation.
lmp.util.validate.raise_if_not_instance(val=parser, val_name='parser', val_type=argparse.ArgumentParser)
[docs] @torch.no_grad()
@abc.abstractmethod
def gen(self, model: BaseModel, tknzr: BaseTknzr, txt: str) -> str:
"""Generate continual text conditioned on given text segment.
Parameters
----------
model: ~lmp.model.BaseModel
Pre-trained language model which will be used to generate text.
tknzr: ~lmp.tknzr.BaseTknzr
Pre-trained tokenizer which performs text encoding and decoding.
txt: str
Text segment which the generation process is conditioned on.
Returns
-------
str
Generated text.
"""
raise NotImplementedError