Skip to content

Commit

Permalink
Add multi-turn metrics and templates support
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Feb 10, 2025
1 parent 14401eb commit 12c2118
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 49 deletions.
27 changes: 27 additions & 0 deletions examples/evaluate_multi_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unitxt import settings
from unitxt.api import evaluate, load_dataset
from unitxt.inference import (
CrossProviderInferenceEngine,
)

with settings.context(
disable_hf_datasets_cache=False,
):
model = CrossProviderInferenceEngine(
model="llama-3-2-1b-instruct", provider="watsonx"
)
dataset = load_dataset(
card="cards.coqa.multi_turn",
format="formats.chat_api",
split="test",
max_test_instances=100,
)

predictions = model.infer(dataset)
results = evaluate(predictions=predictions, data=dataset)

print("Global Results:")
print(results.global_scores.summary)

print("Instance Results:")
print(results.instance_scores.summary)
14 changes: 4 additions & 10 deletions prepare/cards/coqa_multi_turn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Dict

from unitxt.blocks import LoadHF, TaskCard
from unitxt.catalog import add_to_catalog
from unitxt.collections_operators import DuplicateBySubLists, Pop, Wrap
from unitxt.dialog_operators import ToDialog
from unitxt.operator import InstanceOperator
from unitxt.operators import AddID, Copy, FieldOperator, ZipFieldValues
from unitxt.operators import AddID, Copy, ZipFieldValues
from unitxt.test_utils.card import test_card


Expand All @@ -14,15 +16,6 @@ def process(
return instance


class ToDialog(FieldOperator):
def process_value(self, value: Any) -> Any:
dialog = []
for question, answer in value:
dialog.append({"role": "user", "content": question})
dialog.append({"role": "agent", "content": answer})
return dialog


card = TaskCard(
loader=LoadHF(path="stanfordnlp/coqa"),
preprocess_steps=[
Expand Down Expand Up @@ -74,3 +67,4 @@ def process_value(self, value: Any) -> Any:
)

test_card(card)
add_to_catalog(card, "cards.coqa.multi_turn", overwrite=True)
60 changes: 60 additions & 0 deletions prepare/metrics/multi_turn.py/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unitxt import add_to_catalog
from unitxt.metrics import AccuracyFast, MultiTurnMetric
from unitxt.test_utils.metrics import test_metric

metric = MultiTurnMetric(metric=AccuracyFast())

predictions = ["A", "B", "C"]
references = [["B"], ["A"], ["C"]]
task_data = [
{
"conversation": {
"id": "aa",
"dialog": [{"role": "user", "content": "what is it?"}],
}
},
{
"conversation": {
"id": "aa",
"dialog": [
{"role": "user", "content": "what is it?"},
{"role": "agent", "content": "A"},
{"role": "user", "content": "what is it again?"},
],
}
},
{
"conversation": {
"id": "bb",
"dialog": [{"role": "user", "content": "what is it?"}],
}
},
]

instance_targets = [
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
{"accuracy": 1.0, "score": 1.0, "score_name": "accuracy"},
]

global_target = {
"accuracy": 0.5,
"accuracy_ci_high": 1.0,
"accuracy_ci_low": 0.0,
"num_of_instances": 3,
"score": 0.5,
"score_ci_high": 1.0,
"score_ci_low": 0.0,
"score_name": "accuracy",
}

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
task_data=task_data,
instance_targets=instance_targets,
global_target=global_target,
)

add_to_catalog(metric, "metrics.multi_turn.accuracy", overwrite=True)
4 changes: 2 additions & 2 deletions prepare/tasks/qa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
},
reference_fields={"answers": List[str]},
prediction_type=str,
metrics=["metrics.squad"],
default_template="templates.qa.extractive",
metrics=["metrics.multi_turn.accuracy"],
default_template="templates.qa.multi_turn.with_context.simple",
augmentable_inputs=["context"],
),
"tasks.qa.extractive.multi_turn",
Expand Down
24 changes: 23 additions & 1 deletion prepare/templates/qa/with_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from unitxt.catalog import add_to_catalog
from unitxt.serializers import (
ConversationSerializer,
DialogSerializer,
ImageSerializer,
ListSerializer,
MultiTypeSerializer,
SQLDatabaseAsSchemaSerializer,
TableSerializer,
VideoSerializer,
)
from unitxt.templates import MultiReferenceTemplate, TemplatesList

add_to_catalog(
Expand Down Expand Up @@ -135,8 +145,20 @@

add_to_catalog(
MultiReferenceTemplate(
input_format="Context: {context}\n{conversation}",
instruction="Read the context and answer the last question in the conversation. Answer with the minimal span from the context answering the question.",
input_format="Context: {context}\n\nConversation:\n{conversation}",
references_field="answers",
serializer=MultiTypeSerializer(
serializers=[
ImageSerializer(),
VideoSerializer(),
TableSerializer(),
DialogSerializer(),
ConversationSerializer(),
ListSerializer(),
SQLDatabaseAsSchemaSerializer(),
]
),
),
"templates.qa.multi_turn.with_context.simple",
overwrite=True,
Expand Down
88 changes: 88 additions & 0 deletions src/unitxt/catalog/cards/coqa/multi_turn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "stanfordnlp/coqa"
},
"preprocess_steps": [
"splitters.small_no_test",
{
"__type__": "add_id"
},
{
"__type__": "copy",
"field": "id",
"to_field": "conversation/id"
},
{
"__type__": "zip_field_values",
"fields": [
"questions",
"answers/input_text"
],
"to_field": "dialog"
},
{
"__type__": "duplicate_by_sub_lists",
"field": "dialog"
},
{
"__type__": "to_dialog",
"field": "dialog"
},
{
"__type__": "pop",
"field": "dialog",
"item": -1,
"to_field": "last_turn"
},
{
"__type__": "copy",
"field_to_field": {
"last_turn/content": "answer",
"story": "context"
}
},
{
"__type__": "wrap",
"field": "answer",
"inside": "list",
"to_field": "answers"
},
{
"__type__": "copy",
"field": "dialog",
"to_field": "conversation/dialog"
}
],
"task": "tasks.qa.extractive.multi_turn",
"templates": [
"templates.qa.multi_turn.with_context.simple"
],
"__tags__": {
"annotations_creators": "crowdsourced",
"arxiv": [
"1808.07042",
"1704.04683",
"1506.03340"
],
"flags": [
"conversational-qa"
],
"language": "en",
"language_creators": "found",
"license": "other",
"multilinguality": "monolingual",
"region": "us",
"size_categories": "1K<n<10K",
"source_datasets": [
"extended|race",
"extended|cnn_dailymail",
"extended|wikipedia",
"extended|other"
],
"task_categories": "question-answering",
"task_ids": "extractive-qa"
},
"__description__": "CoQA is a large-scale dataset for building Conversational Question Answering systems. \nOur dataset contains 127k questions with answers, obtained from 8k conversations about text passages from seven diverse domains. The questions are conversational, and the answers are free-form text with their corresponding evidence highlighted in the passage. Supported Tasks and Leaderboards More Information Needed… See the full description on the dataset page: https://huggingface.co/datasets/stanfordnlp/coqa."
}
6 changes: 6 additions & 0 deletions src/unitxt/catalog/metrics/multi_turn/accuracy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "multi_turn_metric",
"metric": {
"__type__": "accuracy_fast"
}
}
4 changes: 2 additions & 2 deletions src/unitxt/catalog/tasks/qa/extractive/multi_turn.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
},
"prediction_type": "str",
"metrics": [
"metrics.squad"
"metrics.multi_turn.accuracy"
],
"default_template": "templates.qa.extractive",
"default_template": "templates.qa.multi_turn.with_context.simple",
"augmentable_inputs": [
"context"
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@
{
"__type__": "multi_reference_template",
"input_format": "Context: {context}\n{conversation}",
"references_field": "answers"
"instruction": "Read the context and answer the last question in the conversation. Answer with the minimal span from the context answering the question.",
"input_format": "Context: {context}\n\nConversation:\n{conversation}",
"references_field": "answers",
"serializer": {
"__type__": "multi_type_serializer",
"serializers": [
{
"__type__": "image_serializer"
},
{
"__type__": "video_serializer"
},
{
"__type__": "table_serializer"
},
{
"__type__": "dialog_serializer"
},
{
"__type__": "conversation_serializer"
},
{
"__type__": "list_serializer"
},
{
"__type__": "sql_database_as_schema_serializer"
}
]
}
}
11 changes: 10 additions & 1 deletion src/unitxt/dialog_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from typing import Any, Dict, List, Optional

from .formats import SystemFormat
from .operators import InstanceFieldOperator
from .operators import FieldOperator, InstanceFieldOperator


class ToDialog(FieldOperator):
def process_value(self, value: Any) -> Any:
dialog = []
for question, answer in value:
dialog.append({"role": "user", "content": question})
dialog.append({"role": "agent", "content": answer})
return dialog


class SerializeDialog(InstanceFieldOperator):
Expand Down
Loading

0 comments on commit 12c2118

Please sign in to comment.