Source code for lmp.dset._wiki_text_2

"""Wiki-Text-2 dataset."""

import io
import os
import zipfile
from typing import ClassVar, List, Optional

import lmp.vars
from lmp.dset._base import BaseDset


[docs]class WikiText2Dset(BaseDset): """Wiki-Text-2 dataset. Wiki-Text-2 :footcite:`merity2017pointer` is part of the WikiText Long Term Dependency Language Modeling Dataset. See `Wiki-Text`_ for more details. .. _`Wiki-Text`: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Here are the statistics of each supported version. Tokens are separated by whitespaces. +-----------+-------------------+--------------------------+--------------------------+ | Version | Number of samples | Maximum number of tokens | Minimum number of tokens | +===========+===================+==========================+==========================+ | ``test`` | 60 | 14299 | 461 | +-----------+-------------------+--------------------------+--------------------------+ | ``train`` | 600 | 17706 | 281 | +-----------+-------------------+--------------------------+--------------------------+ | ``valid`` | 60 | 18855 | 778 | +-----------+-------------------+--------------------------+--------------------------+ Parameters ---------- ver: Optional[str], default: None Version of the dataset. Set to ``None`` to use the default version ``self.__class__.df_ver``. Attributes ---------- df_ver: typing.ClassVar[str] Default version is ``'train'``. dset_name: typing.ClassVar[str] CLI name of Wiki-Text-2 dataset is ``wiki-text-2``. spls: list[str] All samples in the dataset. ver: str Version of the dataset. vers: typing.ClassVar[list[str]] Supported versions including ``'train'``, ``'test'`` and ``'valid'``. Examples -------- >>> from lmp.dset import WikiText2Dset >>> dset = WikiText2Dset(ver='test') >>> dset[0][:31] 'Robert <unk> is an English film' """ df_ver: ClassVar[str] = 'train' dset_name: ClassVar[str] = 'wiki-text-2' vers: ClassVar[List[str]] = ['test', 'train', 'valid'] def __init__(self, *, ver: Optional[str] = None): super().__init__(ver=ver) # Make sure dataset files exist. self.download_dataset() # Read dataset from the specified version. # Each line is normalized. with open(os.path.join(lmp.vars.DATA_PATH, f'wiki.{self.ver}.tokens'), 'r') as text_file: lines = [self.norm(line) for line in text_file.readlines()] # Wiki-text-2 is consist of Wiki articles. # Each article is consist of one main section, many subsections and nested subsections. # A main section is begin with a single `=` and end with a single `=`. # A subsection is begin with `= =` and end with `= =`. # A nested subsection is begin with more than 2 `=` and end with the same amount of `=`. # We treat an article as one text passage. # Thus we loop through lines to find all sections and subsections of an article. article = '' for line_idx, line in enumerate(lines): # Discard empty lines. if not line: continue # Each article is treated as a text passage. # The first line is empty, so `article` is added as condition. # Some lines start and end with single `=` but is not section title. # Thus `not lines[line_idx - 1] and not lines[line_idx + 1]` is added as condition. if ( article and line.startswith('=') and not line.startswith('= =') and line.endswith('=') and not lines[line_idx - 1] and not lines[line_idx + 1] ): # Flush previous article and start recording new article. self.spls.append(article.strip()) article = line + ' ' else: # Record article. article += line + ' ' # Flush the last remaining article. self.spls.append(article)
[docs] @classmethod def download_dataset(cls) -> None: """Download Wiki-text-2 dataset. Download zip file from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip and extract raw files from zip file. Raw files are named as ``wiki.ver.tokens``, where ``ver`` is the version of the dataset. After extracting raw files the downloaded zip file will be deleted. Returns ------- None """ # Download zip file path. zip_file_path = os.path.join(lmp.vars.DATA_PATH, 'wiki-text-2.zip') # Original source is no longer available. url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip' # Avoid duplicated download by checking whether all raw files exists. already_downloaded = True for ver in cls.vers: raw_file_path = os.path.join(lmp.vars.DATA_PATH, f'wiki.{ver}.tokens') if not os.path.exists(raw_file_path): already_downloaded = False if already_downloaded: return # Download dataset. BaseDset.download_file(mode='binary', download_path=zip_file_path, url=url) # Extract dataset from zip file. with zipfile.ZipFile(zip_file_path, 'r') as input_zipfile: for ver in cls.vers: with io.TextIOWrapper(input_zipfile.open(f'wikitext-2/wiki.{ver}.tokens', 'r')) as input_binary_file: data = input_binary_file.read() with open(os.path.join(lmp.vars.DATA_PATH, f'wiki.{ver}.tokens'), 'w') as output_text_file: output_text_file.write(data) # Remove downloaded zip file. os.remove(zip_file_path)