Compare commits
2 Commits
7563a8b9d5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| c0f90bdd43 | |||
| 3d687e652d |
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Set, List, Tuple
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models.openai import OpenAIChatModel, Model
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
@@ -27,7 +27,9 @@ messages_to_censor = [
|
||||
|
||||
class SensitiveData(BaseModel):
|
||||
"""Structure for identifying sensitive data that should be censored"""
|
||||
sensitive_items: Set[str]
|
||||
sensitive_items: List[str] = Field(
|
||||
description="A list of sensitive data items (names, IDs, emails, etc.) that should be censored from the text"
|
||||
)
|
||||
|
||||
|
||||
async def simple_test_response(model: Model):
|
||||
@@ -45,7 +47,7 @@ async def simple_test_response(model: Model):
|
||||
print() # Add a final newline
|
||||
|
||||
|
||||
async def get_words_to_censor(model: Model, prompt: str) -> Set[str]:
|
||||
async def get_words_to_censor(model: Model, prompt: str) -> List[str]:
|
||||
# Make a structured response that will return everything from the prompt that needs censoring.
|
||||
|
||||
# Create an agent that returns structured data
|
||||
@@ -56,27 +58,33 @@ async def get_words_to_censor(model: Model, prompt: str) -> Set[str]:
|
||||
)
|
||||
|
||||
result = await censor_agent.run(prompt)
|
||||
sensitive_items = set(sorted(list(result.output.sensitive_items))) # for sorting
|
||||
sensitive_items = sorted(result.output.sensitive_items) # for sorting
|
||||
return sensitive_items
|
||||
|
||||
|
||||
SensitiveTextList = Annotated[
|
||||
List[str],
|
||||
Field(description="List of original sensitive text values that should be replaced with this placeholder")
|
||||
]
|
||||
|
||||
|
||||
class PlaceholderMapping(BaseModel):
|
||||
"""Structure for mapping sensitive data to placeholders"""
|
||||
mappings: List[Tuple[str, List[str]]] # List of (placeholder, [original_texts]) tuples
|
||||
mappings: Dict[str, SensitiveTextList] = Field(
|
||||
description="Dictionary where keys are placeholder names (e.g., '[PERSON_1]', '[EMAIL_1]') and values are lists of original sensitive text items"
|
||||
)
|
||||
|
||||
|
||||
async def generate_placeholders(model: Model, censored_words: Set[str]) -> List[Tuple[str, List[str]]]:
|
||||
async def generate_placeholders(model: Model, censored_words: List[str]) -> Dict[str, List[str]]:
|
||||
"""Generate placeholders for censored words"""
|
||||
placeholder_agent = Agent(
|
||||
model,
|
||||
output_type=PlaceholderMapping,
|
||||
system_prompt=PromptManager.get_prompt('generate_placeholders', {})
|
||||
system_prompt=PromptManager.get_prompt('generate_placeholders')
|
||||
)
|
||||
|
||||
# Convert set to sorted list for consistent ordering
|
||||
words_list = sorted(list(censored_words))
|
||||
|
||||
result = await placeholder_agent.run(words_list)
|
||||
result = await placeholder_agent.run(censored_words)
|
||||
return result.output.mappings
|
||||
|
||||
|
||||
@@ -91,9 +99,9 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
censored_words: Set = set()
|
||||
censored_words: List[str] = []
|
||||
for message in messages_to_censor:
|
||||
censored_words = censored_words | await get_words_to_censor(model, message)
|
||||
censored_words += await get_words_to_censor(model, message)
|
||||
|
||||
print("\nWords to censor:")
|
||||
for word in censored_words:
|
||||
@@ -104,8 +112,14 @@ async def main():
|
||||
placeholder_mappings = await generate_placeholders(model, censored_words)
|
||||
|
||||
print("\nPlaceholder mappings:")
|
||||
for placeholder, originals in placeholder_mappings:
|
||||
print(f" {placeholder} → {', '.join(originals)}")
|
||||
for placeholder, sensitive_items in placeholder_mappings.items():
|
||||
print(f" {placeholder}: {sensitive_items}")
|
||||
|
||||
print(f'Sensitive words: {len(censored_words)}')
|
||||
total_count = 0
|
||||
for _, value_list in placeholder_mappings.items():
|
||||
total_count += len(value_list)
|
||||
print(f'Sensitive words in placeholders: {total_count}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user