r"""Use pre-trained language model to calculate perplexity on given text.
One must first run the script :doc:`lmp.script.train_model </script/train_model>` before running this script.
See Also
--------
:doc:`lmp.model </model/index>`
All available language models.
:doc:`lmp.script.eval_dset_ppl </script/eval_dset_ppl>`
Use pre-trained language model to calculate average perplexity on a particular dataset.
:doc:`lmp.script.train_model </script/train_model>`
Train language model.
Examples
--------
The following example used pre-trained language model under experiment ``my_model_exp`` to calculate perplexity of the
given text ``"Hello world"``.
It use checkpoint number ``5000`` to perform evaluation.
.. code-block::
python -m lmp.script.eval_txt_ppl \
--ckpt 5000 \
--exp_name my_model_exp \
--txt "Hello world"
The following example calculate perplexity using the last checkpoint of experiment ``my_model_exp``.
.. code-block::
python -m lmp.script.eval_txt_ppl \
--ckpt -1 \
--exp_name my_model_exp \
--txt "Hello world"
You can use ``-h`` or ``--help`` options to get a list of supported CLI arguments.
.. code-block:: shell
python -m lmp.script.eval_txt_ppl -h
"""
import argparse
import sys
from typing import List
import torch
import lmp.model
import lmp.util.cfg
import lmp.util.metric
import lmp.util.model
import lmp.util.rand
import lmp.util.tknzr
import lmp.util.validate
[docs]def parse_args(argv: List[str]) -> argparse.Namespace:
"""Parse CLI arguments.
Parameters
----------
argv: list[str]
List of CLI arguments.
See Also
--------
sys.argv
Python CLI arguments interface.
Returns
-------
argparse.Namespace
Parsed CLI arguments.
"""
# Create parser.
parser = argparse.ArgumentParser(
'python -m lmp.script.eval_txt_ppl',
description='Use pre-trained language model to calculate perplexity on given text.',
)
parser.add_argument(
'--ckpt',
default=-1,
help='''
Pre-trained language model checkpoint.
Set to ``-1`` to use the last checkpoint.
Default is ``-1``.
''',
type=int,
)
parser.add_argument(
'--exp_name',
default='my_model_exp',
help='''
Pre-trained language model experiment name.
Default is ``my_model_exp``.
''',
type=str,
)
parser.add_argument(
'--txt',
default='hello world',
help='''
Text to calculate perplexity.
Default is ``hello world``.
''',
type=str,
)
parser.add_argument(
'--seed',
default=42,
help='''
Random seed.
Default is ``42``.
''',
type=int,
)
args = parser.parse_args(argv)
# `args.ckpt` validation.
lmp.util.validate.raise_if_wrong_ordered(vals=[-1, args.ckpt], val_names=['-1', 'args.ckpt'])
# `args.txt` validation.
lmp.util.validate.raise_if_empty_str(val=args.txt, val_name='args.txt')
return args
[docs]def main(argv: List[str]) -> None:
"""Script entry point.
Parameters
----------
argv: list[str]
List of CLI arguments.
Returns
-------
None
"""
# Parse CLI arguments.
args = parse_args(argv=argv)
# Set random seed for reproducibility.
lmp.util.rand.set_seed(seed=args.seed)
# Get model running device.
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
# Load pre-trained model configuration.
model_cfg = lmp.util.cfg.load(exp_name=args.exp_name)
# Load pre-trained tokenizer instance.
tknzr = lmp.util.tknzr.load(exp_name=model_cfg.tknzr_exp_name)
# Load pre-trained model instance.
model = lmp.util.model.load(ckpt=args.ckpt, exp_name=args.exp_name)
# Set model to evaluation mode.
# This turn off dropout layers in model.
model = model.eval()
# Move model to running device.
model = model.to(device)
# Encode text into token ids.
# We convert token ids into tensor and move to the same running device as model.
# Shape: (1, S)
batch_tkids = torch.LongTensor([tknzr.enc(txt=args.txt)]).to(device)
S = batch_tkids.size(1)
# Record BPC and loop through mini-batch by context windows.
# In practice we can use word or subword as token, so we shall call it bit-per-token instead of BPC.
# Naming it as BPC is simply because the convention.
batch_prev_states = None
bpc = 0.0
for ctx_idx in range(0, S, model_cfg.max_seq_len):
# Fetch context window.
ctx_batch_tkids = batch_tkids[..., ctx_idx:ctx_idx + model_cfg.max_seq_len + 1]
# Drop the remaining sequence-length-1 context window.
if ctx_batch_tkids.size(1) == 1:
break
# Construct language model evaluation format.
batch_cur_tkids = ctx_batch_tkids[..., :-1]
batch_next_tkids = ctx_batch_tkids[..., 1:]
# Get next token id probability distribution.
batch_tkids_pd, batch_cur_states = model.pred(
batch_cur_tkids=batch_cur_tkids,
batch_prev_states=batch_prev_states,
)
# Calculate negative log-likelihood -log(p).
nll = lmp.util.metric.nll(batch_tkids=batch_next_tkids, batch_tkids_pd=batch_tkids_pd, use_log2=True)
# Record BPC.
bpc += (nll / S).sum().item()
# Update hidden states.
batch_prev_states = batch_cur_states
# Convert BPC to perplexity.
ppl = pow(2, bpc)
# Output perplexity on given sample.
print(ppl)
if __name__ == '__main__':
main(argv=sys.argv[1:])