Source code for lmp.dset._base

"""Dataset base class."""

import os
import re
import unicodedata
from typing import ClassVar, Iterator, List, Optional

import requests
import torch.utils.data

import lmp.util.validate
import lmp.vars


[docs]class BaseDset(torch.utils.data.Dataset): """Dataset base class. Most datasets need to be downloaded from the web. Only some of them can be generated locally. Datasets are downloaded / generated automatically if they are not on your local machine. No downloading or generation are executed if dataset files already exist on your local machine. Parameters ---------- ver: Optional[str], default: None Version of the dataset. Set to ``None`` to use the default version ``self.__class__.df_ver``. Attributes ---------- df_ver: typing.ClassVar[str] Default version of the dataset. dset_name: typing.ClassVar[str] CLI name of the dataset. Only used to parse CLI arguments. spls: list[str] All samples in the dataset. ver: str Version of the dataset. vers: typing.ClassVar[list[str]] List of dataset supported versions. See Also -------- :doc:`lmp.dset </dset/index>` All available datasets. """ df_ver: ClassVar[str] = '' dset_name: ClassVar[str] = 'base' vers: ClassVar[List[str]] = [] def __init__(self, *, ver: Optional[str] = None): super().__init__() # Use default version of the dataset. if ver is None: ver = self.__class__.df_ver # `ver` validation. lmp.util.validate.raise_if_not_instance(val=ver, val_name='ver', val_type=str) lmp.util.validate.raise_if_not_in(val=ver, val_name='ver', val_range=self.__class__.vers) self.ver = ver self.spls: List[str] = []
[docs] def __getitem__(self, idx: int) -> str: """Sample text using index. Parameters ---------- idx: int Sample index. Returns ------- str The sample whose index equals to ``idx``. """ # `idx` validation. lmp.util.validate.raise_if_not_instance(val=idx, val_name='idx', val_type=int) return self.spls[idx]
[docs] def __iter__(self) -> Iterator[str]: """Iterate through each sample in the dataset. Yields ------ str One sample in ``self.spls``, ordered by sample indices. """ for spl in self.spls: yield spl
[docs] def __len__(self) -> int: """Get dataset size. Returns ------- int Number of samples in the dataset. """ return len(self.spls)
[docs] @staticmethod def download_file(mode: str, download_path: str, url: str) -> None: """Download file from ``url``. Arguments --------- mode: str Can only be ``'binary'`` or ``'text'``. download_path: str File path of the downloaded file. url: str URL of the file to be downloaded. Returns ------- None """ # `mode` type guard. lmp.util.validate.raise_if_not_instance(val=mode, val_name='mode', val_type=str) lmp.util.validate.raise_if_not_in(val=mode, val_name='mode', val_range=['text', 'binary']) # `download_path` type guard. lmp.util.validate.raise_if_not_instance(val=download_path, val_name='download_path', val_type=str) lmp.util.validate.raise_if_empty_str(val=download_path, val_name='download_path') lmp.util.validate.raise_if_is_directory(path=download_path) # `url` type guard. lmp.util.validate.raise_if_not_instance(val=url, val_name='url', val_type=str) lmp.util.validate.raise_if_empty_str(val=url, val_name='url') # Create folder if not exists. download_dir = os.path.abspath(os.path.join(download_path, os.pardir)) if not os.path.exists(download_dir): os.makedirs(download_dir) # Download and output file. if mode == 'binary': with requests.get(url=url) as res, open(download_path, 'wb') as binary_file: binary_file.write(res.content) else: with requests.get(url=url) as res, open(download_path, 'w', encoding='utf-8') as text_file: text_file.write(res.text)
[docs] @staticmethod def norm(txt: str) -> str: """Text normalization. Text will be NFKC normalized. Whitespaces are collapsed and strip from both ends. Parameters ---------- txt: str Text to be normalized. Returns ------- str Normalized text. See Also -------- unicodedata.normalize Python built-in unicode normalization. Examples -------- >>> from lmp.dset import BaseDset >>> BaseDset.norm('123456789') '123456789' """ return re.sub(r'\s+', ' ', unicodedata.normalize('NFKC', txt).strip())