EventStream.transformer.structured_attention module¶
This module contains the class for running “structured attention” that respects a dependency graph.
- class EventStream.transformer.structured_attention.StructuredAttention(seq_module: Module, dep_graph_module: Module)[source]¶
Bases:
ModuleA module for performing dependency-graph structured attention calculations.
This module is a container for shuffling input tensors to pass them to the nested modules for pooling events, processing event sequences, and processing intra-event dependency graph objects.
- Parameters:¶
-
forward(hidden_states: Tensor, seq_attention_mask: Tensor | None =
None, event_mask: Tensor | None =None, seq_module_kwargs: dict[str, Any] | None =None, dep_graph_module_kwargs: dict[str, Any] | None =None, prepend_graph_with_history_embeddings: bool =True, update_last_graph_el_to_history_embedding: bool =True) tuple[Tensor, dict[str, dict[str, Tensor | None] | None]][source]¶ Performs a structured attention forward pass.
This method implements several steps, which include summarizing input events into event embeddings, contextualizing these events via the history, and producing output embeddings by processing the historical context and dependency graph structure.
- Parameters:¶
The input embeddings corresponding to the different elements of the structured dependency graph, with the last element of the graph corresponding to a whole-event embedding.
- seq_attention_mask: Tensor | None =
None¶ Mask to avoid processing on padding token indices.
- event_mask: Tensor | None =
None¶ Mask to avoid processing on padding token indices.
- seq_module_kwargs: dict[str, Any] | None =
None¶ Additional keyword arguments to pass to the sequence module.
- dep_graph_module_kwargs: dict[str, Any] | None =
None¶ Additional keyword arguments to pass to the dependency graph module.
- prepend_graph_with_history_embeddings: bool =
True¶ If true, the history embeddings will be prepended to the dependency graph sequence. Default is True. This is set to false during generation if caching is enabled, as the prepended portion is contained in the cached past history.
- update_last_graph_el_to_history_embedding: bool =
True¶ If true, the last element of the dependency graph sequence will be updated with the history embedding. Default is True. This is set to false during generation, when the last element of the dependency graph sequence may not be the final element of the overall chain (if we are generated an internal element).
- Returns:¶
The dependency graph output and additional return arguments across both the sequence and dep graph modules.