60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
import os
|
|
import logging
|
|
from typing import List, Dict
|
|
|
|
from llm_connector import LLMClient, LLMBackend
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMAgent:
|
|
client: LLMClient
|
|
|
|
def __init__(self, temperature: float = 0.8):
|
|
# TODO User temperature
|
|
backend: LLMBackend = {'base_url': os.environ['BACKEND_BASE_URL'],
|
|
'api_token': os.environ['BACKEND_API_TOKEN'],
|
|
'model': os.environ['BACKEND_MODEL']}
|
|
agent_backend: LLMBackend = {
|
|
'base_url': os.environ['AGENT_BASE_URL'],
|
|
'api_token': os.environ['AGENT_API_TOKEN'],
|
|
'model': os.environ['AGENT_MODEL']}
|
|
embedding_backend: LLMBackend = {
|
|
'base_url': os.environ['EMBEDDING_BASE_URL'],
|
|
'api_token': os.environ['EMBEDDING_API_TOKEN'],
|
|
'model': os.environ['EMBEDDING_MODEL']}
|
|
self.client = LLMClient(agent_backend, embedding_backend)
|
|
self.temperature = temperature
|
|
|
|
async def chat(self, messages: List[Dict[str, str]], max_tokens: int = 200) -> str:
|
|
logger.info('Chat')
|
|
try:
|
|
response = ''
|
|
async for chunk in self.client.get_response(messages, stream=False): # type: ignore
|
|
|
|
if 'content' in chunk:
|
|
response += chunk['content']
|
|
|
|
"""response = client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=self.temperature,
|
|
max_tokens=max_tokens
|
|
)"""
|
|
return response.strip()
|
|
except Exception as e:
|
|
return f"[LLM Error: {str(e)}]"
|
|
|
|
async def get_embedding(self, text: str) -> List[float]:
|
|
"""Get embedding for memory relevance scoring"""
|
|
try:
|
|
response = await self.client.get_embedding(text)
|
|
"""response = client.embeddings.create(
|
|
model="text-embedding-ada-002",
|
|
input=text
|
|
)"""
|
|
return response
|
|
except Exception as e:
|
|
print(f"Embedding error: {e}")
|
|
return [0.0] * 1536 # Default embedding size
|