492 lines
16 KiB
Python
492 lines
16 KiB
Python
import yaml
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any
|
|
from models import (
|
|
Customer, CustomerService, Location, Host, Service, NetworkSegment, HostIP, PortForwarding,
|
|
ServiceType, HostType, VPNType
|
|
)
|
|
|
|
|
|
def get_config_dir() -> Path:
|
|
"""Get the VPNTray configuration directory path."""
|
|
home = Path.home()
|
|
config_dir = home / ".vpntray" / "customers"
|
|
return config_dir
|
|
|
|
|
|
def ensure_config_dir() -> Path:
|
|
"""Ensure the configuration directory exists."""
|
|
config_dir = get_config_dir()
|
|
config_dir.mkdir(parents=True, exist_ok=True)
|
|
return config_dir
|
|
|
|
|
|
def parse_service_type(service_type_str: str) -> ServiceType:
|
|
"""Convert a string to ServiceType enum, with fallback."""
|
|
# Map common strings to enum values
|
|
type_mapping = {
|
|
"SSH": ServiceType.SSH,
|
|
"Web GUI": ServiceType.WEB_GUI,
|
|
"RDP": ServiceType.RDP,
|
|
"VNC": ServiceType.VNC,
|
|
"SMB": ServiceType.SMB,
|
|
"Database": ServiceType.DATABASE,
|
|
"FTP": ServiceType.FTP
|
|
}
|
|
return type_mapping.get(service_type_str, ServiceType.WEB_GUI)
|
|
|
|
|
|
def parse_host_type(host_type_str: str) -> HostType:
|
|
"""Convert a string to HostType enum, with fallback."""
|
|
type_mapping = {
|
|
"Linux": HostType.LINUX,
|
|
"Windows": HostType.WINDOWS,
|
|
"Windows Server": HostType.WINDOWS_SERVER,
|
|
"Proxmox": HostType.PROXMOX,
|
|
"ESXi": HostType.ESXI,
|
|
"Router": HostType.ROUTER,
|
|
"Switch": HostType.SWITCH,
|
|
}
|
|
return type_mapping.get(host_type_str, HostType.LINUX)
|
|
|
|
|
|
def parse_vpn_type(vpn_type_str: str) -> VPNType:
|
|
"""Convert a string to VPNType enum, with fallback."""
|
|
type_mapping = {
|
|
"OpenVPN": VPNType.OPENVPN,
|
|
"WireGuard": VPNType.WIREGUARD,
|
|
"IPSec": VPNType.IPSEC,
|
|
}
|
|
return type_mapping.get(vpn_type_str, VPNType.OPENVPN)
|
|
|
|
|
|
def parse_host(host_data: Dict[str, Any]) -> Host:
|
|
"""Parse a host from YAML data."""
|
|
# Parse services
|
|
services = []
|
|
if 'services' in host_data:
|
|
for service_data in host_data['services']:
|
|
service = Service(
|
|
name=service_data['name'],
|
|
service_type=parse_service_type(service_data['service_type']),
|
|
port=service_data['port']
|
|
)
|
|
services.append(service)
|
|
|
|
# Parse IP addresses - handle both new HostIP format and legacy formats
|
|
ip_addresses = []
|
|
if 'ip_addresses' in host_data:
|
|
for ip_data in host_data['ip_addresses']:
|
|
if isinstance(ip_data, dict):
|
|
# New HostIP format: {ip_address: "192.168.1.10", network_segment: "LAN", is_primary: true}
|
|
host_ip = HostIP(
|
|
ip_address=ip_data['ip_address'],
|
|
network_segment=ip_data.get(
|
|
'network_segment', 'LAN'), # Default segment
|
|
is_primary=ip_data.get('is_primary', False)
|
|
)
|
|
ip_addresses.append(host_ip)
|
|
else:
|
|
# Legacy format: simple string list
|
|
host_ip = HostIP(
|
|
ip_address=ip_data,
|
|
network_segment='LAN', # Default segment for legacy format
|
|
is_primary=len(ip_addresses) == 0 # First IP is primary
|
|
)
|
|
ip_addresses.append(host_ip)
|
|
elif 'ip_address' in host_data:
|
|
# Very old format: single IP string
|
|
host_ip = HostIP(
|
|
ip_address=host_data['ip_address'],
|
|
network_segment='LAN',
|
|
is_primary=True
|
|
)
|
|
ip_addresses.append(host_ip)
|
|
|
|
# Create host
|
|
host = Host(
|
|
name=host_data['name'],
|
|
ip_addresses=ip_addresses,
|
|
host_type=parse_host_type(host_data['host_type']),
|
|
icon=host_data.get('icon'), # Custom icon name
|
|
description=host_data.get('description', ''),
|
|
services=services
|
|
)
|
|
|
|
# Parse sub-hosts (VMs) recursively
|
|
if 'sub_hosts' in host_data:
|
|
for subhost_data in host_data['sub_hosts']:
|
|
subhost = parse_host(subhost_data)
|
|
host.sub_hosts.append(subhost)
|
|
|
|
return host
|
|
|
|
|
|
def parse_location(location_data: Dict[str, Any]) -> Location:
|
|
"""Parse a location from YAML data."""
|
|
# Parse network segments
|
|
network_segments = []
|
|
if 'network_segments' in location_data:
|
|
for segment_data in location_data['network_segments']:
|
|
segment = NetworkSegment(
|
|
name=segment_data['name'],
|
|
cidr=segment_data['cidr'],
|
|
vlan_id=segment_data.get('vlan_id'),
|
|
zone=segment_data.get('zone', 'general'),
|
|
gateway=segment_data.get('gateway'),
|
|
description=segment_data.get('description', '')
|
|
)
|
|
network_segments.append(segment)
|
|
|
|
# Parse port forwardings
|
|
port_forwardings = []
|
|
if 'port_forwardings' in location_data:
|
|
for pf_data in location_data['port_forwardings']:
|
|
port_forward = PortForwarding(
|
|
external_port=pf_data['external_port'],
|
|
internal_ip=pf_data['internal_ip'],
|
|
internal_port=pf_data['internal_port'],
|
|
protocol=pf_data.get('protocol', 'tcp'),
|
|
description=pf_data.get('description', ''),
|
|
enabled=pf_data.get('enabled', True)
|
|
)
|
|
port_forwardings.append(port_forward)
|
|
|
|
# Parse hosts
|
|
hosts = []
|
|
if 'hosts' in location_data:
|
|
for host_data in location_data['hosts']:
|
|
host = parse_host(host_data)
|
|
hosts.append(host)
|
|
|
|
# Create location (active and connected default to False - runtime state)
|
|
location = Location(
|
|
name=location_data['name'],
|
|
vpn_type=parse_vpn_type(location_data['vpn_type']),
|
|
connected=False, # Runtime state - always starts disconnected
|
|
active=False, # Runtime state - always starts inactive
|
|
vpn_config=location_data.get('vpn_config', ''),
|
|
hosts=hosts,
|
|
network_segments=network_segments,
|
|
networks=location_data.get('networks', []), # Legacy support
|
|
external_addresses=location_data.get('external_addresses', []),
|
|
port_forwardings=port_forwardings,
|
|
vpn_credentials=location_data.get('vpn_credentials'),
|
|
nmcli_connection_name=location_data.get('nmcli_connection_name'),
|
|
auto_import=location_data.get('auto_import', True)
|
|
)
|
|
|
|
return location
|
|
|
|
|
|
def parse_customer(yaml_file: Path) -> Customer:
|
|
"""Parse a customer from a YAML file."""
|
|
with open(yaml_file, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
# Parse customer services
|
|
services = []
|
|
if 'services' in data:
|
|
for service_data in data['services']:
|
|
service = CustomerService(
|
|
name=service_data['name'],
|
|
url=service_data['url'],
|
|
service_type=service_data['service_type'],
|
|
description=service_data.get('description', '')
|
|
)
|
|
services.append(service)
|
|
|
|
# Parse locations
|
|
locations = []
|
|
if 'locations' in data:
|
|
for location_data in data['locations']:
|
|
location = parse_location(location_data)
|
|
locations.append(location)
|
|
|
|
# Create customer
|
|
customer = Customer(
|
|
name=data['name'],
|
|
services=services,
|
|
locations=locations
|
|
)
|
|
|
|
return customer
|
|
|
|
|
|
def load_customers() -> List[Customer]:
|
|
"""Load all customers from YAML files in the config directory."""
|
|
config_dir = ensure_config_dir()
|
|
customers = []
|
|
|
|
# Get all YAML files in the directory
|
|
yaml_files = list(config_dir.glob("*.yaml")) + \
|
|
list(config_dir.glob("*.yml"))
|
|
|
|
if not yaml_files:
|
|
# No customer files found, initialize with examples
|
|
print(f"No customer files found in {config_dir}")
|
|
print("Run 'python data_loader.py --init' to create example customer files")
|
|
return get_demo_customers()
|
|
|
|
# Load each customer file
|
|
for yaml_file in yaml_files:
|
|
try:
|
|
customer = parse_customer(yaml_file)
|
|
customers.append(customer)
|
|
print(f"Loaded customer: {customer.name} from {yaml_file.name}")
|
|
except Exception as e:
|
|
print(f"Error loading {yaml_file}: {e}")
|
|
|
|
return customers
|
|
|
|
|
|
def save_customer(customer: Customer, filename: str = None) -> None:
|
|
"""Save a customer to a YAML file."""
|
|
config_dir = ensure_config_dir()
|
|
|
|
if filename is None:
|
|
# Generate filename from customer name
|
|
filename = customer.name.lower().replace(' ', '_') + '.yaml'
|
|
|
|
filepath = config_dir / filename
|
|
|
|
# Convert customer to dictionary
|
|
data = {
|
|
'name': customer.name,
|
|
'services': [
|
|
{
|
|
'name': service.name,
|
|
'url': service.url,
|
|
'service_type': service.service_type,
|
|
'description': service.description
|
|
}
|
|
for service in customer.services
|
|
],
|
|
'locations': []
|
|
}
|
|
|
|
# Convert locations
|
|
for location in customer.locations:
|
|
# Convert network segments
|
|
network_segments = []
|
|
for segment in location.network_segments:
|
|
segment_data = {
|
|
'name': segment.name,
|
|
'cidr': segment.cidr,
|
|
'zone': segment.zone,
|
|
'description': segment.description
|
|
}
|
|
if segment.vlan_id is not None:
|
|
segment_data['vlan_id'] = segment.vlan_id
|
|
if segment.gateway is not None:
|
|
segment_data['gateway'] = segment.gateway
|
|
network_segments.append(segment_data)
|
|
|
|
# Convert port forwardings
|
|
port_forwardings = []
|
|
for pf in location.port_forwardings:
|
|
pf_data = {
|
|
'external_port': pf.external_port,
|
|
'internal_ip': pf.internal_ip,
|
|
'internal_port': pf.internal_port,
|
|
'protocol': pf.protocol,
|
|
'enabled': pf.enabled
|
|
}
|
|
if pf.description:
|
|
pf_data['description'] = pf.description
|
|
port_forwardings.append(pf_data)
|
|
|
|
location_data = {
|
|
'name': location.name,
|
|
'vpn_type': location.vpn_type.value,
|
|
'vpn_config': location.vpn_config,
|
|
'network_segments': network_segments,
|
|
'external_addresses': location.external_addresses,
|
|
'port_forwardings': port_forwardings,
|
|
'hosts': []
|
|
}
|
|
|
|
# Add legacy networks if they exist
|
|
if location.networks:
|
|
location_data['networks'] = location.networks
|
|
|
|
# Convert hosts
|
|
def convert_host(host):
|
|
# Convert HostIP objects back to dictionaries
|
|
ip_addresses = []
|
|
for host_ip in host.ip_addresses:
|
|
ip_dict = {
|
|
'ip_address': host_ip.ip_address,
|
|
'network_segment': host_ip.network_segment,
|
|
'is_primary': host_ip.is_primary
|
|
}
|
|
ip_addresses.append(ip_dict)
|
|
|
|
host_data = {
|
|
'name': host.name,
|
|
'ip_addresses': ip_addresses,
|
|
'host_type': host.host_type.value,
|
|
'description': host.description,
|
|
'services': [
|
|
{
|
|
'name': service.name,
|
|
'service_type': service.service_type.value,
|
|
'port': service.port
|
|
}
|
|
for service in host.services
|
|
]
|
|
}
|
|
|
|
# Add icon if specified
|
|
if host.icon:
|
|
host_data['icon'] = host.icon
|
|
|
|
if host.sub_hosts:
|
|
host_data['sub_hosts'] = [convert_host(
|
|
subhost) for subhost in host.sub_hosts]
|
|
|
|
return host_data
|
|
|
|
for host in location.hosts:
|
|
location_data['hosts'].append(convert_host(host))
|
|
|
|
data['locations'].append(location_data)
|
|
|
|
# Write to file
|
|
with open(filepath, 'w') as f:
|
|
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
|
|
|
print(f"Saved customer to {filepath}")
|
|
|
|
|
|
def get_demo_customers() -> List[Customer]:
|
|
"""Return demo customers for when no config files exist."""
|
|
# Return a minimal demo customer
|
|
demo_customer = Customer(name="Demo Customer")
|
|
|
|
demo_customer.services = [
|
|
CustomerService(
|
|
name="Demo Portal",
|
|
url="https://demo.example.com",
|
|
service_type="Web Portal",
|
|
description="Demo web portal"
|
|
)
|
|
]
|
|
|
|
demo_location = Location(
|
|
name="Demo Location",
|
|
vpn_type=VPNType.OPENVPN,
|
|
connected=False,
|
|
active=True,
|
|
vpn_config="demo.ovpn" # File in ~/.vpntray/vpn/
|
|
)
|
|
|
|
# Create a demo network segment
|
|
demo_segment = NetworkSegment(
|
|
name="LAN",
|
|
cidr="10.0.0.0/24",
|
|
gateway="10.0.0.1",
|
|
zone="production",
|
|
description="Demo network"
|
|
)
|
|
|
|
demo_host = Host(
|
|
name="DEMO-01",
|
|
ip_addresses=[HostIP(ip_address="10.0.0.1",
|
|
network_segment="LAN", is_primary=True)],
|
|
host_type=HostType.LINUX,
|
|
description="Demo server",
|
|
services=[
|
|
Service("SSH", ServiceType.SSH, 22),
|
|
Service("Web", ServiceType.WEB_GUI, 80)
|
|
]
|
|
)
|
|
|
|
demo_location.hosts = [demo_host]
|
|
demo_location.network_segments = [demo_segment]
|
|
demo_customer.locations = [demo_location]
|
|
|
|
return [demo_customer]
|
|
|
|
|
|
def initialize_example_customers():
|
|
"""Create example customer YAML files in the config directory."""
|
|
config_dir = ensure_config_dir()
|
|
|
|
# Create TechCorp example
|
|
techcorp_file = config_dir / "techcorp_solutions.yaml"
|
|
if not techcorp_file.exists():
|
|
# Read from our example file
|
|
example_file = Path(__file__).parent / "example_customer.yaml"
|
|
if example_file.exists():
|
|
with open(example_file, 'r') as f:
|
|
content = f.read()
|
|
with open(techcorp_file, 'w') as f:
|
|
f.write(content)
|
|
print(f"Created example: {techcorp_file}")
|
|
|
|
# Create a simpler example
|
|
simple_file = config_dir / "simple_customer.yaml"
|
|
if not simple_file.exists():
|
|
simple_yaml = """name: Simple Customer
|
|
|
|
services:
|
|
- name: Company Website
|
|
url: https://simple.example.com
|
|
service_type: Web Portal
|
|
description: Main company website
|
|
|
|
locations:
|
|
- name: Main Office
|
|
vpn_type: WireGuard
|
|
vpn_config: simple.conf # File in ~/.vpntray/vpn/
|
|
|
|
network_segments:
|
|
- name: LAN
|
|
cidr: 192.168.1.0/24
|
|
gateway: 192.168.1.1
|
|
zone: production
|
|
description: Main office network
|
|
|
|
external_addresses:
|
|
- simple.vpn.example.com
|
|
|
|
hosts:
|
|
- name: SERVER-01
|
|
ip_addresses:
|
|
- ip_address: 192.168.1.10
|
|
network_segment: LAN
|
|
is_primary: true
|
|
host_type: Linux
|
|
description: Main server
|
|
services:
|
|
- name: SSH
|
|
service_type: SSH
|
|
port: 22
|
|
- name: Web Interface
|
|
service_type: Web GUI
|
|
port: 443
|
|
"""
|
|
with open(simple_file, 'w') as f:
|
|
f.write(simple_yaml)
|
|
print(f"Created example: {simple_file}")
|
|
|
|
print(f"\nExample customer files created in: {config_dir}")
|
|
print("You can now edit these files or create new ones following the same format.")
|
|
|
|
|
|
# Allow running this file directly to initialize examples
|
|
if __name__ == "__main__":
|
|
import sys
|
|
if len(sys.argv) > 1 and sys.argv[1] == "--init":
|
|
initialize_example_customers()
|
|
else:
|
|
# Test loading
|
|
customers = load_customers()
|
|
for customer in customers:
|
|
print(f"\nLoaded: {customer.name}")
|
|
print(f" Services: {len(customer.services)}")
|
|
print(f" Locations: {len(customer.locations)}")
|
|
for location in customer.locations:
|
|
print(f" - {location.name}: {len(location.hosts)} hosts")
|