r"""WNLI dataset."""
import io
import os
import re
import zipfile
from typing import ClassVar, List, Optional
# Typeshed for `pandas` is under development, we will ignore type check on `pandas` until `pandas` typeshed finish its
# development and release stable version.
import pandas as pd # type: ignore
import lmp.vars
from lmp.dset._base import BaseDset
[docs]class WNLIDset(BaseDset):
"""Winograd NLI dataset.
Winograd NLI is a relaxation of the Winograd Schema Challenge :footcite:`levesque2012winograd` proposed as part of
the GLUE :footcite:`wang2018glue` benchmark.
This dataset only extract sentences from WNLI and no NLI labels were used.
Here are the statistics of each supported version.
Tokens are separated by whitespaces.
+-----------+-------------------+--------------------------+--------------------------+
| Version | Number of samples | Maximum number of tokens | Minimum number of tokens |
+===========+===================+==========================+==========================+
| ``dev`` | 142 | 63 | 4 |
+-----------+-------------------+--------------------------+--------------------------+
| ``test`` | 292 | 60 | 4 |
+-----------+-------------------+--------------------------+--------------------------+
| ``train`` | 1270 | 63 | 3 |
+-----------+-------------------+--------------------------+--------------------------+
Parameters
----------
ver: Optional[str], default: None
Version of the dataset.
Set to ``None`` to use the default version ``self.__class__.df_ver``.
Attributes
----------
df_ver: typing.ClassVar[str]
Default version is ``train``.
dset_name: typing.ClassVar[str]
CLI name of WNLI dataset is ``WNLI``.
spls: list[str]
All samples in the dataset.
ver: str
Version of the dataset.
vers: typing.ClassVar[list[str]]
Supported versions including ``'train'``, ``'dev'`` and ``'test'``.
Examples
--------
>>> from lmp.dset import WNLIDset
>>> dset = WNLIDset(ver='test')
>>> dset[0]
Mark was timid .
"""
df_ver: ClassVar[str] = 'train'
dset_name: ClassVar[str] = 'WNLI'
vers: ClassVar[List[str]] = ['dev', 'test', 'train']
def __init__(self, *, ver: Optional[str] = None):
super().__init__(ver=ver)
# Make sure dataset files exist.
self.download_dataset()
# Read text from WNLI tsv file.
df = pd.read_csv(os.path.join(lmp.vars.DATA_PATH, f'wnli.{self.ver}.tsv'), sep='\t')
# Extract all sentences and perform text normalization.
spls = df['sentence1'].apply(self.norm).tolist() + df['sentence2'].apply(self.norm).tolist()
# Insert space before punctuation marks and abbreviations.
spls = list(map(lambda spl: re.sub(r'(\w)([,.!?:;"\'-])', r'\1 \2', spl), spls))
spls = list(map(lambda spl: re.sub(r'(["])(\w)', r'\1 \2', spl), spls))
spls = list(map(lambda spl: re.sub(r'(\w)(\'\w)\s+', r'\1 \2 ', spl), spls))
self.spls.extend(spls)
[docs] @classmethod
def download_dataset(cls) -> None:
"""Download WNLI dataset.
Download zip file from https://dl.fbaipublicfiles.com/glue/data/WNLI.zip and extract raw files from zip file.
Raw files are named as ``wnli.ver.tsv``, where ``ver`` is the version of the dataset.
After extracting raw files the downloaded zip file will be deleted.
Returns
-------
None
"""
# Download zip file path.
zip_file_path = os.path.join(lmp.vars.DATA_PATH, 'WNLI.zip')
# Original source is no longer available.
url = 'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip'
# Avoid duplicated download by checking whether all raw files exists.
already_downloaded = True
for ver in cls.vers:
raw_file_path = os.path.join(lmp.vars.DATA_PATH, f'wnli.{ver}')
if not os.path.exists(raw_file_path):
already_downloaded = False
if already_downloaded:
return
# Download dataset.
BaseDset.download_file(mode='binary', download_path=zip_file_path, url=url)
# Extract dataset from zip file.
with zipfile.ZipFile(zip_file_path, 'r') as input_zipfile:
for ver in cls.vers:
with io.TextIOWrapper(input_zipfile.open(f'WNLI/{ver}.tsv', 'r')) as input_binary_file:
data = input_binary_file.read()
with open(os.path.join(lmp.vars.DATA_PATH, f'wnli.{ver}.tsv'), 'w') as output_text_file:
output_text_file.write(data)
# Remove downloaded zip file.
os.remove(zip_file_path)