diff --git a/src/pydantic_censoring.py b/src/pydantic_censoring.py index c0145ba..4323c5a 100644 --- a/src/pydantic_censoring.py +++ b/src/pydantic_censoring.py @@ -1,6 +1,6 @@ import asyncio -from typing import Set, List, Tuple -from pydantic import BaseModel +from typing import List, Dict, Annotated +from pydantic import BaseModel, Field from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel, Model from pydantic_ai.providers.openai import OpenAIProvider @@ -27,7 +27,9 @@ messages_to_censor = [ class SensitiveData(BaseModel): """Structure for identifying sensitive data that should be censored""" - sensitive_items: Set[str] + sensitive_items: List[str] = Field( + description="A list of sensitive data items (names, IDs, emails, etc.) that should be censored from the text" + ) async def simple_test_response(model: Model): @@ -45,7 +47,7 @@ async def simple_test_response(model: Model): print() # Add a final newline -async def get_words_to_censor(model: Model, prompt: str) -> Set[str]: +async def get_words_to_censor(model: Model, prompt: str) -> List[str]: # Make a structured response that will return everything from the prompt that needs censoring. # Create an agent that returns structured data @@ -56,16 +58,24 @@ async def get_words_to_censor(model: Model, prompt: str) -> Set[str]: ) result = await censor_agent.run(prompt) - sensitive_items = set(sorted(list(result.output.sensitive_items))) # for sorting + sensitive_items = sorted(result.output.sensitive_items) # for sorting return sensitive_items +SensitiveTextList = Annotated[ + List[str], + Field(description="List of original sensitive text values that should be replaced with this placeholder") +] + + class PlaceholderMapping(BaseModel): """Structure for mapping sensitive data to placeholders""" - mappings: List[Tuple[str, List[str]]] # List of (placeholder, [original_texts]) tuples + mappings: Dict[str, SensitiveTextList] = Field( + description="Dictionary where keys are placeholder names (e.g., '[PERSON_1]', '[EMAIL_1]') and values are lists of original sensitive text items" + ) -async def generate_placeholders(model: Model, censored_words: Set[str]) -> List[Tuple[str, List[str]]]: +async def generate_placeholders(model: Model, censored_words: List[str]) -> Dict[str, List[str]]: """Generate placeholders for censored words""" placeholder_agent = Agent( model, @@ -74,8 +84,7 @@ async def generate_placeholders(model: Model, censored_words: Set[str]) -> List[ ) # Convert set to sorted list for consistent ordering - words_list = sorted(list(censored_words)) - result = await placeholder_agent.run(words_list) + result = await placeholder_agent.run(censored_words) return result.output.mappings @@ -90,9 +99,9 @@ async def main(): ) ) - censored_words: Set = set() + censored_words: List[str] = [] for message in messages_to_censor: - censored_words = censored_words | await get_words_to_censor(model, message) + censored_words += await get_words_to_censor(model, message) print("\nWords to censor:") for word in censored_words: @@ -103,8 +112,14 @@ async def main(): placeholder_mappings = await generate_placeholders(model, censored_words) print("\nPlaceholder mappings:") - for placeholder, originals in placeholder_mappings: - print(f" {placeholder} → {', '.join(originals)}") + for placeholder, sensitive_items in placeholder_mappings.items(): + print(f" {placeholder}: {sensitive_items}") + + print(f'Sensitive words: {len(censored_words)}') + total_count = 0 + for _, value_list in placeholder_mappings.items(): + total_count += len(value_list) + print(f'Sensitive words in placeholders: {total_count}') if __name__ == '__main__':