first steps
This commit is contained in:
3
src/services/__init__.py
Normal file
3
src/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .model_manager import ModelManager
|
||||
|
||||
__all__ = ["ModelManager"]
|
||||
113
src/services/model_manager.py
Normal file
113
src/services/model_manager.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from models import Model, ModelStatus
|
||||
from services.persistence import persistence_manager
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self):
|
||||
self.models: Dict[str, Model] = {}
|
||||
self.next_port = 8001
|
||||
self._load_models()
|
||||
|
||||
def list_models(self) -> List[Model]:
|
||||
"""List all registered models"""
|
||||
return list(self.models.values())
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[Model]:
|
||||
"""Get a specific model by ID"""
|
||||
return self.models.get(model_id)
|
||||
|
||||
def create_model(self, model: Model) -> Model:
|
||||
"""Create a new model"""
|
||||
model.port = self._allocate_port()
|
||||
model.status = ModelStatus.LOADING
|
||||
model.created_at = datetime.now(timezone.utc)
|
||||
model.updated_at = datetime.now(timezone.utc)
|
||||
self.models[model.id] = model
|
||||
self._save_models()
|
||||
return model
|
||||
|
||||
def update_model(self, model_id: str, updates: Dict) -> Optional[Model]:
|
||||
"""Update an existing model"""
|
||||
if model_id not in self.models:
|
||||
return None
|
||||
|
||||
model = self.models[model_id]
|
||||
|
||||
# Update allowed fields
|
||||
allowed_fields = {
|
||||
"name", "model", "tensor_parallel_size", "pipeline_parallel_size",
|
||||
"max_model_len", "dtype", "quantization", "trust_remote_code",
|
||||
"gpu_memory_utilization", "max_num_seqs", "config", "capabilities"
|
||||
}
|
||||
|
||||
for field, value in updates.items():
|
||||
if field in allowed_fields and value is not None:
|
||||
setattr(model, field, value)
|
||||
|
||||
model.updated_at = datetime.now(timezone.utc)
|
||||
self._save_models()
|
||||
return model
|
||||
|
||||
def delete_model(self, model_id: str) -> bool:
|
||||
"""Delete a model"""
|
||||
if model_id in self.models:
|
||||
model = self.models[model_id]
|
||||
if model.port:
|
||||
self._release_port(model.port)
|
||||
del self.models[model_id]
|
||||
self._save_models()
|
||||
return True
|
||||
return False
|
||||
|
||||
def _allocate_port(self) -> int:
|
||||
"""Allocate a port for a new model"""
|
||||
port = self.next_port
|
||||
self.next_port += 1
|
||||
return port
|
||||
|
||||
def _release_port(self, port: int) -> None:
|
||||
"""Release a port when a model is deleted"""
|
||||
# In a real implementation, we might want to track and reuse ports
|
||||
pass
|
||||
|
||||
def _save_models(self) -> None:
|
||||
"""Save all models to disk"""
|
||||
models_data = [model.to_admin_format() for model in self.models.values()]
|
||||
persistence_manager.save_models(models_data)
|
||||
|
||||
def _load_models(self) -> None:
|
||||
"""Load models from disk on startup"""
|
||||
models_data = persistence_manager.load_models()
|
||||
for model_data in models_data:
|
||||
# Reconstruct Model object from saved data
|
||||
model = Model(
|
||||
id=model_data.get("id"),
|
||||
name=model_data.get("name", ""),
|
||||
model=model_data.get("model", ""),
|
||||
status=ModelStatus(model_data.get("status", "loading")),
|
||||
created_at=model_data.get("created_at"),
|
||||
updated_at=model_data.get("updated_at"),
|
||||
tensor_parallel_size=model_data.get("tensor_parallel_size", 1),
|
||||
pipeline_parallel_size=model_data.get("pipeline_parallel_size", 1),
|
||||
max_model_len=model_data.get("max_model_len"),
|
||||
dtype=model_data.get("dtype", "auto"),
|
||||
quantization=model_data.get("quantization"),
|
||||
trust_remote_code=model_data.get("trust_remote_code", False),
|
||||
gpu_memory_utilization=model_data.get("gpu_memory_utilization", 0.9),
|
||||
max_num_seqs=model_data.get("max_num_seqs", 256),
|
||||
port=model_data.get("port"),
|
||||
process_id=model_data.get("process_id"),
|
||||
config=model_data.get("config", {}),
|
||||
capabilities=model_data.get("capabilities", []),
|
||||
)
|
||||
self.models[model.id] = model
|
||||
|
||||
# Update next_port to avoid conflicts
|
||||
if model.port and model.port >= self.next_port:
|
||||
self.next_port = model.port + 1
|
||||
|
||||
|
||||
# Global instance
|
||||
model_manager = ModelManager()
|
||||
67
src/services/persistence.py
Normal file
67
src/services/persistence.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistenceManager:
|
||||
def __init__(self, data_dir: str = None):
|
||||
if data_dir is None:
|
||||
# Use absolute path to ensure consistency regardless of where script is run from
|
||||
default_data = Path(__file__).parent.parent / "data"
|
||||
data_dir = os.getenv("DATA_DIR", str(default_data))
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.models_file = self.data_dir / "models.json"
|
||||
|
||||
def save_models(self, models: List[Dict[str, Any]]) -> None:
|
||||
"""Save models to JSON file"""
|
||||
try:
|
||||
# Convert datetime objects to ISO format strings
|
||||
serializable_models = []
|
||||
for model in models:
|
||||
model_copy = model.copy()
|
||||
if "created_at" in model_copy and isinstance(model_copy["created_at"], datetime):
|
||||
model_copy["created_at"] = model_copy["created_at"].isoformat()
|
||||
if "updated_at" in model_copy and isinstance(model_copy["updated_at"], datetime):
|
||||
model_copy["updated_at"] = model_copy["updated_at"].isoformat()
|
||||
serializable_models.append(model_copy)
|
||||
|
||||
with open(self.models_file, 'w') as f:
|
||||
json.dump(serializable_models, f, indent=2)
|
||||
|
||||
logger.info(f"Saved {len(models)} models to {self.models_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save models: {e}")
|
||||
|
||||
def load_models(self) -> List[Dict[str, Any]]:
|
||||
"""Load models from JSON file"""
|
||||
if not self.models_file.exists():
|
||||
logger.info("No existing models file found")
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.models_file, 'r') as f:
|
||||
models = json.load(f)
|
||||
|
||||
# Convert ISO format strings back to datetime objects
|
||||
for model in models:
|
||||
if "created_at" in model and isinstance(model["created_at"], str):
|
||||
model["created_at"] = datetime.fromisoformat(model["created_at"])
|
||||
if "updated_at" in model and isinstance(model["updated_at"], str):
|
||||
model["updated_at"] = datetime.fromisoformat(model["updated_at"])
|
||||
|
||||
logger.info(f"Loaded {len(models)} models from {self.models_file}")
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Global instance
|
||||
persistence_manager = PersistenceManager()
|
||||
Reference in New Issue
Block a user