EventStream.transformer.lightning_modules.zero_shot_evaluator module¶
-
class EventStream.transformer.lightning_modules.zero_shot_evaluator.ESTForZeroShotClassificationLM(config: StructuredTransformerConfig | dict[str, Any], pretrained_weights_fp: Path, labeling_function: Labeler, max_new_events: int =
10)[source]¶ Bases:
LightningModuleA PyTorch Lightning Module for a zero-shot classification via generation for an EST model.
- get_generative_predictions(batch: PytorchBatch) StreamClassificationModelOutput[source]¶
# capture num_samples to generate
- log_metrics(results: StreamClassificationModelOutput, unpredictable: BoolTensor, skip_metrics: Sequence[str], prefix: str)[source]¶
Logs metric results for a given output result.
- Parameters:¶
- results: StreamClassificationModelOutput¶
The results to assess across the suite of metrics.
- skip_metrics: Sequence[str]¶
A list of metrics to skip. Entries are not full metric names, but rather are partial names and any metric whose name contains an element of
skip_metricswill be skipped. For example, ifskip_metrics = ['AUROC', 'AUPRC'], then a metric with name'macro_AUROC'or'micro_AUPRC'would be skipped, whereas a metric named'weighted_accuracy'would not.- prefix: str¶
The prefix that should be used when logging metric results. Will likely be ‘train’, ‘tuning’, or ‘held_out’, for example.
- EventStream.transformer.lightning_modules.zero_shot_evaluator.import_class_from_file(module_path, class_name)[source]¶
- EventStream.transformer.lightning_modules.zero_shot_evaluator.zero_shot_evaluation(cfg: FinetuneConfig)[source]¶