EventStream.transformer.nested_attention_model module¶
The nested attention core event stream GPT model.
- class EventStream.transformer.nested_attention_model.NAPPTForGenerativeSequenceModeling(config: StructuredTransformerConfig)[source]¶
Bases:
StructuredGenerationMixin,StructuredTransformerPreTrainedModelThe end-to-end model for nested attention generative sequence modelling.
This model is a subclass of
StructuredTransformerPreTrainedModeland is designed for generative pre-training over “event-stream” data, with inputs in the form ofPytorchBatchobjects. It is trained to solve the generative, multivariate, masked temporal point process problem over the defined measurements in the input data. It does so while respecting intra-event causal dependencies specified through the measurements_per_dep_graph_level specified in the config (aka the dependency graph).This model largely simply passes the input data through a
NestedAttentionPointProcessTransformerfollowed by aNestedAttentionGenerativeOutputLayer.- Parameters:¶
- config: StructuredTransformerConfig¶
The overall model configuration.
- Raises:¶
ValueError – If the model configuration does not indicate nested attention mode.
-
forward(batch: PytorchBatch, is_generation: bool =
False, **kwargs) GenerativeSequenceModelOutput[source]¶ This runs the full forward pass of the model.
- Parameters:¶
- batch: PytorchBatch¶
The batch of data to be transformed.
- is_generation: bool =
False¶ Whether or not the model is being used for generation.
- **kwargs
Additional keyword arguments, which are used for output structuring and are forwarded to the encoder. The model specifically looks for use_cache, output_attentions, and output_hidden_states keyword arguments, which control whether additional properties should be added to the output. In addition, the model also looks for the dep_graph_el_generation_target keyword argument, which is passed to the output layer.
- Returns:¶
The output of the model, which is a
GenerativeSequenceModelOutputobject.
-
prepare_inputs_for_generation(batch: PytorchBatch, past: dict[str, tuple] | None =
None, **kwargs) dict[str, Any][source]¶ Returns model keyword arguments that have been modified for generation purposes.
- Parameters:¶
- batch: PytorchBatch¶
The batch of data to be transformed.
- past: dict[str, tuple] | None =
None¶ The past state of the model, if any. If specified, it must be a dictionary containing both the seq_past key (the past of the sequential attention module) and a dep_graph_past key (the past of the dependency graph attention module). These inner past encodings are tuples containing the past values over prior layers and heads.
- **kwargs
Additional keyword arguments. If “use_cache” is set in the kwargs to False, then the past state is ignored. If not, then the past state is passed through the model to accelerate generation, if past is not None then the batch is trimmed to the last element in the sequence, and the sequential attention mask is pre-computed.
- Raises:¶
ValueError – If the past state is malformed or if there is a dep_graph_el_generation_target in the kwargs that is not None.
- class EventStream.transformer.nested_attention_model.NestedAttentionGenerativeOutputLayer(config: StructuredTransformerConfig)[source]¶
Bases:
GenerativeOutputLayerBaseThe output layer for the nested attention event stream model.
- TODO(mmcdermott):
Allow for use of NLL-beta throughout? https://github.com/mmcdermott/EventStreamGPT/issues/26
- Parameters:¶
- config: StructuredTransformerConfig¶
The overall model configuration.
- Raises:¶
ValueError – If the model configuration does not indicate nested attention mode.
-
forward(batch: PytorchBatch, encoded: FloatTensor, is_generation: bool =
False, dep_graph_el_generation_target: int | None =None) GenerativeSequenceModelOutput[source]¶ Returns the overall model output for the input batch.
It takes the final hidden states from the encoder and runs them through various output layers to predict subsequent event timing and contents. It’s difference from a conditionally independent variant is largely in that it predicts dependency graph elements sequentially, relying on prior graph elements at each stage.
- Parameters:¶
- batch: PytorchBatch¶
The batch of data to process.
- encoded: FloatTensor¶
The encoded representation of the input data. This is of shape (batch size, sequence length, dependency graph len, config.hidden_size). The last element of the dependency graph is always the whole-event embedding, and the first element of the dependency graph is always assumed to capture the time of the event.
- is_generation: bool =
False¶ Whether or not we are in generation mode. If so, the output predictions are for the next event for both time and event contents; if not, then we shift the event contents predictoin back by one event in order to align with the labels.
- dep_graph_el_generation_target: int | None =
None¶ If is_generation is True, this is the index of the dependency graph element for which we are generating for this pass. If None, we generate all elements of the dependency graph (even though, for a nested attention model, this is generally wrong as we need to generate dependency graph elements in dependency graph order).