"""A vocabulary class for easy management of categorical data element options."""
from __future__ import annotations
import copy
import dataclasses
import math
from collections.abc import Sequence
from functools import cached_property
from io import TextIOBase
from textwrap import shorten, wrap
from typing import Generic, TypeVar, Union
import numpy as np
from sparklines import sparklines
from ..utils import COUNT_OR_PROPORTION, num_initial_spaces
VOCAB_ELEMENT = TypeVar("T")
NESTED_VOCAB_SEQUENCE = Union[VOCAB_ELEMENT, Sequence["NESTED_VOCAB_SEQUENCE"]]
[docs]
@dataclasses.dataclass
class Vocabulary(Generic[VOCAB_ELEMENT]):
"""Stores a vocabulary of observed elements of type `VOCAB_ELEMENT` ordered by frequency.
This class represents a vocabulary of observed elements of specifiable type `VOCAB_ELEMENT`. All
vocabularies include an "unknown" option, codified as the string `'UNK'`. Upon construction, the
vocabulary is sorted in order of decreasing frequency. The vocabulary can also be described for a
text-based visual representation of the contained elements and their relative frequency distribution.
Vocabulary elements can be arbitrary types _except_ for integers.
Attributes:
vocabulary: The vocabulary, stored as a plain list, beginning with 'UNK' and subsequently proceeding
in order of most frequently observed to least frequently observed.
obs_frequencies: The observed frequencies of elements of the vocabulary, stored as a plain list.
element_types: A set of the types of elements that are allowed in this vocabulary.
Raises:
ValueError: If an empty vocabulary is passed, a vocabulary with duplicates is passed, a vocabulary
with integer elements is passed, or a vocabulary whose length differs from the passed observation
frequencies.
Examples:
>>> vocab = Vocabulary(vocabulary=['apple', 'banana', 'UNK'], obs_frequencies=[3, 5, 2])
>>> vocab.vocabulary
['UNK', 'banana', 'apple']
>>> vocab.obs_frequencies
[0.2, 0.5, 0.3]
>>> len(vocab)
3
>>> vocab = Vocabulary(vocabulary=[], obs_frequencies=[])
Traceback (most recent call last):
...
ValueError: Empty vocabularies are not supported.
>>> vocab = Vocabulary(vocabulary=['apple'], obs_frequencies=[1, 2])
Traceback (most recent call last):
...
ValueError: self.vocabulary and self.obs_frequencies must have the same length. Got 1 and 2.
>>> vocab = Vocabulary(vocabulary=['apple', 'apple'], obs_frequencies=[1, 2])
Traceback (most recent call last):
...
ValueError: Vocabulary has duplicates. len(self.vocabulary) = 2, but len(set(self.vocabulary)) = 1.
>>> vocab = Vocabulary(vocabulary=['apple', 1], obs_frequencies=[1, 2])
Traceback (most recent call last):
...
ValueError: Integer elements in the vocabulary are not supported.
"""
# The vocabulary, beginning with 'UNK' and subsequently proceeding in order of most frequently observed to
# least frequently observed.
vocabulary: list[str | VOCAB_ELEMENT] | None = None
# The observed frequencies of elements of the vocabulary.
obs_frequencies: np.ndarray | list[float] | None = None
@cached_property
def idxmap(self) -> dict[VOCAB_ELEMENT, int]:
"""Returns a mapping from vocab element to vocabulary integer index.
Returns:
Dictionary mapping vocabulary elements to their index.
Example:
>>> vocab = Vocabulary(vocabulary=['apple', 'banana', 'UNK'], obs_frequencies=[3, 5, 2])
>>> vocab.idxmap
{'UNK': 0, 'banana': 1, 'apple': 2}
"""
return {v: i for i, v in enumerate(self.vocabulary)}
def __getitem__(self, q: int | VOCAB_ELEMENT) -> int | VOCAB_ELEMENT:
"""Gets vocabulary element or corresponding integer index for `q`.
If `q` is an integer index, returns the vocabulary element at that index. If it is a valid type to be
a member of the vocabulary, returns the integer index associated with that element, or 0 if that
element is not in the vocabulary (0 corresponds to the UNK index, so this is appropriate).
Args:
q: Query to fetch either the vocabulary element or its index.
Returns:
Vocabulary element at index q if q is an integer.
Index of the vocabulary element if q is a string.
Raises:
TypeError: if the query element is not an integer, the UNK sentinel value, or a member of the
allowed types for this vocabulary (`self.element_types`).
Example:
>>> vocab = Vocabulary(vocabulary=['apple', 'banana', 'UNK'], obs_frequencies=[3, 5, 2])
>>> vocab[1]
'banana'
>>> vocab['apple']
2
>>> vocab[3.4]
Traceback (most recent call last):
...
TypeError: Type <class 'float'> is not a valid type for this vocabulary.
"""
if type(q) is int:
return self.vocabulary[q]
else:
if (type(q) not in self.element_types) and (q != "UNK"):
raise TypeError(f"Type {type(q)} is not a valid type for this vocabulary.")
return self.idxmap.get(q, 0)
def __len__(self):
"""Returns the length of the vocabulary, including UNK."""
return len(self.vocabulary)
def __eq__(self, other: Vocabulary):
"""Returns True if other is an identical vocabulary.
Returns:
True if the type of self and other match, if their vocabulary lists are identical, and if their
observed frequencies list are identical up to a precision of 3 decimal points.
"""
return (
(type(self) is type(other))
and (self.vocabulary == other.vocabulary)
and (np.array(self.obs_frequencies).round(3) == np.array(other.obs_frequencies).round(3)).all()
)
def __post_init__(self):
"""Validates and sorts the vocabulary."""
if len(self.vocabulary) == 0:
raise ValueError("Empty vocabularies are not supported.")
if len(self.vocabulary) != len(self.obs_frequencies):
raise ValueError(
"self.vocabulary and self.obs_frequencies must have the same length. Got "
f"{len(self.vocabulary)} and {len(self.obs_frequencies)}."
)
vocab_set = set(self.vocabulary)
if len(self.vocabulary) != len(vocab_set):
raise ValueError(
f"Vocabulary has duplicates. len(self.vocabulary) = {len(self.vocabulary)}, but "
f"len(set(self.vocabulary)) = {len(vocab_set)}."
)
self.element_types = {type(v) for v in self.vocabulary if v != "UNK"}
if int in self.element_types:
raise ValueError("Integer elements in the vocabulary are not supported.")
self.obs_frequencies = np.array(self.obs_frequencies)
self.obs_frequencies = self.obs_frequencies / self.obs_frequencies.sum()
vocab = copy.deepcopy(self.vocabulary)
obs_frequencies = self.obs_frequencies
if "UNK" in vocab_set:
unk_index = vocab.index("UNK")
unk_freq = obs_frequencies[unk_index]
obs_frequencies = np.delete(obs_frequencies, unk_index)
del vocab[unk_index]
else:
unk_freq = 0
idx = np.lexsort((vocab, obs_frequencies))[::-1]
self.vocabulary = ["UNK"] + [vocab[i] for i in idx]
self.obs_frequencies = list(np.concatenate(([unk_freq], obs_frequencies[idx])))
[docs]
def filter(
self, total_observations: int | None, min_valid_element_freq: COUNT_OR_PROPORTION | None
) -> Vocabulary:
"""Filters the vocabulary elements to only those occurring sufficiently often.
Filters out infrequent elements from the vocabulary, pushing the dropped elements into the UNK
element. The cutoff frequency can be specified either as an integral count or as a floating point
proportion. If specified as a count, it will be converted to a proportion via `total_observations`, as
the internal observed frequency list is stored in terms of frequencies, not counts. Even if UNK occurs
in the original vocabulary with frequency below this cut off, it will be retained as it is the
destination element for filtered elements, and its output frequency will be updated accordingly.
Args:
total_observations: How many total observations were there of vocabulary elements.
min_valid_element_freq: How frequently must an element have been observed to be retained?
Raises:
ValueError: If `min_valid_element_freq` is not a positive integer or a floating point number
between 0 and 1.
Example:
>>> vocab = Vocabulary(vocabulary=['apple', 'banana', 'UNK'], obs_frequencies=[5, 3, 2])
>>> vocab.filter(total_observations=10, min_valid_element_freq=0.4)
>>> vocab.vocabulary
['UNK', 'apple']
>>> vocab.obs_frequencies
[0.5, 0.5]
>>> vocab = Vocabulary(vocabulary=['apple', 'banana', 'UNK'], obs_frequencies=[5, 3, 2])
>>> vocab.filter(total_observations=10, min_valid_element_freq=4)
>>> vocab.vocabulary
['UNK', 'apple']
>>> vocab.obs_frequencies
[0.5, 0.5]
>>> vocab = Vocabulary(vocabulary=['apple', 'banana', 'UNK'], obs_frequencies=[5, 3, 2])
>>> vocab.filter(total_observations=10, min_valid_element_freq=None)
>>> vocab.vocabulary
['UNK', 'apple', 'banana']
>>> vocab.filter(total_observations=10, min_valid_element_freq=1.02)
Traceback (most recent call last):
...
ValueError: Can only filter vocabularies by floats in (0, 1) or ints > 1; got <class 'float'> 1.02
>>> vocab.filter(total_observations=10, min_valid_element_freq="0.02")
Traceback (most recent call last):
...
ValueError: Can only filter vocabularies by floats in (0, 1) or ints > 1; got <class 'str'> 0.02
>>> vocab.filter(total_observations=10, min_valid_element_freq=0)
Traceback (most recent call last):
...
ValueError: Can only filter vocabularies by floats in (0, 1) or ints > 1; got <class 'int'> 0
"""
if min_valid_element_freq is None:
return
try:
if 0 < min_valid_element_freq and min_valid_element_freq < 1:
pass
elif min_valid_element_freq >= 1 and min_valid_element_freq == round(min_valid_element_freq):
min_valid_element_freq /= total_observations
else:
raise ValueError(
"Can only filter vocabularies by floats in (0, 1) or ints > 1; got "
f"{type(min_valid_element_freq)} {min_valid_element_freq}"
)
except TypeError as e:
raise ValueError(
"Can only filter vocabularies by floats in (0, 1) or ints > 1; got "
f"{type(min_valid_element_freq)} {min_valid_element_freq}"
) from e
# np.searchsorted(a, v, side='right') returns i such that
# a[i-1] <= v < a[i]
# So, np.searchsorted(-self.obs_frequencies[1:], -min_valid_element_freq, side='left') returns i s.t.
# -self.obs_frequencies[i+1-1] <= -min_valid_element_freq < -self.obs_frequencies[i+1]
# <=>
# self.obs_frequencies[i] >= min_valid_element_freq > self.obs_frequencies[i+1]
# which is precisely the index i such that self.obs_frequencies[:i+1] are >= min_valid_element_freq
# and self.obs_frequencies[i+1:] are < min_valid_element_freq
self.obs_frequencies = np.array(self.obs_frequencies)
idx = np.searchsorted(-self.obs_frequencies[1:], -min_valid_element_freq, side="right")
# Now, we need to filter the vocabulary elements, but also put anything dropped in the UNK bucket.
self.obs_frequencies[0] += self.obs_frequencies[idx + 1 :].sum()
self.vocabulary = self.vocabulary[: idx + 1]
self.obs_frequencies = self.obs_frequencies[: idx + 1]
if hasattr(self, "idxmap"):
delattr(self, "idxmap")
self.obs_frequencies = list(self.obs_frequencies)
[docs]
def describe(
self,
line_width: int = 60,
wrap_lines: bool = True,
n_head: int = 3,
n_tail: int = 2,
stream: TextIOBase | None = None,
) -> int | None:
"""Prints or outputs to a stream a text-based visual representation of the vocabulary.
This both lists the head and tail of the vocabulary but also produces a sparklines representation of
the relative frequency distribution of vocabulary elements observed. In the printed head and tail
elements, UNK is skipped. If more elements are in the vocabulary than the printed elements, ellipsis
will denote the skipped elements.
Args:
line_width: The maximum width of each line in the description.
wrap_lines: Whether to wrap lines that exceed the `line_width`.
n_head: The number of high-frequency elements to include in the description.
n_tail: The number of low-frequency elements to include in the description.
stream: The stream to write the description to. If `None`, the description is printed to stdout.
Returns:
The number of characters written to the stream if a stream was provided, otherwise `None`.
Example:
>>> vocab = Vocabulary(
... vocabulary=['apple', 'banana', 'pear', 'UNK'],
... obs_frequencies=[3, 4, 1, 2],
... )
>>> vocab.describe(n_head=2, n_tail=1, wrap_lines=False)
4 elements, 20.0% UNKs
Frequencies: █▆▁
Elements:
(40.0%) banana
(30.0%) apple
(10.0%) pear
>>> vocab.describe(n_head=1, n_tail=0, wrap_lines=False)
4 elements, 20.0% UNKs
Frequencies: █▆▁
Examples:
(40.0%) banana
...
>>> vocab.describe(n_head=1, n_tail=0, wrap_lines=False, line_width=10)
4 [...]
[...]
Examples:
[...]
...
>>> vocab.describe(n_head=1, n_tail=0, wrap_lines=True, line_width=10)
4
elements,
20.0% UNKs
Frequencie
s:
Examples:
(40.0%)
banana
...
"""
lines = []
lines.append(f"{len(self)} elements, {self.obs_frequencies[0]*100:.1f}% UNKs")
sparkline_prefix = "Frequencies:"
W = line_width - len(sparkline_prefix) - 2
if W > len(self):
freqs = self.obs_frequencies[1:]
else:
freqs = self.obs_frequencies[1 : len(self) : int(math.ceil(len(self) / W))]
lines.append(f"{sparkline_prefix} {sparklines(freqs)[0]}")
if len(self) - 1 <= (n_head + n_tail):
lines.append("Elements:")
for v, f in zip(self.vocabulary[1:], self.obs_frequencies[1:]):
lines.append(f" ({f*100:.1f}%) {v}")
else:
lines.append("Examples:")
for i in range(n_head):
lines.append(f" ({self.obs_frequencies[i+1]*100:.1f}%) {self.vocabulary[i+1]}")
lines.append(" ...")
for i in range(n_tail):
lines.append(f" ({self.obs_frequencies[-n_tail+i]*100:.1f}%) {self.vocabulary[-n_tail+i]}")
line_indents = [num_initial_spaces(line) for line in lines]
if wrap_lines:
new_lines = []
for line, ind in zip(lines, line_indents):
new_lines.extend(
wrap(line, width=line_width, initial_indent="", subsequent_indent=(" " * ind))
)
lines = new_lines
else:
lines = [
shorten(line, width=line_width, initial_indent=(" " * ind))
for line, ind in zip(lines, line_indents)
]
desc = "\n".join(lines)
if stream is None:
print(desc)
return
return stream.write(desc)