jinja2 support
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Set, Optional
|
||||
import logging
|
||||
from jinja2 import Template, Environment, meta, TemplateSyntaxError, UndefinedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,8 +17,9 @@ class ManagedPrompt:
|
||||
name: str
|
||||
variables: Set[str]
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
_filled_prompt: Optional[str] = None
|
||||
_context: Optional[Dict[str, Any]] = None
|
||||
_filled_prompt: Optional[str] = field(default=None, init=False, repr=False)
|
||||
_context: Optional[Dict[str, Any]] = field(default=None, init=False, repr=False)
|
||||
_jinja_template: Optional[Template] = field(default=None, init=False, repr=False)
|
||||
|
||||
def validate(self, **kwargs) -> bool:
|
||||
"""Validate that all required variables are provided.
|
||||
@@ -39,7 +41,7 @@ class ManagedPrompt:
|
||||
return self.variables - provided_vars
|
||||
|
||||
def fill(self, **kwargs) -> str:
|
||||
"""Fill the template with provided variables.
|
||||
"""Fill the template with provided variables using Jinja2.
|
||||
|
||||
Args:
|
||||
**kwargs: Variables to fill in the template
|
||||
@@ -48,8 +50,15 @@ class ManagedPrompt:
|
||||
The filled prompt string
|
||||
|
||||
Raises:
|
||||
ValueError: If required variables are missing
|
||||
ValueError: If required variables are missing or template syntax error
|
||||
"""
|
||||
# Create Jinja2 template if not already created
|
||||
if self._jinja_template is None:
|
||||
try:
|
||||
self._jinja_template = Template(self.template)
|
||||
except TemplateSyntaxError as e:
|
||||
raise ValueError(f"Invalid template syntax in prompt '{self.name}': {e}")
|
||||
|
||||
# If no variables required and none provided, return template as-is
|
||||
if not self.variables and not kwargs:
|
||||
self._filled_prompt = self.template
|
||||
@@ -63,15 +72,11 @@ class ManagedPrompt:
|
||||
f"Required: {self.variables}, Provided: {set(kwargs.keys())}"
|
||||
)
|
||||
|
||||
# Only process the template if there are actually variables to replace
|
||||
if self.variables:
|
||||
result = self.template
|
||||
for key, value in kwargs.items():
|
||||
if key in self.variables: # Only replace known variables
|
||||
placeholder = f"{{{{{key}}}}}" # {{key}}
|
||||
result = result.replace(placeholder, str(value))
|
||||
else:
|
||||
result = self.template
|
||||
try:
|
||||
# Render the template with Jinja2
|
||||
result = self._jinja_template.render(**kwargs)
|
||||
except UndefinedError as e:
|
||||
raise ValueError(f"Error rendering template '{self.name}': {e}")
|
||||
|
||||
# Cache the filled result
|
||||
self._filled_prompt = result
|
||||
@@ -182,10 +187,19 @@ class PromptManager:
|
||||
return Path.cwd() / 'prompts'
|
||||
|
||||
def _extract_variables(self, template: str) -> Set[str]:
|
||||
"""Extract all {{variable}} placeholders from template"""
|
||||
pattern = r'\{\{(\w+)\}\}'
|
||||
variables = set(re.findall(pattern, template))
|
||||
return variables
|
||||
"""Extract all variables from Jinja2 template"""
|
||||
try:
|
||||
# Create a Jinja2 environment and parse the template
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
# Get all undeclared variables from the template
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
except TemplateSyntaxError:
|
||||
# Fallback to simple regex for backwards compatibility
|
||||
pattern = r'\{\{\s*(\w+)\s*\}\}'
|
||||
variables = set(re.findall(pattern, template))
|
||||
return variables
|
||||
|
||||
def _validate_context(self, prompt_name: str, context: Dict[str, Any]) -> None:
|
||||
"""Validate that all required variables are provided"""
|
||||
@@ -208,14 +222,12 @@ class PromptManager:
|
||||
logger.warning(f"Extra variables provided for prompt '{prompt_name}': {extra_vars}")
|
||||
|
||||
def _fill_template(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""Fill template with context variables"""
|
||||
result = template
|
||||
|
||||
for key, value in context.items():
|
||||
placeholder = f"{{{{{key}}}}}" # {{key}}
|
||||
result = result.replace(placeholder, str(value))
|
||||
|
||||
return result
|
||||
"""Fill template with context variables using Jinja2"""
|
||||
try:
|
||||
jinja_template = Template(template)
|
||||
return jinja_template.render(**context)
|
||||
except (TemplateSyntaxError, UndefinedError) as e:
|
||||
raise ValueError(f"Error rendering template: {e}")
|
||||
|
||||
@classmethod
|
||||
def configure(cls, path: Optional[Path] = None, caching: Optional[bool] = None):
|
||||
|
||||
Reference in New Issue
Block a user