import yaml from pathlib import Path from typing import List, Dict, Any from models import ( Customer, CustomerService, Location, Host, Service, ServiceType, HostType, VPNType ) def get_config_dir() -> Path: """Get the VPNTray configuration directory path.""" home = Path.home() config_dir = home / ".vpntray" / "customers" return config_dir def ensure_config_dir() -> Path: """Ensure the configuration directory exists.""" config_dir = get_config_dir() config_dir.mkdir(parents=True, exist_ok=True) return config_dir def parse_service_type(service_type_str: str) -> ServiceType: """Convert a string to ServiceType enum, with fallback.""" # Map common strings to enum values type_mapping = { "SSH": ServiceType.SSH, "Web GUI": ServiceType.WEB_GUI, "RDP": ServiceType.RDP, "VNC": ServiceType.VNC, "SMB": ServiceType.SMB, "Database": ServiceType.DATABASE, "FTP": ServiceType.FTP } return type_mapping.get(service_type_str, ServiceType.WEB_GUI) def parse_host_type(host_type_str: str) -> HostType: """Convert a string to HostType enum, with fallback.""" type_mapping = { "Linux": HostType.LINUX, "Windows": HostType.WINDOWS, "Windows Server": HostType.WINDOWS_SERVER, "Proxmox": HostType.PROXMOX, "ESXi": HostType.ESXI, "Router": HostType.ROUTER, "Switch": HostType.SWITCH, } return type_mapping.get(host_type_str, HostType.LINUX) def parse_vpn_type(vpn_type_str: str) -> VPNType: """Convert a string to VPNType enum, with fallback.""" type_mapping = { "OpenVPN": VPNType.OPENVPN, "WireGuard": VPNType.WIREGUARD, "IPSec": VPNType.IPSEC, } return type_mapping.get(vpn_type_str, VPNType.OPENVPN) def parse_host(host_data: Dict[str, Any]) -> Host: """Parse a host from YAML data.""" # Parse services services = [] if 'services' in host_data: for service_data in host_data['services']: service = Service( name=service_data['name'], service_type=parse_service_type(service_data['service_type']), port=service_data['port'] ) services.append(service) # Create host host = Host( name=host_data['name'], ip_address=host_data['ip_address'], host_type=parse_host_type(host_data['host_type']), description=host_data.get('description', ''), services=services ) # Parse sub-hosts (VMs) recursively if 'sub_hosts' in host_data: for subhost_data in host_data['sub_hosts']: subhost = parse_host(subhost_data) host.sub_hosts.append(subhost) return host def parse_location(location_data: Dict[str, Any]) -> Location: """Parse a location from YAML data.""" # Parse hosts hosts = [] if 'hosts' in location_data: for host_data in location_data['hosts']: host = parse_host(host_data) hosts.append(host) # Create location location = Location( name=location_data['name'], vpn_type=parse_vpn_type(location_data['vpn_type']), connected=location_data.get('connected', False), active=location_data.get('active', False), vpn_config=location_data.get('vpn_config', ''), hosts=hosts ) return location def parse_customer(yaml_file: Path) -> Customer: """Parse a customer from a YAML file.""" with open(yaml_file, 'r') as f: data = yaml.safe_load(f) # Parse customer services services = [] if 'services' in data: for service_data in data['services']: service = CustomerService( name=service_data['name'], url=service_data['url'], service_type=service_data['service_type'], description=service_data.get('description', '') ) services.append(service) # Parse locations locations = [] if 'locations' in data: for location_data in data['locations']: location = parse_location(location_data) locations.append(location) # Create customer customer = Customer( name=data['name'], services=services, locations=locations ) return customer def load_customers() -> List[Customer]: """Load all customers from YAML files in the config directory.""" config_dir = ensure_config_dir() customers = [] # Get all YAML files in the directory yaml_files = list(config_dir.glob("*.yaml")) + \ list(config_dir.glob("*.yml")) if not yaml_files: # No customer files found, initialize with examples print(f"No customer files found in {config_dir}") print("Run 'python data_loader.py --init' to create example customer files") return get_demo_customers() # Load each customer file for yaml_file in yaml_files: try: customer = parse_customer(yaml_file) customers.append(customer) print(f"Loaded customer: {customer.name} from {yaml_file.name}") except Exception as e: print(f"Error loading {yaml_file}: {e}") return customers def save_customer(customer: Customer, filename: str = None) -> None: """Save a customer to a YAML file.""" config_dir = ensure_config_dir() if filename is None: # Generate filename from customer name filename = customer.name.lower().replace(' ', '_') + '.yaml' filepath = config_dir / filename # Convert customer to dictionary data = { 'name': customer.name, 'services': [ { 'name': service.name, 'url': service.url, 'service_type': service.service_type, 'description': service.description } for service in customer.services ], 'locations': [] } # Convert locations for location in customer.locations: location_data = { 'name': location.name, 'vpn_type': location.vpn_type.value, 'vpn_config': location.vpn_config, 'active': location.active, 'connected': location.connected, 'hosts': [] } # Convert hosts def convert_host(host): host_data = { 'name': host.name, 'ip_address': host.ip_address, 'host_type': host.host_type.value, 'description': host.description, 'services': [ { 'name': service.name, 'service_type': service.service_type.value, 'port': service.port } for service in host.services ] } if host.sub_hosts: host_data['sub_hosts'] = [convert_host( subhost) for subhost in host.sub_hosts] return host_data for host in location.hosts: location_data['hosts'].append(convert_host(host)) data['locations'].append(location_data) # Write to file with open(filepath, 'w') as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) print(f"Saved customer to {filepath}") def get_demo_customers() -> List[Customer]: """Return demo customers for when no config files exist.""" # Return a minimal demo customer demo_customer = Customer(name="Demo Customer") demo_customer.services = [ CustomerService( name="Demo Portal", url="https://demo.example.com", service_type="Web Portal", description="Demo web portal" ) ] demo_location = Location( name="Demo Location", vpn_type=VPNType.OPENVPN, connected=False, active=True, vpn_config="/etc/openvpn/demo.ovpn" ) demo_host = Host( name="DEMO-01", ip_address="10.0.0.1", host_type=HostType.LINUX, description="Demo server", services=[ Service("SSH", ServiceType.SSH, 22), Service("Web", ServiceType.WEB_GUI, 80) ] ) demo_location.hosts = [demo_host] demo_customer.locations = [demo_location] return [demo_customer] def initialize_example_customers(): """Create example customer YAML files in the config directory.""" config_dir = ensure_config_dir() # Create TechCorp example techcorp_file = config_dir / "techcorp_solutions.yaml" if not techcorp_file.exists(): # Read from our example file example_file = Path(__file__).parent / "example_customer.yaml" if example_file.exists(): with open(example_file, 'r') as f: content = f.read() with open(techcorp_file, 'w') as f: f.write(content) print(f"Created example: {techcorp_file}") # Create a simpler example simple_file = config_dir / "simple_customer.yaml" if not simple_file.exists(): simple_yaml = """name: Simple Customer services: - name: Company Website url: https://simple.example.com service_type: Web Portal description: Main company website locations: - name: Main Office vpn_type: WireGuard vpn_config: /etc/wireguard/simple.conf active: false connected: false hosts: - name: SERVER-01 ip_address: 192.168.1.10 host_type: Linux description: Main server services: - name: SSH service_type: SSH port: 22 - name: Web Interface service_type: Web GUI port: 443 """ with open(simple_file, 'w') as f: f.write(simple_yaml) print(f"Created example: {simple_file}") print(f"\nExample customer files created in: {config_dir}") print("You can now edit these files or create new ones following the same format.") # Allow running this file directly to initialize examples if __name__ == "__main__": import sys if len(sys.argv) > 1 and sys.argv[1] == "--init": initialize_example_customers() else: # Test loading customers = load_customers() for customer in customers: print(f"\nLoaded: {customer.name}") print(f" Services: {len(customer.services)}") print(f" Locations: {len(customer.locations)}") for location in customer.locations: print(f" - {location.name}: {len(location.hosts)} hosts")