EventStream.transformer.structured_attention module

This module contains the class for running “structured attention” that respects a dependency graph.

class EventStream.transformer.structured_attention.StructuredAttention(seq_module: Module, dep_graph_module: Module)[source]

Bases: Module

A module for performing dependency-graph structured attention calculations.

This module is a container for shuffling input tensors to pass them to the nested modules for pooling events, processing event sequences, and processing intra-event dependency graph objects.

Parameters:
seq_module: Module

The module responsible for processing sequences.

dep_graph_module: Module

The module responsible for processing dependency graphs.

forward(hidden_states: Tensor, seq_attention_mask: Tensor | None = None, event_mask: Tensor | None = None, seq_module_kwargs: dict[str, Any] | None = None, dep_graph_module_kwargs: dict[str, Any] | None = None, prepend_graph_with_history_embeddings: bool = True, update_last_graph_el_to_history_embedding: bool = True) tuple[Tensor, dict[str, dict[str, Tensor | None] | None]][source]

Performs a structured attention forward pass.

This method implements several steps, which include summarizing input events into event embeddings, contextualizing these events via the history, and producing output embeddings by processing the historical context and dependency graph structure.

Parameters:
hidden_states: Tensor

The input embeddings corresponding to the different elements of the structured dependency graph, with the last element of the graph corresponding to a whole-event embedding.

seq_attention_mask: Tensor | None = None

Mask to avoid processing on padding token indices.

event_mask: Tensor | None = None

Mask to avoid processing on padding token indices.

seq_module_kwargs: dict[str, Any] | None = None

Additional keyword arguments to pass to the sequence module.

dep_graph_module_kwargs: dict[str, Any] | None = None

Additional keyword arguments to pass to the dependency graph module.

prepend_graph_with_history_embeddings: bool = True

If true, the history embeddings will be prepended to the dependency graph sequence. Default is True. This is set to false during generation if caching is enabled, as the prepended portion is contained in the cached past history.

update_last_graph_el_to_history_embedding: bool = True

If true, the last element of the dependency graph sequence will be updated with the history embedding. Default is True. This is set to false during generation, when the last element of the dependency graph sequence may not be the final element of the overall chain (if we are generated an internal element).

Returns:

The dependency graph output and additional return arguments across both the sequence and dep graph modules.