Source code for lmp.util.model

"""Model utilities."""

import os
import re
from typing import Any, List

import torch

import lmp.util.validate
import lmp.vars
from lmp.model import MODEL_OPTS, BaseModel


[docs]def create(model_name: str, **kwargs: Any) -> BaseModel: """Create language model instance by language model's name. Language model's arguments are collected in ``**kwargs`` and are passed directly to language model's constructor. Parameters ---------- model_name: str Name of the language model to create. kwargs: typing.Any, optional Model's hyperparameters. Returns ------- ~lmp.model.BaseModel Language model instance. See Also -------- :doc:`lmp.model </model/index>` All available language models. :doc:`lmp.tknzr </tknzr/index>` All available tokenizers. Examples -------- >>> from lmp.model import ElmanNet >>> from lmp.tknzr import CharTknzr >>> import lmp.util.model >>> tknzr = CharTknzr() >>> model = lmp.util.model.create(model_name=ElmanNet.model_name, tknzr=tknzr) >>> assert isinstance(model, ElmanNet) """ # `model_name` validation. lmp.util.validate.raise_if_not_instance(val=model_name, val_name='model_name', val_type=str) lmp.util.validate.raise_if_not_in(val=model_name, val_name='model_name', val_range=list(MODEL_OPTS.keys())) return MODEL_OPTS[model_name](**kwargs)
[docs]def save(ckpt: int, exp_name: str, model: BaseModel) -> None: """Save model checkpoint. .. danger:: This method overwrite existing files. Make sure you know what you are doing before calling this method. Parameters ---------- ckpt: int Saving checkpoint number. exp_name: int Language model training experiment name. model: lmp.model.BaseModel Model to be saved. Returns ------- None See Also -------- ~load Load pre-trained language model instance by checkpoint and experiment name. Examples -------- >>> from lmp.model import ElmanNet >>> from lmp.tknzr import CharTknzr >>> import lmp.util.model >>> tknzr = CharTknzr() >>> model = ElmanNet(tknzr=tknzr) >>> lmp.util.model.save(ckpt=0, exp_name='test', model=model) None """ # `ckpt` validation. lmp.util.validate.raise_if_not_instance(val=ckpt, val_name='ckpt', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[0, ckpt], val_names=['0', 'ckpt']) # `exp_name` validation. lmp.util.validate.raise_if_not_instance(val=exp_name, val_name='exp_name', val_type=str) lmp.util.validate.raise_if_empty_str(val=exp_name, val_name='exp_name') # `save_dir_path` validation save_dir_path = os.path.join(lmp.vars.EXP_PATH, exp_name) lmp.util.validate.raise_if_is_file(path=save_dir_path) if not os.path.exists(save_dir_path): os.makedirs(save_dir_path) # `save_path` validation. save_file_path = os.path.join(save_dir_path, f'model-{ckpt}.pt') lmp.util.validate.raise_if_is_directory(path=save_file_path) # Save model. torch.save(model, save_file_path)
[docs]def load(ckpt: int, exp_name: str) -> BaseModel: """Load pre-trained language model instance by checkpoint and experiment name. Load pre-trained language model from path ``project_root/exp/exp_name``. Parameters ---------- ckpt: int Saving checkpoint number. Set to ``-1`` to load the last checkpoint. exp_name: str Pre-trained language model experiment name. Returns ------- ~lmp.model.BaseModel Pre-trained language model instance. See Also -------- :doc:`lmp.model </model/index>` All available language models. Examples -------- >>> from lmp.model import ElmanNet >>> from lmp.tknzr import CharTknzr >>> import lmp.util.model >>> tknzr = CharTknzr() >>> model = ElmanNet(tknzr=tknzr) >>> lmp.util.model.save(ckpt=0, exp_name='test', model=model) >>> load_model = lmp.util.model.load(ckpt=0, exp_name='test') >>> assert torch.all(load_model.emb.weight == model.emb.weight) """ # `ckpt` validation. lmp.util.validate.raise_if_not_instance(val=ckpt, val_name='ckpt', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[-1, ckpt], val_names=['-1', 'ckpt']) # `exp_name` validation. lmp.util.validate.raise_if_not_instance(val=exp_name, val_name='exp_name', val_type=str) lmp.util.validate.raise_if_empty_str(val=exp_name, val_name='exp_name') # `ckpt_dir_path` validation. ckpt_dir_path = os.path.join(lmp.vars.EXP_PATH, exp_name) lmp.util.validate.raise_if_is_file(path=ckpt_dir_path) # Load the last checkpoint if `ckpt == -1`. if ckpt == -1: for ckpt_file_name in os.listdir(ckpt_dir_path): match = re.match(r'model-(\d+).pt', ckpt_file_name) if match is None: continue ckpt = max(int(match.group(1)), ckpt) # `ckpt_file_path` validation. ckpt_file_path = os.path.join(ckpt_dir_path, f'model-{ckpt}.pt') lmp.util.validate.raise_if_is_directory(path=ckpt_file_path) return torch.load(ckpt_file_path)
[docs]def list_ckpts(exp_name: str, first_ckpt: int, last_ckpt: int) -> List[int]: r"""List all pre-trained model checkpoints from ``first_ckpt`` to ``last_ckpt``. The last checkpoint is included. Parameters ---------- exp_name: str Pre-trained language model experiment name. first_ckpt: int First checkpoint to include. Set to ``-1`` to include only the last checkpoint. last_ckpt: int Last checkpoint to include. Set to ``-1`` to include all checkpoints whose number is greater than ``first_ckpt``. Returns ------- list[int] All available checkpoints of the experiment. Checkpoints are sorted in ascending order. """ # `exp_name` validation. lmp.util.validate.raise_if_not_instance(val=exp_name, val_name='exp_name', val_type=str) lmp.util.validate.raise_if_empty_str(val=exp_name, val_name='exp_name') # `first_ckpt` validation. lmp.util.validate.raise_if_not_instance(val=first_ckpt, val_name='first_ckpt', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[-1, first_ckpt], val_names=['-1', 'first_ckpt']) # `last_ckpt` validation. lmp.util.validate.raise_if_not_instance(val=last_ckpt, val_name='last_ckpt', val_type=int) lmp.util.validate.raise_if_wrong_ordered(vals=[-1, last_ckpt], val_names=['-1', 'last_ckpt']) # `ckpt_dir_path` validation. ckpt_dir_path = os.path.join(lmp.vars.EXP_PATH, exp_name) lmp.util.validate.raise_if_is_file(path=ckpt_dir_path) ckpt_list = [] for ckpt_file_name in os.listdir(ckpt_dir_path): match = re.match(r'model-(\d+).pt', ckpt_file_name) if match is None: continue ckpt_list.append(int(match.group(1))) if first_ckpt == -1: return [max(ckpt_list)] if last_ckpt == -1: last_ckpt = max(ckpt_list) return sorted(list(filter(lambda ckpt: first_ckpt <= ckpt <= last_ckpt, ckpt_list)))