Source code for EventStream.transformer.generation.generation_stopping_criteria

# Sourced from
# https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/generation_stopping_criteria.py
# Then modified

from abc import ABC

from transformers.utils import add_start_docstrings

from ...data.types import PytorchBatch
from ..model_output import GenerativeSequenceModelPredictions

STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
    Args:
        `batch` (`PytorchBatch`): The input batch.
        `outputs` (`GenerativeSequenceModelPredictions`): The predicted outputs.
        kwargs:
            Additional stopping criteria specific kwargs.
    Return:
        `bool`. `False` indicates we should continue, `True` indicates we should stop.
"""


[docs] class StoppingCriteria(ABC): """Abstract base class for all stopping criteria that can be applied during generation.""" @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, batch: PytorchBatch, outputs: GenerativeSequenceModelPredictions, **kwargs) -> bool: raise NotImplementedError("StoppingCriteria needs to be subclassed")
[docs] class MaxLengthCriteria(StoppingCriteria): """This class can be used to stop generation whenever the full generated number of events exceeds `max_length`. Keep in mind for decoder-only type of transformers, this will include the initial prompted events. Args: max_length (`int`): The maximum length that the output sequence can have in number of events. """ def __init__(self, max_length: int): self.max_length = max_length @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, batch: PytorchBatch, outputs: GenerativeSequenceModelPredictions, **kwargs) -> bool: return batch.sequence_length >= self.max_length
[docs] class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, batch: PytorchBatch, outputs: GenerativeSequenceModelPredictions, **kwargs) -> bool: return any(criteria(batch, outputs) for criteria in self)