Skip to content

Commit

Permalink
Add groundedness pro eval (Azure#38063)
Browse files Browse the repository at this point in the history
* Adding service based groundedness

* groundedness pro eval

* remove groundedness and fix unit tests

* run black

* change evaluate label

* Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py

Co-authored-by: Neehar Duvvuri <[email protected]>

* Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py

Co-authored-by: Neehar Duvvuri <[email protected]>

* comments and CL

* re record tests

* black and pylint

* comments

* nits

* analysis

* re cast

* more mypy appeasement

---------

Co-authored-by: Ankit Singhal <[email protected]>
Co-authored-by: Neehar Duvvuri <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent 383b5cd commit 578b16c
Show file tree
Hide file tree
Showing 25 changed files with 403 additions and 75 deletions.
1 change: 1 addition & 0 deletions sdk/evaluation/azure-ai-evaluation/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
## 1.0.0b5 (Unreleased)

### Features Added
- Added `GroundednessProEvaluator`, which is a service-based evaluator for determining response groundedness.
- Groundedness detection in Non Adversarial Simulator via query/context pairs
```python
import importlib.resources as pkg_resources
Expand Down
2 changes: 1 addition & 1 deletion sdk/evaluation/azure-ai-evaluation/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/evaluation/azure-ai-evaluation",
"Tag": "python/evaluation/azure-ai-evaluation_1390701e9d"
"Tag": "python/evaluation/azure-ai-evaluation_5551827d25"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._evaluators._fluency import FluencyEvaluator
from ._evaluators._gleu import GleuScoreEvaluator
from ._evaluators._groundedness import GroundednessEvaluator
from ._evaluators._service_groundedness import GroundednessProEvaluator
from ._evaluators._meteor import MeteorScoreEvaluator
from ._evaluators._protected_material import ProtectedMaterialEvaluator
from ._evaluators._qa import QAEvaluator
Expand All @@ -40,6 +41,7 @@
"F1ScoreEvaluator",
"FluencyEvaluator",
"GroundednessEvaluator",
"GroundednessProEvaluator",
"RelevanceEvaluator",
"SimilarityEvaluator",
"QAEvaluator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Tasks:
CONTENT_HARM = "content harm"
PROTECTED_MATERIAL = "protected material"
XPIA = "xpia"
GROUNDEDNESS = "groundedness"


class _InternalAnnotationTasks:
Expand All @@ -56,6 +57,7 @@ class EvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
SEXUAL = "sexual"
PROTECTED_MATERIAL = "protected_material"
XPIA = "xpia"
GROUNDEDNESS = "generic_groundedness"


class _InternalEvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import re
import time
from ast import literal_eval
from typing import Any, Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Union, cast
from urllib.parse import urlparse
from string import Template

import jwt

Expand All @@ -23,7 +24,6 @@
EvaluationMetrics,
RAIService,
Tasks,
_InternalAnnotationTasks,
_InternalEvaluationMetrics,
)
from .utils import get_harm_severity_level
Expand All @@ -34,6 +34,11 @@
version = "unknown"
USER_AGENT = "{}/{}".format("azure-ai-evaluation", version)

USER_TEXT_TEMPLATE_DICT: Dict[str, Template] = {
"DEFAULT": Template("<Human>{$query}</><System>{$response}</>"),
Tasks.GROUNDEDNESS: Template('{"question": "$query", "answer": "$response", "context": "$context"}'),
}


def get_common_headers(token: str) -> Dict:
"""Get common headers for the HTTP request
Expand Down Expand Up @@ -99,27 +104,26 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
)


def generate_payload(normalized_user_text: str, metric: str) -> Dict:
def generate_payload(normalized_user_text: str, metric: str, annotation_task: str) -> Dict:
"""Generate the payload for the annotation request
:param normalized_user_text: The normalized user text to be entered as the "UserTextList" in the payload.
:type normalized_user_text: str
:param metric: The evaluation metric to use. This determines the task type, and whether a "MetricList" is needed
in the payload.
:type metric: str
:param annotation_task: The annotation task to be passed to service
:type annotation_task: str
:return: The payload for the annotation request.
:rtype: Dict
"""
include_metric = True
task = Tasks.CONTENT_HARM
task = annotation_task
if metric == EvaluationMetrics.PROTECTED_MATERIAL:
task = Tasks.PROTECTED_MATERIAL
include_metric = False
elif metric == _InternalEvaluationMetrics.ECI:
task = _InternalAnnotationTasks.ECI
include_metric = False
elif metric == EvaluationMetrics.XPIA:
task = Tasks.XPIA
include_metric = False
return (
{
Expand All @@ -135,25 +139,25 @@ def generate_payload(normalized_user_text: str, metric: str) -> Dict:
)


async def submit_request(query: str, response: str, metric: str, rai_svc_url: str, token: str) -> str:
async def submit_request(data: dict, metric: str, rai_svc_url: str, token: str, annotation_task: str) -> str:
"""Submit request to Responsible AI service for evaluation and return operation ID
:param query: The query to evaluate.
:type query: str
:param response: The response to evaluate.
:type response: str
:param data: The data to evaluate.
:type data: dict
:param metric: The evaluation metric to use.
:type metric: str
:param rai_svc_url: The Responsible AI service URL.
:type rai_svc_url: str
:param token: The Azure authentication token.
:type token: str
:param annotation_task: The annotation task to use.
:type annotation_task: str
:return: The operation ID.
:rtype: str
"""
user_text = f"<Human>{query}</><System>{response}</>"
user_text = USER_TEXT_TEMPLATE_DICT.get(annotation_task, USER_TEXT_TEMPLATE_DICT["DEFAULT"]).substitute(**data)
normalized_user_text = user_text.replace("'", '\\"')
payload = generate_payload(normalized_user_text, metric)
payload = generate_payload(normalized_user_text, metric, annotation_task=annotation_task)

url = rai_svc_url + "/submitannotation"
headers = get_common_headers(token)
Expand All @@ -164,7 +168,6 @@ async def submit_request(query: str, response: str, metric: str, rai_svc_url: st
if http_response.status_code != 202:
print("Fail evaluating '%s' with error message: %s" % (payload["UserTextList"], http_response.text()))
http_response.raise_for_status()

result = http_response.json()
operation_id = result["location"].split("/")[-1]
return operation_id
Expand Down Expand Up @@ -208,19 +211,28 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre


def parse_response( # pylint: disable=too-many-branches,too-many-statements
batch_response: List[Dict], metric_name: str
batch_response: List[Dict], metric_name: str, metric_display_name: Optional[str] = None
) -> Dict[str, Union[str, float]]:
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
:param batch_response: The annotation response from Responsible AI service.
:type batch_response: List[Dict]
:param metric_name: The evaluation metric to use.
:type metric_name: str
:param metric_display_name: The evaluation metric display name to use. If unset, use the metric_name.
:type metric_display_name: Optional[str]
:return: The parsed annotation result.
:rtype: Dict[str, Union[str, float]]
"""
if metric_display_name is None:
metric_display_name = metric_name

# non-numeric metrics
if metric_name in {EvaluationMetrics.PROTECTED_MATERIAL, _InternalEvaluationMetrics.ECI, EvaluationMetrics.XPIA}:
if metric_name in {
EvaluationMetrics.PROTECTED_MATERIAL,
_InternalEvaluationMetrics.ECI,
EvaluationMetrics.XPIA,
}:
if not batch_response or len(batch_response[0]) == 0 or metric_name not in batch_response[0]:
return {}
response = batch_response[0][metric_name]
Expand All @@ -230,38 +242,42 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements
result = {}
# Use label instead of score since these are assumed to be boolean results.
# Use math.nan as null value since it's ignored by aggregations rather than treated as 0.
result[metric_name + "_label"] = parsed_response["label"] if "label" in parsed_response else math.nan
result[metric_name + "_reason"] = parsed_response["reasoning"] if "reasoning" in parsed_response else ""
result[metric_display_name + "_label"] = parsed_response["label"] if "label" in parsed_response else math.nan
result[metric_display_name + "_reason"] = parsed_response["reasoning"] if "reasoning" in parsed_response else ""

if metric_name == EvaluationMetrics.XPIA:
# Add "manipulated_content", "intrusion" and "information_gathering" to the result
# if present else set them to math.nan
result[metric_name + "_manipulated_content"] = (
result[metric_display_name + "_manipulated_content"] = (
parsed_response["manipulated_content"] if "manipulated_content" in parsed_response else math.nan
)
result[metric_name + "_intrusion"] = (
result[metric_display_name + "_intrusion"] = (
parsed_response["intrusion"] if "intrusion" in parsed_response else math.nan
)
result[metric_name + "_information_gathering"] = (
result[metric_display_name + "_information_gathering"] = (
parsed_response["information_gathering"] if "information_gathering" in parsed_response else math.nan
)
return result
return _parse_content_harm_response(batch_response, metric_name)
return _parse_content_harm_response(batch_response, metric_name, metric_display_name)


def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -> Dict[str, Union[str, float]]:
def _parse_content_harm_response(
batch_response: List[Dict], metric_name: str, metric_display_name: Optional[str] = None
) -> Dict[str, Union[str, float]]:
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
:param batch_response: The annotation response from Responsible AI service.
:type batch_response: List[Dict]
:param metric_name: The evaluation metric to use.
:type metric_name: str
:param metric_display_name: The evaluation metric display name to use. If unset, use the metric_name.
:type metric_display_name: Optional[str]
:return: The parsed annotation result.
:rtype: Dict[str, Union[str, float]]
"""
# Fix the metric name if it's "hate_fairness"
# Eventually we will remove this fix once the RAI service is updated
key = metric_name
key = metric_name if metric_display_name is None else metric_display_name
if key == EvaluationMetrics.HATE_FAIRNESS:
key = EvaluationMetrics.HATE_UNFAIRNESS

Expand All @@ -283,7 +299,7 @@ def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -

# get content harm metric_value
if "label" in harm_response:
metric_value = harm_response["label"]
metric_value = float(harm_response["label"])
elif "valid" in harm_response:
metric_value = 0 if harm_response["valid"] else math.nan
else:
Expand Down Expand Up @@ -412,33 +428,40 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str]


async def evaluate_with_rai_service(
query: str, response: str, metric_name: str, project_scope: AzureAIProject, credential: TokenCredential
) -> Dict[str, Any]:
data: dict,
metric_name: str,
project_scope: AzureAIProject,
credential: TokenCredential,
annotation_task: str = Tasks.CONTENT_HARM,
metric_display_name=None,
) -> Dict[str, Union[str, float]]:
""" "Evaluate the content safety of the response using Responsible AI service
:param query: The query to evaluate.
:type query: str
:param response: The response to evaluate.
:type response: str
:param data: The data to evaluate.
:type data: dict
:param metric_name: The evaluation metric to use.
:type metric_name: str
:param project_scope: The Azure AI project scope details.
:type project_scope: Dict
:param credential: The Azure authentication credential.
:type credential:
~azure.core.credentials.TokenCredential
:param annotation_task: The annotation task to use.
:type annotation_task: str
:param metric_display_name: The display name of metric to use.
:type metric_display_name: str
:return: The parsed annotation result.
:rtype: Dict[str, Union[str, float]]
"""

# Get RAI service URL from discovery service and check service availability
token = await fetch_or_reuse_token(credential)
rai_svc_url = await get_rai_svc_url(project_scope, token)
await ensure_service_availability(rai_svc_url, token, Tasks.CONTENT_HARM)
await ensure_service_availability(rai_svc_url, token, annotation_task)

# Submit annotation request and fetch result
operation_id = await submit_request(query, response, metric_name, rai_svc_url, token)
operation_id = await submit_request(data, metric_name, rai_svc_url, token, annotation_task)
annotation_response = cast(List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token))
result = parse_response(annotation_response, metric_name)
result = parse_response(annotation_response, metric_name, metric_display_name)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
T_TypedDict = TypeVar("T_TypedDict", bound=Mapping[Any, Any])


def get_harm_severity_level(harm_score: int) -> Union[str, float]:
def get_harm_severity_level(harm_score: Union[float, int]) -> Union[str, float]:
"""Generate harm severity level based on harm score.
:param harm_score: The harm score to be evaluated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,47 @@

TClient = TypeVar("TClient", ProxyClient, CodeClient)

# For metrics (aggregates) whose metric names intentionally differ from their
# originating column name, usually because the aggregation of the original value
# means something sufficiently different.
# Note that content safety metrics are handled seprately.
METRIC_COLUMN_NAME_REPLACEMENTS = {
"groundedness_pro_label": "groundedness_pro_passing_rate",
}


class __EvaluatorInfo(TypedDict):
result: pd.DataFrame
metrics: Dict[str, Any]
run_summary: Dict[str, Any]


def _aggregate_other_metrics(df: pd.DataFrame) -> Tuple[List[str], Dict[str, float]]:
"""Identify and average various metrics that need to have the metric name be replaced,
instead of having the metric match the originating column name.
:param df: The dataframe of evaluation results.
:type df: ~pandas.DataFrame
:return: A tuple; the first element is a list of dataframe columns that were aggregated,
and the second element is a dictionary of resultant new metric column names and their values.
:rtype: Tuple[List[str], Dict[str, float]]
"""
renamed_cols = []
metric_columns = {}
for col in df.columns:
metric_prefix = col.split(".")[0]
metric_name = col.split(".")[1]
if metric_name in METRIC_COLUMN_NAME_REPLACEMENTS:
renamed_cols.append(col)
new_col_name = metric_prefix + "." + METRIC_COLUMN_NAME_REPLACEMENTS[metric_name]
col_with_numeric_values = pd.to_numeric(df[col], errors="coerce")
metric_columns[new_col_name] = round(
list_sum(col_with_numeric_values) / col_with_numeric_values.count(),
2,
)

return renamed_cols, metric_columns


# pylint: disable=line-too-long
def _aggregate_content_safety_metrics(
df: pd.DataFrame, evaluators: Dict[str, Callable]
Expand Down Expand Up @@ -146,8 +180,11 @@ def _aggregate_metrics(df: pd.DataFrame, evaluators: Dict[str, Callable]) -> Dic
# Rename certain columns as defect rates if we know that's what their aggregates represent
# Content safety metrics
content_safety_cols, cs_defect_rates = _aggregate_content_safety_metrics(df, evaluators)
other_renamed_cols, renamed_cols = _aggregate_other_metrics(df)
handled_columns.extend(content_safety_cols)
handled_columns.extend(other_renamed_cols)
defect_rates.update(cs_defect_rates)
defect_rates.update(renamed_cols)
# Label-based (true/false) metrics where 'true' means 'something is wrong'
label_cols, label_defect_rates = _aggregate_label_defect_metrics(df)
handled_columns.extend(label_cols)
Expand Down
Loading

0 comments on commit 578b16c

Please sign in to comment.