Compare commits

...

2 Commits

Author SHA1 Message Date
c0f90bdd43 fields added. sets to tuples 2025-09-08 22:53:40 +02:00
3d687e652d test 2025-09-08 06:24:17 +02:00

View File

@@ -1,6 +1,6 @@
import asyncio
from typing import Set, List, Tuple
from pydantic import BaseModel
from typing import List, Dict, Annotated
from pydantic import BaseModel, Field
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel, Model
from pydantic_ai.providers.openai import OpenAIProvider
@@ -27,7 +27,9 @@ messages_to_censor = [
class SensitiveData(BaseModel):
"""Structure for identifying sensitive data that should be censored"""
sensitive_items: Set[str]
sensitive_items: List[str] = Field(
description="A list of sensitive data items (names, IDs, emails, etc.) that should be censored from the text"
)
async def simple_test_response(model: Model):
@@ -45,7 +47,7 @@ async def simple_test_response(model: Model):
print() # Add a final newline
async def get_words_to_censor(model: Model, prompt: str) -> Set[str]:
async def get_words_to_censor(model: Model, prompt: str) -> List[str]:
# Make a structured response that will return everything from the prompt that needs censoring.
# Create an agent that returns structured data
@@ -56,27 +58,33 @@ async def get_words_to_censor(model: Model, prompt: str) -> Set[str]:
)
result = await censor_agent.run(prompt)
sensitive_items = set(sorted(list(result.output.sensitive_items))) # for sorting
sensitive_items = sorted(result.output.sensitive_items) # for sorting
return sensitive_items
SensitiveTextList = Annotated[
List[str],
Field(description="List of original sensitive text values that should be replaced with this placeholder")
]
class PlaceholderMapping(BaseModel):
"""Structure for mapping sensitive data to placeholders"""
mappings: List[Tuple[str, List[str]]] # List of (placeholder, [original_texts]) tuples
mappings: Dict[str, SensitiveTextList] = Field(
description="Dictionary where keys are placeholder names (e.g., '[PERSON_1]', '[EMAIL_1]') and values are lists of original sensitive text items"
)
async def generate_placeholders(model: Model, censored_words: Set[str]) -> List[Tuple[str, List[str]]]:
async def generate_placeholders(model: Model, censored_words: List[str]) -> Dict[str, List[str]]:
"""Generate placeholders for censored words"""
placeholder_agent = Agent(
model,
output_type=PlaceholderMapping,
system_prompt=PromptManager.get_prompt('generate_placeholders', {})
system_prompt=PromptManager.get_prompt('generate_placeholders')
)
# Convert set to sorted list for consistent ordering
words_list = sorted(list(censored_words))
result = await placeholder_agent.run(words_list)
result = await placeholder_agent.run(censored_words)
return result.output.mappings
@@ -91,9 +99,9 @@ async def main():
)
)
censored_words: Set = set()
censored_words: List[str] = []
for message in messages_to_censor:
censored_words = censored_words | await get_words_to_censor(model, message)
censored_words += await get_words_to_censor(model, message)
print("\nWords to censor:")
for word in censored_words:
@@ -104,8 +112,14 @@ async def main():
placeholder_mappings = await generate_placeholders(model, censored_words)
print("\nPlaceholder mappings:")
for placeholder, originals in placeholder_mappings:
print(f" {placeholder} {', '.join(originals)}")
for placeholder, sensitive_items in placeholder_mappings.items():
print(f" {placeholder}: {sensitive_items}")
print(f'Sensitive words: {len(censored_words)}')
total_count = 0
for _, value_list in placeholder_mappings.items():
total_count += len(value_list)
print(f'Sensitive words in placeholders: {total_count}')
if __name__ == '__main__':