Files
VPNTray/data_loader.py
2025-09-06 11:11:48 +02:00

356 lines
10 KiB
Python

import yaml
from pathlib import Path
from typing import List, Dict, Any
from models import (
Customer, CustomerService, Location, Host, Service,
ServiceType, HostType, VPNType
)
def get_config_dir() -> Path:
"""Get the VPNTray configuration directory path."""
home = Path.home()
config_dir = home / ".vpntray" / "customers"
return config_dir
def ensure_config_dir() -> Path:
"""Ensure the configuration directory exists."""
config_dir = get_config_dir()
config_dir.mkdir(parents=True, exist_ok=True)
return config_dir
def parse_service_type(service_type_str: str) -> ServiceType:
"""Convert a string to ServiceType enum, with fallback."""
# Map common strings to enum values
type_mapping = {
"SSH": ServiceType.SSH,
"Web GUI": ServiceType.WEB_GUI,
"RDP": ServiceType.RDP,
"VNC": ServiceType.VNC,
"SMB": ServiceType.SMB,
"Database": ServiceType.DATABASE,
"FTP": ServiceType.FTP
}
return type_mapping.get(service_type_str, ServiceType.WEB_GUI)
def parse_host_type(host_type_str: str) -> HostType:
"""Convert a string to HostType enum, with fallback."""
type_mapping = {
"Linux": HostType.LINUX,
"Windows": HostType.WINDOWS,
"Windows Server": HostType.WINDOWS_SERVER,
"Proxmox": HostType.PROXMOX,
"ESXi": HostType.ESXI,
"Router": HostType.ROUTER,
"Switch": HostType.SWITCH,
}
return type_mapping.get(host_type_str, HostType.LINUX)
def parse_vpn_type(vpn_type_str: str) -> VPNType:
"""Convert a string to VPNType enum, with fallback."""
type_mapping = {
"OpenVPN": VPNType.OPENVPN,
"WireGuard": VPNType.WIREGUARD,
"IPSec": VPNType.IPSEC,
}
return type_mapping.get(vpn_type_str, VPNType.OPENVPN)
def parse_host(host_data: Dict[str, Any]) -> Host:
"""Parse a host from YAML data."""
# Parse services
services = []
if 'services' in host_data:
for service_data in host_data['services']:
service = Service(
name=service_data['name'],
service_type=parse_service_type(service_data['service_type']),
port=service_data['port']
)
services.append(service)
# Create host
host = Host(
name=host_data['name'],
ip_address=host_data['ip_address'],
host_type=parse_host_type(host_data['host_type']),
description=host_data.get('description', ''),
services=services
)
# Parse sub-hosts (VMs) recursively
if 'sub_hosts' in host_data:
for subhost_data in host_data['sub_hosts']:
subhost = parse_host(subhost_data)
host.sub_hosts.append(subhost)
return host
def parse_location(location_data: Dict[str, Any]) -> Location:
"""Parse a location from YAML data."""
# Parse hosts
hosts = []
if 'hosts' in location_data:
for host_data in location_data['hosts']:
host = parse_host(host_data)
hosts.append(host)
# Create location
location = Location(
name=location_data['name'],
vpn_type=parse_vpn_type(location_data['vpn_type']),
connected=location_data.get('connected', False),
active=location_data.get('active', False),
vpn_config=location_data.get('vpn_config', ''),
hosts=hosts
)
return location
def parse_customer(yaml_file: Path) -> Customer:
"""Parse a customer from a YAML file."""
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
# Parse customer services
services = []
if 'services' in data:
for service_data in data['services']:
service = CustomerService(
name=service_data['name'],
url=service_data['url'],
service_type=service_data['service_type'],
description=service_data.get('description', '')
)
services.append(service)
# Parse locations
locations = []
if 'locations' in data:
for location_data in data['locations']:
location = parse_location(location_data)
locations.append(location)
# Create customer
customer = Customer(
name=data['name'],
services=services,
locations=locations
)
return customer
def load_customers() -> List[Customer]:
"""Load all customers from YAML files in the config directory."""
config_dir = ensure_config_dir()
customers = []
# Get all YAML files in the directory
yaml_files = list(config_dir.glob("*.yaml")) + \
list(config_dir.glob("*.yml"))
if not yaml_files:
# No customer files found, initialize with examples
print(f"No customer files found in {config_dir}")
print("Run 'python data_loader.py --init' to create example customer files")
return get_demo_customers()
# Load each customer file
for yaml_file in yaml_files:
try:
customer = parse_customer(yaml_file)
customers.append(customer)
print(f"Loaded customer: {customer.name} from {yaml_file.name}")
except Exception as e:
print(f"Error loading {yaml_file}: {e}")
return customers
def save_customer(customer: Customer, filename: str = None) -> None:
"""Save a customer to a YAML file."""
config_dir = ensure_config_dir()
if filename is None:
# Generate filename from customer name
filename = customer.name.lower().replace(' ', '_') + '.yaml'
filepath = config_dir / filename
# Convert customer to dictionary
data = {
'name': customer.name,
'services': [
{
'name': service.name,
'url': service.url,
'service_type': service.service_type,
'description': service.description
}
for service in customer.services
],
'locations': []
}
# Convert locations
for location in customer.locations:
location_data = {
'name': location.name,
'vpn_type': location.vpn_type.value,
'vpn_config': location.vpn_config,
'active': location.active,
'connected': location.connected,
'hosts': []
}
# Convert hosts
def convert_host(host):
host_data = {
'name': host.name,
'ip_address': host.ip_address,
'host_type': host.host_type.value,
'description': host.description,
'services': [
{
'name': service.name,
'service_type': service.service_type.value,
'port': service.port
}
for service in host.services
]
}
if host.sub_hosts:
host_data['sub_hosts'] = [convert_host(
subhost) for subhost in host.sub_hosts]
return host_data
for host in location.hosts:
location_data['hosts'].append(convert_host(host))
data['locations'].append(location_data)
# Write to file
with open(filepath, 'w') as f:
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
print(f"Saved customer to {filepath}")
def get_demo_customers() -> List[Customer]:
"""Return demo customers for when no config files exist."""
# Return a minimal demo customer
demo_customer = Customer(name="Demo Customer")
demo_customer.services = [
CustomerService(
name="Demo Portal",
url="https://demo.example.com",
service_type="Web Portal",
description="Demo web portal"
)
]
demo_location = Location(
name="Demo Location",
vpn_type=VPNType.OPENVPN,
connected=False,
active=True,
vpn_config="/etc/openvpn/demo.ovpn"
)
demo_host = Host(
name="DEMO-01",
ip_address="10.0.0.1",
host_type=HostType.LINUX,
description="Demo server",
services=[
Service("SSH", ServiceType.SSH, 22),
Service("Web", ServiceType.WEB_GUI, 80)
]
)
demo_location.hosts = [demo_host]
demo_customer.locations = [demo_location]
return [demo_customer]
def initialize_example_customers():
"""Create example customer YAML files in the config directory."""
config_dir = ensure_config_dir()
# Create TechCorp example
techcorp_file = config_dir / "techcorp_solutions.yaml"
if not techcorp_file.exists():
# Read from our example file
example_file = Path(__file__).parent / "example_customer.yaml"
if example_file.exists():
with open(example_file, 'r') as f:
content = f.read()
with open(techcorp_file, 'w') as f:
f.write(content)
print(f"Created example: {techcorp_file}")
# Create a simpler example
simple_file = config_dir / "simple_customer.yaml"
if not simple_file.exists():
simple_yaml = """name: Simple Customer
services:
- name: Company Website
url: https://simple.example.com
service_type: Web Portal
description: Main company website
locations:
- name: Main Office
vpn_type: WireGuard
vpn_config: /etc/wireguard/simple.conf
active: false
connected: false
hosts:
- name: SERVER-01
ip_address: 192.168.1.10
host_type: Linux
description: Main server
services:
- name: SSH
service_type: SSH
port: 22
- name: Web Interface
service_type: Web GUI
port: 443
"""
with open(simple_file, 'w') as f:
f.write(simple_yaml)
print(f"Created example: {simple_file}")
print(f"\nExample customer files created in: {config_dir}")
print("You can now edit these files or create new ones following the same format.")
# Allow running this file directly to initialize examples
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "--init":
initialize_example_customers()
else:
# Test loading
customers = load_customers()
for customer in customers:
print(f"\nLoaded: {customer.name}")
print(f" Services: {len(customer.services)}")
print(f" Locations: {len(customer.locations)}")
for location in customer.locations:
print(f" - {location.name}: {len(location.hosts)} hosts")