EventStream.transformer.lightning_modules.zero_shot_evaluator module

class EventStream.transformer.lightning_modules.zero_shot_evaluator.ESTForZeroShotClassificationLM(config: StructuredTransformerConfig | dict[str, Any], pretrained_weights_fp: Path, labeling_function: Labeler, max_new_events: int = 10)[source]

Bases: LightningModule

A PyTorch Lightning Module for a zero-shot classification via generation for an EST model.

build_metrics()[source]

Build the various torchmetrics we’ll use to track performance.

get_generative_predictions(batch: PytorchBatch) StreamClassificationModelOutput[source]

# capture num_samples to generate

log_metrics(results: StreamClassificationModelOutput, unpredictable: BoolTensor, skip_metrics: Sequence[str], prefix: str)[source]

Logs metric results for a given output result.

Parameters:
results: StreamClassificationModelOutput

The results to assess across the suite of metrics.

skip_metrics: Sequence[str]

A list of metrics to skip. Entries are not full metric names, but rather are partial names and any metric whose name contains an element of skip_metrics will be skipped. For example, if skip_metrics = ['AUROC', 'AUPRC'], then a metric with name 'macro_AUROC' or 'micro_AUPRC' would be skipped, whereas a metric named 'weighted_accuracy' would not.

prefix: str

The prefix that should be used when logging metric results. Will likely be ‘train’, ‘tuning’, or ‘held_out’, for example.

test_step(batch, batch_idx)[source]

Validation step.

Differs from training only in that it does not skip metrics.

validation_step(batch, batch_idx)[source]

Validation step.

Differs from training only in that it does not skip metrics.

EventStream.transformer.lightning_modules.zero_shot_evaluator.import_class_from_file(module_path, class_name)[source]
EventStream.transformer.lightning_modules.zero_shot_evaluator.zero_shot_evaluation(cfg: FinetuneConfig)[source]