first steps

This commit is contained in:
2025-09-09 06:48:51 +02:00
parent 6e2b456dbb
commit 838f1c737e
21 changed files with 1411 additions and 0 deletions

3
src/services/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .model_manager import ModelManager
__all__ = ["ModelManager"]

View File

@@ -0,0 +1,113 @@
from typing import Dict, List, Optional
from datetime import datetime, timezone
from models import Model, ModelStatus
from services.persistence import persistence_manager
class ModelManager:
def __init__(self):
self.models: Dict[str, Model] = {}
self.next_port = 8001
self._load_models()
def list_models(self) -> List[Model]:
"""List all registered models"""
return list(self.models.values())
def get_model(self, model_id: str) -> Optional[Model]:
"""Get a specific model by ID"""
return self.models.get(model_id)
def create_model(self, model: Model) -> Model:
"""Create a new model"""
model.port = self._allocate_port()
model.status = ModelStatus.LOADING
model.created_at = datetime.now(timezone.utc)
model.updated_at = datetime.now(timezone.utc)
self.models[model.id] = model
self._save_models()
return model
def update_model(self, model_id: str, updates: Dict) -> Optional[Model]:
"""Update an existing model"""
if model_id not in self.models:
return None
model = self.models[model_id]
# Update allowed fields
allowed_fields = {
"name", "model", "tensor_parallel_size", "pipeline_parallel_size",
"max_model_len", "dtype", "quantization", "trust_remote_code",
"gpu_memory_utilization", "max_num_seqs", "config", "capabilities"
}
for field, value in updates.items():
if field in allowed_fields and value is not None:
setattr(model, field, value)
model.updated_at = datetime.now(timezone.utc)
self._save_models()
return model
def delete_model(self, model_id: str) -> bool:
"""Delete a model"""
if model_id in self.models:
model = self.models[model_id]
if model.port:
self._release_port(model.port)
del self.models[model_id]
self._save_models()
return True
return False
def _allocate_port(self) -> int:
"""Allocate a port for a new model"""
port = self.next_port
self.next_port += 1
return port
def _release_port(self, port: int) -> None:
"""Release a port when a model is deleted"""
# In a real implementation, we might want to track and reuse ports
pass
def _save_models(self) -> None:
"""Save all models to disk"""
models_data = [model.to_admin_format() for model in self.models.values()]
persistence_manager.save_models(models_data)
def _load_models(self) -> None:
"""Load models from disk on startup"""
models_data = persistence_manager.load_models()
for model_data in models_data:
# Reconstruct Model object from saved data
model = Model(
id=model_data.get("id"),
name=model_data.get("name", ""),
model=model_data.get("model", ""),
status=ModelStatus(model_data.get("status", "loading")),
created_at=model_data.get("created_at"),
updated_at=model_data.get("updated_at"),
tensor_parallel_size=model_data.get("tensor_parallel_size", 1),
pipeline_parallel_size=model_data.get("pipeline_parallel_size", 1),
max_model_len=model_data.get("max_model_len"),
dtype=model_data.get("dtype", "auto"),
quantization=model_data.get("quantization"),
trust_remote_code=model_data.get("trust_remote_code", False),
gpu_memory_utilization=model_data.get("gpu_memory_utilization", 0.9),
max_num_seqs=model_data.get("max_num_seqs", 256),
port=model_data.get("port"),
process_id=model_data.get("process_id"),
config=model_data.get("config", {}),
capabilities=model_data.get("capabilities", []),
)
self.models[model.id] = model
# Update next_port to avoid conflicts
if model.port and model.port >= self.next_port:
self.next_port = model.port + 1
# Global instance
model_manager = ModelManager()

View File

@@ -0,0 +1,67 @@
import json
import os
from pathlib import Path
from typing import Dict, Any, List
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class PersistenceManager:
def __init__(self, data_dir: str = None):
if data_dir is None:
# Use absolute path to ensure consistency regardless of where script is run from
default_data = Path(__file__).parent.parent / "data"
data_dir = os.getenv("DATA_DIR", str(default_data))
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.models_file = self.data_dir / "models.json"
def save_models(self, models: List[Dict[str, Any]]) -> None:
"""Save models to JSON file"""
try:
# Convert datetime objects to ISO format strings
serializable_models = []
for model in models:
model_copy = model.copy()
if "created_at" in model_copy and isinstance(model_copy["created_at"], datetime):
model_copy["created_at"] = model_copy["created_at"].isoformat()
if "updated_at" in model_copy and isinstance(model_copy["updated_at"], datetime):
model_copy["updated_at"] = model_copy["updated_at"].isoformat()
serializable_models.append(model_copy)
with open(self.models_file, 'w') as f:
json.dump(serializable_models, f, indent=2)
logger.info(f"Saved {len(models)} models to {self.models_file}")
except Exception as e:
logger.error(f"Failed to save models: {e}")
def load_models(self) -> List[Dict[str, Any]]:
"""Load models from JSON file"""
if not self.models_file.exists():
logger.info("No existing models file found")
return []
try:
with open(self.models_file, 'r') as f:
models = json.load(f)
# Convert ISO format strings back to datetime objects
for model in models:
if "created_at" in model and isinstance(model["created_at"], str):
model["created_at"] = datetime.fromisoformat(model["created_at"])
if "updated_at" in model and isinstance(model["updated_at"], str):
model["updated_at"] = datetime.fromisoformat(model["updated_at"])
logger.info(f"Loaded {len(models)} models from {self.models_file}")
return models
except Exception as e:
logger.error(f"Failed to load models: {e}")
return []
# Global instance
persistence_manager = PersistenceManager()