-
Notifications
You must be signed in to change notification settings - Fork 725
MAINT: Add labels to attack results #1624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b63361c
e8164aa
4ec494a
5e04ca7
fa5a0d7
aafefc1
fca41cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| from datetime import datetime, timedelta, timezone | ||
| from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union | ||
|
|
||
| from sqlalchemy import and_, create_engine, event, exists, text | ||
| from sqlalchemy import and_, create_engine, event, exists, or_, text | ||
| from sqlalchemy.engine.base import Engine | ||
| from sqlalchemy.exc import SQLAlchemyError | ||
| from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker | ||
|
|
@@ -448,32 +448,52 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: | |
| """ | ||
| Get the SQL Azure implementation for filtering AttackResults by labels. | ||
|
|
||
| Matches if the labels are found on the AttackResultEntry directly | ||
| OR on an associated PromptMemoryEntry (via conversation_id). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we get rid of PromptMemoryEntry labels?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'd rather have the route to only attack result labels to simplify things. And a data migration path for the databases
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am putting that in a separate PR (deprecate labels on message piece and its memory entries) and then, this OR will be removed too. Maybe I'm being over-cautious? should I do the flip in one go? |
||
|
|
||
| Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. | ||
|
|
||
| Args: | ||
| labels (dict[str, str]): Dictionary of label key-value pairs to filter by. | ||
|
|
||
| Returns: | ||
| Any: SQLAlchemy exists subquery condition with bound parameters. | ||
| """ | ||
| # Build JSON conditions for all labels with parameterized queries | ||
| label_conditions = [] | ||
| bindparams_dict = {} | ||
| for key, value in labels.items(): | ||
| param_name = f"label_{key}" | ||
| label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") | ||
| bindparams_dict[param_name] = str(value) | ||
| Any: SQLAlchemy condition with bound parameters. | ||
| """ | ||
| # --- Direct match on AttackResultEntry.labels --- | ||
| ar_label_conditions = [] | ||
| ar_bindparams: dict[str, str] = {} | ||
| for i, (key, value) in enumerate(labels.items()): | ||
| path_param = f"ar_label_path_{i}" | ||
| value_param = f"ar_label_val_{i}" | ||
| ar_label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') | ||
| ar_bindparams[path_param] = f"$.{key}" | ||
| ar_bindparams[value_param] = str(value) | ||
|
|
||
| ar_combined = " AND ".join(ar_label_conditions) | ||
| direct_condition = and_( | ||
| AttackResultEntry.labels.isnot(None), | ||
| text(f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}').bindparams(**ar_bindparams), | ||
| ) | ||
|
|
||
| combined_conditions = " AND ".join(label_conditions) | ||
| # --- Conversation-level match on PromptMemoryEntry.labels --- | ||
| pme_label_conditions = [] | ||
| pme_bindparams: dict[str, str] = {} | ||
| for key, value in labels.items(): | ||
| param_name = f"pme_label_{key}" | ||
| pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") | ||
| pme_bindparams[param_name] = str(value) | ||
|
|
||
| return exists().where( | ||
| pme_combined = " AND ".join(pme_label_conditions) | ||
| conversation_condition = exists().where( | ||
| and_( | ||
| PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, | ||
| PromptMemoryEntry.labels.isnot(None), | ||
| text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), | ||
| text(f"ISJSON(labels) = 1 AND {pme_combined}").bindparams(**pme_bindparams), | ||
| ) | ||
| ) | ||
|
|
||
| return or_(direct_condition, conversation_condition) | ||
|
|
||
| def get_unique_attack_class_names(self) -> list[str]: | ||
| """ | ||
| Azure SQL implementation: extract unique class_name values from | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -419,7 +419,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories | |||||||||
| def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: | ||||||||||
| """ | ||||||||||
| Return a database-specific condition for filtering AttackResults by labels | ||||||||||
| in the associated PromptMemoryEntry records. | ||||||||||
| stored directly on the AttackResultEntry OR on an associated PromptMemoryEntry (via conversation_id). | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| labels: Dictionary of labels that must ALL be present. | ||||||||||
|
|
@@ -1480,9 +1480,9 @@ def get_attack_results( | |||||||||
| not necessarily one(s) that were found in the response. | ||||||||||
| By providing a list, this means ALL categories in the list must be present. | ||||||||||
| Defaults to None. | ||||||||||
| labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. | ||||||||||
| These labels are associated with the prompts themselves, used for custom tagging and tracking. | ||||||||||
| Defaults to None. | ||||||||||
| labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. | ||||||||||
| These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id) | ||||||||||
| . All specified key-value pairs must be present (AND logic). Defaults to None. | ||||||||||
|
Comment on lines
+1484
to
+1485
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| identifier_filters (Optional[Sequence[IdentifierFilter]], optional): | ||||||||||
| A sequence of IdentifierFilter objects that allows filtering by various attack identifier | ||||||||||
| JSON properties. Defaults to None. | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confused about this comment. Aren't attack-result labels conversation-level?