diff --git a/src/openbench/dataset/dataset_transcription.py b/src/openbench/dataset/dataset_transcription.py index 706ef23..d89ae2f 100644 --- a/src/openbench/dataset/dataset_transcription.py +++ b/src/openbench/dataset/dataset_transcription.py @@ -13,6 +13,7 @@ class TranscriptionExtraInfo(TypedDict, total=False): language: str dictionary: list[str] + metric_keywords: list[str] class TranscriptionRow(TypedDict): @@ -24,6 +25,7 @@ class TranscriptionRow(TypedDict): word_timestamps_end: NotRequired[list[float]] language: NotRequired[str] dictionary: NotRequired[list[str]] + metric_keywords: NotRequired[list[str]] class TranscriptionSample(BaseSample[Transcript, TranscriptionExtraInfo]): @@ -39,6 +41,11 @@ def dictionary(self) -> list[str] | None: """Convenience property to access dictionary from extra_info.""" return self.extra_info.get("dictionary") + @property + def metric_keywords(self) -> list[str] | None: + """Convenience property to access metric keywords from extra_info.""" + return self.extra_info.get("metric_keywords") + class TranscriptionDataset(BaseDataset[TranscriptionSample]): """Dataset for transcription pipelines with optional keyword support.""" @@ -60,4 +67,7 @@ def prepare_sample(self, row: TranscriptionRow) -> tuple[Transcript, Transcripti extra_info["language"] = row["language"] if "dictionary" in row: extra_info["dictionary"] = row["dictionary"] + metric_keywords = row.get("metric-keywords") or row.get("metric_keywords") + if metric_keywords: + extra_info["metric_keywords"] = metric_keywords return reference, extra_info diff --git a/src/openbench/runner/benchmark.py b/src/openbench/runner/benchmark.py index 439c15c..c05050d 100644 --- a/src/openbench/runner/benchmark.py +++ b/src/openbench/runner/benchmark.py @@ -125,7 +125,11 @@ def _process_single_sample( for metric_name, metric in metrics_dict.items(): reference = sample.reference - kwargs = sample.extra_info + kwargs = dict(sample.extra_info) + + # Use metric_keywords for metric calculation if available, falling back to dictionary + if "metric_keywords" in kwargs: + kwargs["dictionary"] = kwargs.pop("metric_keywords") # The metric returns a dictionary that is also stored in the metric object as a state to compute the global result # We copy to avoid any side effects that may happen while interacting with dictionary for reporting @@ -254,10 +258,15 @@ def _run_pipeline_on_dataset_parallel( # Update metric with all results for sample_result in per_sample_results: sample = dataset[sample_result.sample_id] - # Get UEM from extra_info if available kwargs = {} - if hasattr(sample, "extra_info") and "uem" in sample.extra_info: - kwargs["uem"] = sample.extra_info["uem"] + if hasattr(sample, "extra_info"): + if "uem" in sample.extra_info: + kwargs["uem"] = sample.extra_info["uem"] + # Use metric_keywords for metric calculation if available, falling back to dictionary + if "metric_keywords" in sample.extra_info: + kwargs["dictionary"] = sample.extra_info["metric_keywords"] + elif "dictionary" in sample.extra_info: + kwargs["dictionary"] = sample.extra_info["dictionary"] metric(hypothesis=sample_result.prediction, reference=sample.reference, detailed=True, **kwargs)