Compare commits
2 Commits
7563a8b9d5
...
c0f90bdd43
| Author | SHA1 | Date | |
|---|---|---|---|
| c0f90bdd43 | |||
| 3d687e652d |
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Set, List, Tuple
|
from typing import List, Dict, Annotated
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_ai import Agent
|
from pydantic_ai import Agent
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel, Model
|
from pydantic_ai.models.openai import OpenAIChatModel, Model
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
@@ -27,7 +27,9 @@ messages_to_censor = [
|
|||||||
|
|
||||||
class SensitiveData(BaseModel):
|
class SensitiveData(BaseModel):
|
||||||
"""Structure for identifying sensitive data that should be censored"""
|
"""Structure for identifying sensitive data that should be censored"""
|
||||||
sensitive_items: Set[str]
|
sensitive_items: List[str] = Field(
|
||||||
|
description="A list of sensitive data items (names, IDs, emails, etc.) that should be censored from the text"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def simple_test_response(model: Model):
|
async def simple_test_response(model: Model):
|
||||||
@@ -45,7 +47,7 @@ async def simple_test_response(model: Model):
|
|||||||
print() # Add a final newline
|
print() # Add a final newline
|
||||||
|
|
||||||
|
|
||||||
async def get_words_to_censor(model: Model, prompt: str) -> Set[str]:
|
async def get_words_to_censor(model: Model, prompt: str) -> List[str]:
|
||||||
# Make a structured response that will return everything from the prompt that needs censoring.
|
# Make a structured response that will return everything from the prompt that needs censoring.
|
||||||
|
|
||||||
# Create an agent that returns structured data
|
# Create an agent that returns structured data
|
||||||
@@ -56,27 +58,33 @@ async def get_words_to_censor(model: Model, prompt: str) -> Set[str]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await censor_agent.run(prompt)
|
result = await censor_agent.run(prompt)
|
||||||
sensitive_items = set(sorted(list(result.output.sensitive_items))) # for sorting
|
sensitive_items = sorted(result.output.sensitive_items) # for sorting
|
||||||
return sensitive_items
|
return sensitive_items
|
||||||
|
|
||||||
|
|
||||||
|
SensitiveTextList = Annotated[
|
||||||
|
List[str],
|
||||||
|
Field(description="List of original sensitive text values that should be replaced with this placeholder")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderMapping(BaseModel):
|
class PlaceholderMapping(BaseModel):
|
||||||
"""Structure for mapping sensitive data to placeholders"""
|
"""Structure for mapping sensitive data to placeholders"""
|
||||||
mappings: List[Tuple[str, List[str]]] # List of (placeholder, [original_texts]) tuples
|
mappings: Dict[str, SensitiveTextList] = Field(
|
||||||
|
description="Dictionary where keys are placeholder names (e.g., '[PERSON_1]', '[EMAIL_1]') and values are lists of original sensitive text items"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_placeholders(model: Model, censored_words: Set[str]) -> List[Tuple[str, List[str]]]:
|
async def generate_placeholders(model: Model, censored_words: List[str]) -> Dict[str, List[str]]:
|
||||||
"""Generate placeholders for censored words"""
|
"""Generate placeholders for censored words"""
|
||||||
placeholder_agent = Agent(
|
placeholder_agent = Agent(
|
||||||
model,
|
model,
|
||||||
output_type=PlaceholderMapping,
|
output_type=PlaceholderMapping,
|
||||||
system_prompt=PromptManager.get_prompt('generate_placeholders', {})
|
system_prompt=PromptManager.get_prompt('generate_placeholders')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert set to sorted list for consistent ordering
|
# Convert set to sorted list for consistent ordering
|
||||||
words_list = sorted(list(censored_words))
|
result = await placeholder_agent.run(censored_words)
|
||||||
|
|
||||||
result = await placeholder_agent.run(words_list)
|
|
||||||
return result.output.mappings
|
return result.output.mappings
|
||||||
|
|
||||||
|
|
||||||
@@ -91,9 +99,9 @@ async def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
censored_words: Set = set()
|
censored_words: List[str] = []
|
||||||
for message in messages_to_censor:
|
for message in messages_to_censor:
|
||||||
censored_words = censored_words | await get_words_to_censor(model, message)
|
censored_words += await get_words_to_censor(model, message)
|
||||||
|
|
||||||
print("\nWords to censor:")
|
print("\nWords to censor:")
|
||||||
for word in censored_words:
|
for word in censored_words:
|
||||||
@@ -104,8 +112,14 @@ async def main():
|
|||||||
placeholder_mappings = await generate_placeholders(model, censored_words)
|
placeholder_mappings = await generate_placeholders(model, censored_words)
|
||||||
|
|
||||||
print("\nPlaceholder mappings:")
|
print("\nPlaceholder mappings:")
|
||||||
for placeholder, originals in placeholder_mappings:
|
for placeholder, sensitive_items in placeholder_mappings.items():
|
||||||
print(f" {placeholder} → {', '.join(originals)}")
|
print(f" {placeholder}: {sensitive_items}")
|
||||||
|
|
||||||
|
print(f'Sensitive words: {len(censored_words)}')
|
||||||
|
total_count = 0
|
||||||
|
for _, value_list in placeholder_mappings.items():
|
||||||
|
total_count += len(value_list)
|
||||||
|
print(f'Sensitive words in placeholders: {total_count}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user